diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index 038f340..d93c20f 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -67,4 +67,30 @@ jobs: name: ${{ matrix.python-version }}-cov-lcov path: coverage.lcov if-no-files-found: warn + mypy: + runs-on: ubuntu-latest + strategy: + # don't cancel any remaining jobs when one fails + fail-fast: false + # how you define a matrix strategy + matrix: + # use these pythons + python-version: [ "3.11", "3.14" ] + steps: + - uses: actions/checkout@v6 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: | + dev-requirements.txt + setup.py + + - name: Install dev dependencies + run: python -m pip install -r dev-requirements.txt + - name: Install self (dpytest) + run: python -m pip install . + - name: Run mypy + run: mypy --install-types --non-interactive diff --git a/dev-requirements.txt b/dev-requirements.txt index 47c4cd3..9f9c5a9 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -9,3 +9,4 @@ build flake8~=7.0.0 pynacl typing-extensions +mypy diff --git a/discord/ext/test/_types.py b/discord/ext/test/_types.py index c4a1967..6def7fc 100644 --- a/discord/ext/test/_types.py +++ b/discord/ext/test/_types.py @@ -1,29 +1,53 @@ """ Internal module for type-hinting aliases. Ensures single common definitions. """ +from enum import Enum +import typing +from typing import Callable, Literal, Self, TypeVar, ParamSpec, Protocol import discord -import typing -T = typing.TypeVar('T') +T = TypeVar('T') +P = ParamSpec('P') + +AnyChannel = (discord.abc.GuildChannel | discord.TextChannel | discord.VoiceChannel | discord.StageChannel + | discord.DMChannel | discord.Thread | discord.GroupChannel) + + +class Wrapper(Protocol[P, T]): + __wrapped__: Callable[P, T] + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: + ... + + +class FnWithOld(Protocol[P, T]): + __old__: Callable[P, T] | None + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: + ... + + +class Undef(Enum): + undefined = None + + +undefined: Literal[Undef.undefined] = Undef.undefined -Callback = typing.Callable[..., typing.Coroutine[None, None, None]] -AnyChannel = (discord.TextChannel | discord.CategoryChannel | discord.abc.GuildChannel - | discord.abc.PrivateChannel | discord.Thread) if typing.TYPE_CHECKING: from discord.types import ( - role, gateway, appinfo, user, guild, emoji, channel, message, sticker, # noqa: F401 - scheduled_event, member # noqa: F401 + role, gateway, appinfo, user, guild, emoji, channel, message, sticker, snowflake, # noqa: F401 + scheduled_event, member, poll # noqa: F401 ) AnyChannelJson = channel.VoiceChannel | channel.TextChannel | channel.DMChannel | channel.CategoryChannel else: class OpenNamespace: - def __getattr__(self, item: str) -> typing.Self: + def __getattr__(self, item: str) -> Self: return self - def __subclasscheck__(self, subclass: type) -> typing.Literal[True]: + def __subclasscheck__(self, subclass: type) -> Literal[True]: return True def __or__(self, other: T) -> T: diff --git a/discord/ext/test/backend.py b/discord/ext/test/backend.py index ddf2178..74ae187 100644 --- a/discord/ext/test/backend.py +++ b/discord/ext/test/backend.py @@ -11,7 +11,6 @@ import sys import logging import re -import typing import datetime import discord @@ -20,13 +19,18 @@ import urllib.parse import urllib.request -from discord.abc import Snowflake +from discord.types import member from requests import Response +from typing import NamedTuple, Any, ClassVar, NoReturn, Literal, Pattern, overload, Sequence, Iterable from . import factories as facts, state as dstate, callbacks, websocket, _types +from ._types import Undef, undefined +from discord.types.snowflake import Snowflake +from .callbacks import CallbackEvent -class BackendState(typing.NamedTuple): + +class BackendState(NamedTuple): """ The dpytest backend, with all the state it needs to hold to be able to pretend to be discord. Generally only used internally, but exposed through :py:func:`get_state` @@ -37,10 +41,9 @@ class BackendState(typing.NamedTuple): log = logging.getLogger("discord.ext.tests") _cur_config: BackendState | None = None -_undefined = object() # default value for when NoneType has special meaning -def _get_higher_locs(num: int) -> dict[str, typing.Any]: +def _get_higher_locs(num: int) -> dict[str, Any]: """ Get the local variables from higher in the call-stack. Should only be used in FakeHttp for retrieving information not passed to it by its caller. @@ -72,53 +75,61 @@ class FakeHttp(dhttp.HTTPClient): a runner callback and calls the ``dpytest`` backend to update any necessary state and trigger any necessary fake messages to the client. """ - fileno: typing.ClassVar[int] = 0 + fileno: ClassVar[int] = 0 state: dstate.FakeState - def __init__(self, loop: asyncio.AbstractEventLoop = None) -> None: + def __init__(self, loop: asyncio.AbstractEventLoop | None = None) -> None: if loop is None: loop = asyncio.get_event_loop() - self.state = None + self.state = None # type: ignore[assignment] super().__init__(connector=None, loop=loop) - async def request(self, *args: typing.Any, **kwargs: typing.Any) -> typing.NoReturn: + async def request( + self, + route: discord.http.Route, + *, + files: Sequence[discord.File] | None = None, + form: Iterable[dict[str, Any]] | None = None, + **kwargs: Any, + ) -> NoReturn: """ Overloaded to raise a NotImplemented error informing the user that the requested operation isn't yet supported by ``dpytest``. To fix this, the method call that triggered this error should be overloaded below to instead trigger a callback and call the appropriate backend function. - :param args: Arguments provided to the request - :param kwargs: Keyword arguments provided to the request + :param route: The route to request + :param files: Sequence of files in the request + :param form: Form input data + :param kwargs: Any other request arguments """ - route: discord.http.Route = args[0] raise NotImplementedError( - f"Operation occured that isn't captured by the tests framework. This is dpytest's fault, please report" + f"Operation occurred that isn't captured by the tests framework. This is dpytest's fault, please report" f"an issue on github. Debug Info: {route.method} {route.url} with {kwargs}" ) async def create_channel( self, - guild_id: int, - channel_type: discord.ChannelType, + guild_id: Snowflake, + channel_type: _types.channel.ChannelType, *, reason: str | None = None, - **options: typing.Any - ) -> _types.channel.PartialChannel: + **options: Any + ) -> _types.channel.GuildChannel: locs = _get_higher_locs(1) - guild = locs.get("self", None) - name = locs.get("name", None) + guild = locs["self"] + name = locs["name"] perms = options.get("permission_overwrites", None) parent_id = options.get("parent_id", None) + channel: discord.abc.GuildChannel if channel_type == discord.ChannelType.text.value: channel = make_text_channel(name, guild, permission_overwrites=perms, parent_id=parent_id) elif channel_type == discord.ChannelType.category.value: channel = make_category_channel(name, guild, permission_overwrites=perms) elif channel_type == discord.ChannelType.voice.value: channel = make_voice_channel(name, guild, permission_overwrites=perms) - else: raise NotImplementedError( "Operation occurred that isn't captured by the tests framework. This is dpytest's fault, please report" @@ -126,9 +137,9 @@ async def create_channel( ) return facts.dict_from_object(channel) - async def delete_channel(self, channel_id: int, *, reason: str = None) -> None: + async def delete_channel(self, channel_id: Snowflake, *, reason: str | None = None) -> None: locs = _get_higher_locs(1) - channel = locs.get("self", None) + channel = locs["self"] if channel.type.value == discord.ChannelType.text.value: delete_channel(channel) if channel.type.value == discord.ChannelType.category.value: @@ -138,11 +149,11 @@ async def delete_channel(self, channel_id: int, *, reason: str = None) -> None: if channel.type.value == discord.ChannelType.voice.value: delete_channel(channel) - async def get_channel(self, channel_id: int) -> _types.channel.PartialChannel: - await callbacks.dispatch_event("get_channel", channel_id) + async def get_channel(self, channel_id: Snowflake) -> _types.channel.Channel: + await callbacks.dispatch_event(CallbackEvent.get_channel, channel_id) find = None - for guild in _cur_config.state.guilds: + for guild in get_state().guilds: for channel in guild.channels: if channel.id == channel_id: find = facts.dict_from_object(channel) @@ -150,50 +161,52 @@ async def get_channel(self, channel_id: int) -> _types.channel.PartialChannel: raise discord.errors.NotFound(FakeRequest(404, "Not Found"), "Unknown Channel") return find - async def start_private_message(self, user_id: int) -> _types.channel.DMChannel: + async def start_private_message(self, user_id: Snowflake) -> _types.channel.DMChannel: locs = _get_higher_locs(1) - user = locs.get("self", None) + user = locs["self"] - await callbacks.dispatch_event("start_private_message", user) + await callbacks.dispatch_event(CallbackEvent.start_private_message, user) return facts.make_dm_channel_dict(user) async def send_message( self, - channel_id: int, + channel_id: Snowflake, *, params: dhttp.MultipartParameters ) -> _types.message.Message: locs = _get_higher_locs(1) - channel = locs.get("channel", None) + channel = locs["channel"] + poll = locs["poll"] payload = params.payload embeds = [] attachments = [] - content = None + content = "" tts = False nonce = None # EMBEDS if payload: - content = params.payload.get("content") - tts = params.payload.get("tts") - nonce = params.payload.get("nonce") + content = payload.get("content") or "" + tts = payload.get("tts") or False + nonce = payload.get("nonce") if payload.get("embeds"): - embeds = [discord.Embed.from_dict(e) for e in params.payload.get("embeds")] + embeds = [discord.Embed.from_dict(e) for e in payload.get("embeds", [])] # ATTACHMENTS if params.files: + paths = [] for file in params.files: - path = pathlib.Path(f"./dpytest_{self.fileno}.dat") - self.fileno += 1 + path = pathlib.Path(f"./dpytest_{FakeHttp.fileno}.dat") + FakeHttp.fileno += 1 if file.fp.seekable(): file.fp.seek(0) with open(path, "wb") as nfile: nfile.write(file.fp.read()) - attachments.append((path, file.filename)) - attachments = list(map(lambda x: make_attachment(*x), attachments)) + paths.append((path, file.filename)) + attachments = list(map(lambda x: make_attachment(*x), paths)) user = self.state.user if channel.guild: @@ -208,78 +221,85 @@ async def send_message( tts=tts, embeds=embeds, attachments=attachments, + poll=poll, nonce=nonce ) - await callbacks.dispatch_event("send_message", message) + await callbacks.dispatch_event(CallbackEvent.send_message, message) return facts.dict_from_object(message) - async def send_typing(self, channel_id: int) -> None: + async def send_typing(self, channel_id: Snowflake) -> None: locs = _get_higher_locs(1) channel = locs.get("channel", None) - await callbacks.dispatch_event("send_typing", channel) + await callbacks.dispatch_event(CallbackEvent.send_typing, channel) - async def delete_message(self, channel_id: int, message_id: int, *, reason: str | None = None) -> None: + async def delete_message(self, channel_id: Snowflake, message_id: Snowflake, *, + reason: str | None = None) -> None: locs = _get_higher_locs(1) - message = locs.get("self", None) + message = locs["self"] - await callbacks.dispatch_event("delete_message", message.channel, message, reason=reason) + await callbacks.dispatch_event(CallbackEvent.delete_message, message.channel, message, reason=reason) delete_message(message) - async def edit_message(self, channel_id: int, message_id: int, + async def edit_message(self, channel_id: Snowflake, message_id: Snowflake, **fields: dhttp.MultipartParameters) -> _types.message.Message: # noqa: E501 locs = _get_higher_locs(1) - message = locs.get("self", None) + message = locs["self"] - await callbacks.dispatch_event("edit_message", message.channel, message, fields) + await callbacks.dispatch_event(CallbackEvent.edit_message, message.channel, message, fields) return edit_message(message, **fields) - async def add_reaction(self, channel_id: int, message_id: int, emoji: str) -> None: + async def add_reaction(self, channel_id: Snowflake, message_id: Snowflake, + emoji: str) -> None: locs = _get_higher_locs(1) - message = locs.get("self") + message = locs["self"] # normally only the connected user can add a reaction, but for testing purposes we want to be able to force # the call from a specific user. user = locs.get("member", self.state.user) emoji = emoji # TODO: Turn this back into class? - await callbacks.dispatch_event("add_reaction", message, emoji) + await callbacks.dispatch_event(CallbackEvent.add_reaction, message, emoji) add_reaction(message, user, emoji) - async def remove_reaction(self, channel_id: int, message_id: int, emoji: str, member_id: int) -> None: + async def remove_reaction(self, channel_id: Snowflake, message_id: Snowflake, + emoji: str, + member_id: Snowflake) -> None: locs = _get_higher_locs(1) - message = locs.get("self") - member = locs.get("member") + message = locs["self"] + member = locs["member"] - await callbacks.dispatch_event("remove_reaction", message, emoji, member) + await callbacks.dispatch_event(CallbackEvent.remove_reaction, message, emoji, member) remove_reaction(message, member, emoji) - async def remove_own_reaction(self, channel_id: int, message_id: int, emoji: str) -> None: + async def remove_own_reaction(self, channel_id: Snowflake, message_id: Snowflake, + emoji: str) -> None: locs = _get_higher_locs(1) - message = locs.get("self") - member = locs.get("member") + message = locs["self"] + member = locs["member"] - await callbacks.dispatch_event("remove_own_reaction", message, emoji, member) + await callbacks.dispatch_event(CallbackEvent.remove_own_reaction, message, emoji, member) remove_reaction(message, self.state.user, emoji) - async def clear_reactions(self, channel_id: int, message_id: int) -> None: + async def clear_reactions(self, channel_id: Snowflake, message_id: Snowflake) -> None: locs = _get_higher_locs(1) - message = locs.get("self") + message = locs["self"] clear_reactions(message) - async def get_message(self, channel_id: int, message_id: int) -> _types.message.Message: + async def get_message(self, channel_id: Snowflake, + message_id: Snowflake) -> _types.message.Message: locs = _get_higher_locs(1) - channel = locs.get("self") + channel = locs["self"] - await callbacks.dispatch_event("get_message", channel, message_id) + await callbacks.dispatch_event(CallbackEvent.get_message, channel, message_id) - messages = _cur_config.messages[channel_id] + messages = get_config().messages[int(channel_id)] find = next(filter(lambda m: m["id"] == message_id, messages), None) if find is None: raise discord.errors.NotFound(FakeRequest(404, "Not Found"), "Unknown Message") @@ -287,18 +307,18 @@ async def get_message(self, channel_id: int, message_id: int) -> _types.message. async def logs_from( self, - channel_id: int, + channel_id: Snowflake, limit: int, - before: int | None = None, - after: int | None = None, - around: int | None = None + before: Snowflake | None = None, + after: Snowflake | None = None, + around: Snowflake | None = None ) -> list[_types.message.Message]: locs = _get_higher_locs(1) - channel = locs.get("self", None) + channel = locs["self"] - await callbacks.dispatch_event("logs_from", channel, limit, before=None, after=None, around=None) + await callbacks.dispatch_event(CallbackEvent.logs_from, channel, limit, before=None, after=None, around=None) - messages = _cur_config.messages[channel_id] + messages = get_config().messages[int(channel_id)] if after is not None: start = next(i for i, v in enumerate(messages) if v["id"] == after) return messages[start:start + limit] @@ -312,117 +332,134 @@ async def logs_from( start = next(i for i, v in enumerate(messages) if v["id"] == before) return messages[start - limit:start] - async def kick(self, user_id: int, guild_id: int, reason: str | None = None) -> None: + async def kick(self, user_id: Snowflake, guild_id: Snowflake, + reason: str | None = None) -> None: locs = _get_higher_locs(1) - guild = locs.get("self", None) - member = locs.get("user", None) + guild = locs["self"] + member = locs["user"] - await callbacks.dispatch_event("kick", guild, member, reason=reason) + await callbacks.dispatch_event(CallbackEvent.kick, guild, member, reason=reason) delete_member(member) - async def ban(self, user_id: int, guild_id: int, delete_message_days: int = 1, + async def ban(self, user_id: Snowflake, guild_id: Snowflake, + delete_message_days: int = 1, reason: str | None = None) -> None: locs = _get_higher_locs(1) - guild = locs.get("self", None) - member = locs.get("user", None) + guild = locs["self"] + member = locs["user"] - await callbacks.dispatch_event("ban", guild, member, delete_message_days, reason=reason) + await callbacks.dispatch_event(CallbackEvent.ban, guild, member, delete_message_days, reason=reason) delete_member(member) - async def unban(self, user_id: Snowflake, guild_id: Snowflake, *, reason: str | None = None) -> None: + async def unban(self, user_id: Snowflake, guild_id: Snowflake, *, + reason: str | None = None) -> None: locs = _get_higher_locs(1) - guild = locs.get("self", None) - member = locs.get("user", None) - await callbacks.dispatch_event("unban", guild, member, reason=reason) + guild = locs["self"] + member = locs["user"] + await callbacks.dispatch_event(CallbackEvent.unban, guild, member, reason=reason) - async def change_my_nickname(self, guild_id: int, nickname: str, *, + async def change_my_nickname(self, guild_id: Snowflake, nickname: str, *, reason: str | None = None) -> _types.member.Nickname: locs = _get_higher_locs(1) - me = locs.get("self", None) + me = locs["self"] me.nick = nickname - await callbacks.dispatch_event("change_nickname", nickname, me, reason=reason) + await callbacks.dispatch_event(CallbackEvent.change_nickname, nickname, me, reason=reason) return {"nick": nickname} - async def edit_member(self, guild_id: int, user_id: int, *, reason: str | None = None, - **fields: typing.Any) -> _types.guild.Member: + async def edit_member(self, guild_id: Snowflake, user_id: Snowflake, *, + reason: str | None = None, + **fields: Any) -> _types.member.MemberWithUser: locs = _get_higher_locs(1) - member = locs.get("self", None) + member = locs["self"] - await callbacks.dispatch_event("edit_member", fields, member, reason=reason) + await callbacks.dispatch_event(CallbackEvent.edit_member, fields, member, reason=reason) member = update_member(member, nick=fields.get('nick'), roles=fields.get('roles')) return facts.dict_from_object(member) - async def get_member(self, guild_id: int, member_id: int) -> _types.guild.Member: + async def get_members( + self, guild_id: Snowflake, limit: int, after: Snowflake | None + ) -> list[member.MemberWithUser]: + locs = _get_higher_locs(1) + guild = locs["self"] + return list(map(facts.dict_from_object, guild.members)) + + async def get_member(self, guild_id: Snowflake, + member_id: Snowflake) -> _types.member.MemberWithUser: locs = _get_higher_locs(1) - guild: discord.Guild = locs.get("self", None) + guild = locs["self"] member = discord.utils.get(guild.members, id=member_id) return facts.dict_from_object(member) - async def edit_role(self, guild_id: int, role_id: int, *, reason: str | None = None, - **fields: typing.Any) -> _types.role.Role: + async def edit_role(self, guild_id: Snowflake, role_id: Snowflake, *, + reason: str | None = None, + **fields: Any) -> _types.role.Role: locs = _get_higher_locs(1) - role = locs.get("self") + role = locs["self"] guild = role.guild - await callbacks.dispatch_event("edit_role", guild, role, fields, reason=reason) + await callbacks.dispatch_event(CallbackEvent.edit_role, guild, role, fields, reason=reason) update_role(role, **fields) return facts.dict_from_object(role) - async def delete_role(self, guild_id: int, role_id: int, *, reason: str | None = None) -> None: + async def delete_role(self, guild_id: Snowflake, role_id: Snowflake, *, + reason: str | None = None) -> None: locs = _get_higher_locs(1) - role = locs.get("self") + role = locs["self"] guild = role.guild - await callbacks.dispatch_event("delete_role", guild, role, reason=reason) + await callbacks.dispatch_event(CallbackEvent.delete_role, guild, role, reason=reason) delete_role(role) - async def create_role(self, guild_id: int, *, reason: str | None = None, - **fields: typing.Any) -> _types.role.Role: + async def create_role(self, guild_id: Snowflake, *, reason: str | None = None, + **fields: Any) -> _types.role.Role: locs = _get_higher_locs(1) - guild = locs.get("self", None) + guild = locs["self"] role = make_role(guild=guild, **fields) - await callbacks.dispatch_event("create_role", guild, role, reason=reason) + await callbacks.dispatch_event(CallbackEvent.create_role, guild, role, reason=reason) return facts.dict_from_object(role) - async def move_role_position(self, guild_id: int, positions: list[_types.guild.RolePositionUpdate], *, + async def move_role_position(self, guild_id: Snowflake, + positions: list[_types.guild.RolePositionUpdate], *, reason: str | None = None) -> list[_types.role.Role]: locs = _get_higher_locs(1) - role = locs.get("self", None) + role = locs["self"] guild = role.guild - await callbacks.dispatch_event("move_role", guild, role, positions, reason=reason) + await callbacks.dispatch_event(CallbackEvent.move_role, guild, role, positions, reason=reason) for pair in positions: guild._roles[pair["id"]].position = pair["position"] return list(guild._roles.values()) - async def add_role(self, guild_id: int, user_id: int, role_id: int, *, reason: str | None = None) -> None: + async def add_role(self, guild_id: Snowflake, user_id: Snowflake, + role_id: Snowflake, *, reason: str | None = None) -> None: locs = _get_higher_locs(1) - member = locs.get("self", None) - role = locs.get("role", None) + member = locs["self"] + role = locs["role"] - await callbacks.dispatch_event("add_role", member, role, reason=reason) + await callbacks.dispatch_event(CallbackEvent.add_role, member, role, reason=reason) roles = [role] + [x for x in member.roles if x.id != member.guild.id] update_member(member, roles=roles) - async def remove_role(self, guild_id: int, user_id: int, role_id: int, *, + async def remove_role(self, guild_id: Snowflake, user_id: Snowflake, + role_id: Snowflake, *, reason: str | None = None) -> None: locs = _get_higher_locs(1) - member = locs.get("self", None) - role = locs.get("role", None) + member = locs["self"] + role = locs["role"] - await callbacks.dispatch_event("remove_role", member, role, reason=reason) + await callbacks.dispatch_event(CallbackEvent.remove_role, member, role, reason=reason) roles = [x for x in member.roles if x != role and x.id != member.guild.id] update_member(member, roles=roles) @@ -445,18 +482,22 @@ async def application_info(self) -> _types.appinfo.AppInfo: } appinfo = discord.AppInfo(self.state, data) - await callbacks.dispatch_event("app_info", appinfo) + await callbacks.dispatch_event(CallbackEvent.app_info, appinfo) return data - async def delete_channel_permissions(self, channel_id: int, target_id: int, *, + async def delete_channel_permissions(self, channel_id: Snowflake, + target_id: Snowflake, *, reason: str | None = None) -> None: locs = _get_higher_locs(1) - channel: discord.TextChannel = locs.get("self", None) - target = locs.get("target", None) + channel: discord.TextChannel = locs["self"] + target = locs["target"] user = self.state.user - perm: discord.Permissions = channel.permissions_for(channel.guild.get_member(user.id)) + member = channel.guild.get_member(user.id) + if member is None: + raise RuntimeError(f"Couldn't find user {user.id} in guild {channel.guild.id}") + perm: discord.Permissions = channel.permissions_for(member) if not (perm.administrator or perm.manage_permissions): raise discord.errors.Forbidden(FakeRequest(403, "missing manage_roles"), "manage_roles") @@ -464,24 +505,28 @@ async def delete_channel_permissions(self, channel_id: int, target_id: int, *, async def edit_channel_permissions( self, - channel_id: int, - target_id: int, - allow_value: int, - deny_value: int, - perm_type: str, + channel_id: Snowflake, + target_id: Snowflake, + allow_value: str, + deny_value: str, + perm_type: Literal[0, 1], *, reason: str | None = None ) -> None: locs = _get_higher_locs(1) - channel: discord.TextChannel = locs.get("self", None) - target = locs.get("target", None) + channel: discord.TextChannel = locs["self"] + target = locs["target"] user = self.state.user - perm: discord.Permissions = channel.permissions_for(channel.guild.get_member(user.id)) + member = channel.guild.get_member(user.id) + if member is None: + raise RuntimeError(f"Couldn't find user {user.id} in guild {channel.guild.id}") + perm: discord.Permissions = channel.permissions_for(member) if not (perm.administrator or perm.manage_permissions): raise discord.errors.Forbidden(FakeRequest(403, "missing manage_roles"), "manage_roles") - ovr = discord.PermissionOverwrite.from_pair(discord.Permissions(allow_value), discord.Permissions(deny_value)) + ovr = discord.PermissionOverwrite.from_pair(discord.Permissions(int(allow_value)), + discord.Permissions(int(deny_value))) update_text_channel(channel, target, ovr) async def get_from_cdn(self, url: str) -> bytes: @@ -490,50 +535,40 @@ async def get_from_cdn(self, url: str) -> bytes: with open(path, 'rb') as fd: return fd.read() - async def get_user(self, user_id: int) -> _types.user.User: + async def get_user(self, user_id: Snowflake) -> _types.user.User: # return self.request(Route('GET', '/users/{user_id}', user_id=user_id)) locs = _get_higher_locs(1) - client = locs.get("self", None) + client = locs["self"] guild = client.guilds[0] member = discord.utils.get(guild.members, id=user_id) return facts.dict_from_object(member._user) - async def pin_message(self, channel_id: int, message_id: int, reason: str | None = None) -> None: + async def pin_message(self, channel_id: Snowflake, message_id: Snowflake, + reason: str | None = None) -> None: # return self.request(Route('PUT', '/channels/{channel_id}/pins/{message_id}', # channel_id=channel_id, message_id=message_id), reason=reason) pin_message(channel_id, message_id) - async def unpin_message(self, channel_id: int, message_id: int, reason: str | None = None) -> None: + async def unpin_message(self, channel_id: Snowflake, message_id: Snowflake, + reason: str | None = None) -> None: # return self.request(Route('DELETE', '/channels/{channel_id}/pins/{message_id}', # channel_id=channel_id, message_id=message_id), reason=reason) unpin_message(channel_id, message_id) - async def get_guilds(self, limit: int, before: Snowflake | None = None, after: Snowflake | None = None, - with_counts: bool = True): + async def get_guilds(self, limit: int, before: Snowflake | None = None, + after: Snowflake | None = None, + with_counts: bool = True) -> list[_types.guild.Guild]: # self.request(Route('GET', '/users/@me/guilds') - await callbacks.dispatch_event("get_guilds", limit, before=before, after=after, with_counts=with_counts) + await callbacks.dispatch_event( + CallbackEvent.get_guilds, + limit, + before=before, + after=after, + with_counts=with_counts, + ) guilds = get_state().guilds # List[] - guilds_new = [{ - 'id': guild.id, - 'name': guild.name, - 'icon': guild.icon, - 'splash': guild.splash, - 'owner_id': guild.owner_id, - 'region': guild.region, - 'afk_channel_id': guild.afk_channel.id if guild.afk_channel else None, - 'afk_timeout': guild.afk_timeout, - 'verification_level': guild.verification_level, - 'default_message_notifications': guild.default_notifications.value, - 'explicit_content_filter': guild.explicit_content_filter, - 'roles': list(map(facts.dict_from_object, guild.roles)), - 'emojis': list(map(facts.dict_from_object, guild.emojis)), - 'features': guild.features, - 'mfa_level': guild.mfa_level, - 'application_id': None, - 'system_channel_id': guild.system_channel.id if guild.system_channel else None, - 'owner': guild.owner_id == get_state().user.id - } for guild in guilds] + guilds_new = [facts.dict_from_object(guild) for guild in guilds] if not limit: limit = 100 @@ -551,8 +586,10 @@ async def get_guild(self, guild_id: Snowflake, *, with_counts: bool = True) -> _ # return self.request(Route('GET', '/guilds/{guild_id}', guild_id=guild_id)) # TODO: Respect with_counts locs = _get_higher_locs(1) - client: discord.Client = locs.get("self", None) + client: discord.Client = locs["self"] guild = discord.utils.get(client.guilds, id=guild_id) + if guild is None: + raise RuntimeError(f"Couldn't find guild with ID {guild_id} in test client") return facts.dict_from_object(guild) @@ -567,11 +604,17 @@ def get_state() -> dstate.FakeState: return _cur_config.state +def get_config() -> BackendState: + if _cur_config is None: + raise ValueError("Dpytest backend not configured") + return _cur_config + + def make_guild( name: str, - members: list[discord.Member] = None, - channels: list[_types.AnyChannel] = None, - roles: list[discord.Role] = None, + members: list[discord.Member] | None = None, + channels: list[_types.AnyChannel] | None = None, + roles: list[_types.role.Role] | None = None, owner: bool = False, id_num: int = -1, ) -> discord.Guild: @@ -600,16 +643,16 @@ def make_guild( owner_id = state.user.id if owner else 0 - data: _types.gateway.GuildCreateEvent = facts.make_guild_dict( + data: _types.gateway.Guild = facts.make_guild_dict( name, owner_id, roles, id_num=id_num, member_count=member_count, members=members, channels=channels ) state.parse_guild_create(data) - return state._get_guild(id_num) + return state._get_guild(id_num) # type: ignore[return-value] -def update_guild(guild: discord.Guild, roles: list[discord.Role] = None) -> discord.Guild: +def update_guild(guild: discord.Guild, roles: list[discord.Role] | None = None) -> discord.Guild: """ Update an existing guild with new information, triggers a guild update but not any individual item create/edit calls @@ -661,7 +704,7 @@ def make_role( # r_dict["position"] = max(map(lambda x: x.position, guild._roles.values())) + 1 r_dict["position"] = 1 - data = { + data: _types.gateway._GuildRoleEvent = { "guild_id": guild.id, "role": r_dict } @@ -669,7 +712,7 @@ def make_role( state = get_state() state.parse_guild_role_create(data) - return guild.get_role(r_dict["id"]) + return guild.get_role(int(r_dict["id"])) # type: ignore[return-value] def update_role( @@ -696,7 +739,7 @@ def update_role( :param name: New name for the role :return: Role that was updated """ - data = { + data: _types.gateway._GuildRoleEvent = { "guild_id": role.guild.id, "role": facts.dict_from_object(role), } @@ -707,8 +750,7 @@ def update_role( if colors is not None: data["role"]["colors"] = colors if permissions is not None: - data["role"]["permissions"] = int(permissions) - data["role"]["permissions_new"] = int(permissions) + data["role"]["permissions"] = str(permissions) if hoist is not None: data["role"]["hoist"] = hoist @@ -750,7 +792,7 @@ def make_text_channel( state = get_state() state.parse_channel_create(c_dict) - return guild.get_channel(c_dict["id"]) + return guild.get_channel(int(c_dict["id"])) # type: ignore[return-value] def make_category_channel( @@ -767,7 +809,7 @@ def make_category_channel( state = get_state() state.parse_channel_create(c_dict) - return guild.get_channel(c_dict["id"]) + return guild.get_channel(int(c_dict["id"])) # type: ignore[return-value] def make_voice_channel( @@ -789,10 +831,10 @@ def make_voice_channel( state = get_state() state.parse_channel_create(c_dict) - return guild.get_channel(c_dict["id"]) + return guild.get_channel(int(c_dict["id"])) # type: ignore[return-value] -def delete_channel(channel: _types.AnyChannel) -> None: +def delete_channel(channel: discord.abc.GuildChannel) -> None: c_dict = facts.make_text_channel_dict(channel.name, id_num=channel.id, guild_id=channel.guild.id) state = get_state() @@ -801,17 +843,17 @@ def delete_channel(channel: _types.AnyChannel) -> None: def update_text_channel( channel: discord.TextChannel, - target: discord.User | discord.Role, - override: discord.PermissionOverwrite | None = _undefined + target: discord.Member | discord.Role | discord.Object, + override: discord.PermissionOverwrite | None | Undef = undefined ) -> None: c_dict = facts.dict_from_object(channel) - if override is not _undefined: + if override is not undefined: ovr = c_dict.get("permission_overwrites", []) existing = [o for o in ovr if o.get("id") == target.id] if existing: ovr.remove(existing[0]) if override: - ovr = ovr + [facts.dict_from_overwrite(target, override)] + ovr = ovr + [facts.dict_from_object(override, target=target)] c_dict["permission_overwrites"] = ovr state = get_state() @@ -831,24 +873,27 @@ def make_user(username: str, discrim: str | int, avatar: str | None = None, return user -def make_member(user: discord.user.BaseUser | discord.abc.User, guild: discord.Guild, +def make_member(user: discord.user.BaseUser, guild: discord.Guild, nick: str | None = None, roles: list[discord.Role] | None = None) -> discord.Member: if roles is None: roles = [] - roles = list(map(lambda x: x.id, roles)) + role_ids: list[Snowflake] = list(map(lambda x: x.id, roles)) - data = facts.make_member_dict(guild, user, roles, nick=nick) + data: _types.gateway.GuildMemberAddEvent = { + 'guild_id': guild.id, + **facts.make_member_dict(user, role_ids, nick=nick), + } state = get_state() state.parse_guild_member_add(data) - return guild.get_member(user.id) + return guild.get_member(user.id) # type: ignore[return-value] def update_member(member: discord.Member, nick: str | None = None, roles: list[discord.Role] | None = None) -> discord.Member: - data = facts.dict_from_object(member) + data = facts.dict_from_object(member, guild=True) if nick is not None: data["nick"] = nick if roles is not None: @@ -861,81 +906,85 @@ def update_member(member: discord.Member, nick: str | None = None, def delete_member(member: discord.Member) -> None: - out = facts.dict_from_object(member) + out = facts.dict_from_object(member, guild=True) state = get_state() state.parse_guild_member_remove(out) def make_message( content: str, - author: discord.user.BaseUser | discord.abc.User, + author: discord.user.BaseUser | discord.Member, channel: _types.AnyChannel, tts: bool = False, embeds: list[discord.Embed] | None = None, attachments: list[discord.Attachment] | None = None, + poll: discord.Poll | None = None, nonce: int | None = None, id_num: int = -1, ) -> discord.Message: guild = channel.guild if hasattr(channel, "guild") else None guild_id = guild.id if guild else None - mentions = find_user_mentions(content, guild) + mentions = find_member_mentions(content, guild) role_mentions = find_role_mentions(content, guild) channel_mentions = find_channel_mentions(content, guild) - kwargs = {} + kwargs: dict[str, Any] = {} if nonce is not None: kwargs["nonce"] = nonce data = facts.make_message_dict( channel, author, id_num, content=content, mentions=mentions, tts=tts, embeds=embeds, attachments=attachments, - mention_roles=role_mentions, mention_channels=channel_mentions, guild_id=guild_id, **kwargs + poll=facts.dict_from_object(poll) if poll else None, mention_roles=role_mentions, + mention_channels=channel_mentions, guild_id=guild_id, **kwargs ) state = get_state() state.parse_message_create(data) - if channel.id not in _cur_config.messages: - _cur_config.messages[channel.id] = [] - _cur_config.messages[channel.id].append(data) + messages = get_config().messages + if channel.id not in messages: + messages[channel.id] = [] + messages[channel.id].append(data) - return state._get_message(data["id"]) + return state._get_message(int(data["id"])) # type: ignore[return-value] def edit_message( message: discord.Message, **fields: dhttp.MultipartParameters ) -> _types.message.Message: data = facts.dict_from_object(message) - payload = fields.get("params").payload + payload = fields["params"].payload # TODO : do something for files and stuff. # if params.files: # return self.request(r, files=params.files, form=params.multipart) # else: # return self.request(r, json=params.payload) - data.update(payload) + data.update(payload) # type: ignore[typeddict-item] + config = get_config() i = 0 - while i < len(_cur_config.messages[message.channel.id]): - if _cur_config.messages[message.channel.id][i].get("id") == data.get("id"): - _cur_config.messages[message.channel.id][i] = data + while i < len(config.messages[message.channel.id]): + if config.messages[message.channel.id][i].get("id") == data.get("id"): + config.messages[message.channel.id][i] = data i += 1 return data -MEMBER_MENTION: typing.Pattern = re.compile(r"<@!?[0-9]{17,21}>", re.MULTILINE) -ROLE_MENTION: typing.Pattern = re.compile(r"<@&([0-9]{17,21})>", re.MULTILINE) -CHANNEL_MENTION: typing.Pattern = re.compile(r"<#[0-9]{17,21}>", re.MULTILINE) +MEMBER_MENTION: Pattern[str] = re.compile(r"<@!?([0-9]{17,21})>", re.MULTILINE) +ROLE_MENTION: Pattern[str] = re.compile(r"<@&([0-9]{17,21})>", re.MULTILINE) +CHANNEL_MENTION: Pattern[str] = re.compile(r"<#([0-9]{17,21})>", re.MULTILINE) -def find_user_mentions(content: str | None, guild: discord.Guild | None) -> list[discord.Member]: +def find_member_mentions(content: str | None, guild: discord.Guild | None) -> list[discord.Member | discord.User]: if guild is None or content is None: return [] # TODO: Check for dm user mentions matches = re.findall(MEMBER_MENTION, content) - return [discord.utils.get(guild.members, id=int(re.search(r'\d+', match)[0])) for match in matches] # noqa: E501 + return [discord.utils.get(guild.members, id=int(match)) for match in matches] # type: ignore[misc] -def find_role_mentions(content: str | None, guild: discord.Guild | None) -> list[int]: +def find_role_mentions(content: str | None, guild: discord.Guild | None) -> list[Snowflake]: if guild is None or content is None: return [] matches = re.findall(ROLE_MENTION, content) @@ -948,11 +997,11 @@ def find_channel_mentions(content: str | None, if guild is None or content is None: return [] matches = re.findall(CHANNEL_MENTION, content) - return [discord.utils.get(guild.channels, mention=match) for match in matches] + return [discord.utils.get(guild.channels, id=int(match)) for match in matches] # type: ignore[misc] def delete_message(message: discord.Message) -> None: - data = { + data: _types.gateway.MessageDeleteEvent = { "id": message.id, "channel_id": message.channel.id } @@ -962,9 +1011,9 @@ def delete_message(message: discord.Message) -> None: state = get_state() state.parse_message_delete(data) - messages = _cur_config.messages[message.channel.id] + messages = get_config().messages[message.channel.id] index = next(i for i, v in enumerate(messages) if v["id"] == message.id) - del _cur_config.messages[message.channel.id][index] + del get_config().messages[message.channel.id][index] def make_attachment(filename: pathlib.Path, name: str | None = None, id_num: int = -1) -> discord.Attachment: @@ -984,12 +1033,12 @@ def add_reaction(message: discord.Message, user: discord.user.BaseUser | discord emoji: str) -> None: if ":" in emoji: temp = emoji.split(":") - emoji: _types.message.PartialEmoji = { + partial: _types.message.PartialEmoji = { "id": temp[0], "name": temp[1] } else: - emoji = { + partial = { "id": None, "name": emoji } @@ -998,7 +1047,7 @@ def add_reaction(message: discord.Message, user: discord.user.BaseUser | discord "message_id": message.id, "channel_id": message.channel.id, "user_id": user.id, - "emoji": emoji, + "emoji": partial, "burst": False, "type": 0, } @@ -1011,7 +1060,7 @@ def add_reaction(message: discord.Message, user: discord.user.BaseUser | discord state = get_state() state.parse_message_reaction_add(data) - messages = _cur_config.messages[message.channel.id] + messages = get_config().messages[message.channel.id] message_data = next(filter(lambda x: x["id"] == message.id, messages), None) if message_data is not None: if "reactions" not in message_data: @@ -1019,14 +1068,14 @@ def add_reaction(message: discord.Message, user: discord.user.BaseUser | discord react: _types.message.Reaction | None = None for react in message_data["reactions"]: - if react["emoji"]["id"] == emoji["id"] and react["emoji"]["name"] == emoji["name"]: + if react["emoji"]["id"] == partial["id"] and react["emoji"]["name"] == partial["name"]: break if react is None: - react: _types.message.Reaction = { + react = { "count": 0, "me": False, - "emoji": emoji, + "emoji": partial, "me_burst": False, "count_details": { "burst": 0, @@ -1042,15 +1091,15 @@ def add_reaction(message: discord.Message, user: discord.user.BaseUser | discord react["me"] = True -def remove_reaction(message: discord.Message, user: discord.user.BaseUser, emoji: str) -> None: +def remove_reaction(message: discord.Message, user: discord.abc.Snowflake, emoji: str) -> None: if ":" in emoji: temp = emoji.split(":") - emoji: _types.message.PartialEmoji = { + partial: _types.message.PartialEmoji = { "id": temp[0], "name": temp[1] } else: - emoji = { + partial = { "id": None, "name": emoji } @@ -1059,7 +1108,7 @@ def remove_reaction(message: discord.Message, user: discord.user.BaseUser, emoji "message_id": message.id, "channel_id": message.channel.id, "user_id": user.id, - "emoji": emoji, + "emoji": partial, "burst": False, "type": 0, } @@ -1069,7 +1118,7 @@ def remove_reaction(message: discord.Message, user: discord.user.BaseUser, emoji state = get_state() state.parse_message_reaction_remove(data) - messages = _cur_config.messages[message.channel.id] + messages = get_config().messages[message.channel.id] message_data = next(filter(lambda x: x["id"] == message.id, messages), None) if message_data is not None: if "reactions" not in message_data: @@ -1077,7 +1126,7 @@ def remove_reaction(message: discord.Message, user: discord.user.BaseUser, emoji react: _types.message.Reaction | None = None for react in message_data["reactions"]: - if react["emoji"]["id"] == emoji["id"] and react["emoji"]["name"] == emoji["name"]: + if react["emoji"]["id"] == partial["id"] and react["emoji"]["name"] == partial["name"]: break if react is None: return @@ -1091,8 +1140,8 @@ def remove_reaction(message: discord.Message, user: discord.user.BaseUser, emoji message_data["reactions"].remove(react) -def clear_reactions(message: discord.Message): - data = { +def clear_reactions(message: discord.Message) -> None: + data: _types.gateway.MessageReactionRemoveAllEvent = { "message_id": message.id, "channel_id": message.channel.id } @@ -1102,14 +1151,14 @@ def clear_reactions(message: discord.Message): state = get_state() state.parse_message_reaction_remove_all(data) - messages = _cur_config.messages[message.channel.id] + messages = get_config().messages[message.channel.id] message_data = next(filter(lambda x: x["id"] == message.id, messages), None) if message_data is not None: message_data["reactions"] = [] -def pin_message(channel_id: int, message_id: int): - data = { +def pin_message(channel_id: Snowflake, message_id: Snowflake) -> None: + data: _types.gateway.ChannelPinsUpdateEvent = { "channel_id": channel_id, "last_pin_timestamp": datetime.datetime.now().isoformat(), } @@ -1117,8 +1166,8 @@ def pin_message(channel_id: int, message_id: int): state.parse_channel_pins_update(data) -def unpin_message(channel_id: int, message_id: int): - data = { +def unpin_message(channel_id: Snowflake, message_id: Snowflake) -> None: + data: _types.gateway.ChannelPinsUpdateEvent = { "channel_id": channel_id, "last_pin_timestamp": None, } @@ -1126,11 +1175,11 @@ def unpin_message(channel_id: int, message_id: int): state.parse_channel_pins_update(data) -@typing.overload +@overload def configure(client: discord.Client) -> None: ... -@typing.overload +@overload def configure(client: discord.Client | None, *, use_dummy: bool = ...) -> None: ... @@ -1141,11 +1190,11 @@ def configure(client: discord.Client | None, *, use_dummy: bool = False) -> None :param client: Client to use, or None :param use_dummy: Whether to use a dummy if client param is None, or error """ - global _cur_config, _messages + global _cur_config if client is None and use_dummy: log.info("None passed to backend configuration, dummy client will be used") - client = discord.Client() + client = discord.Client(intents=discord.Intents.all()) if not isinstance(client, discord.Client): raise TypeError("Runner client must be an instance of discord.Client") diff --git a/discord/ext/test/callbacks.py b/discord/ext/test/callbacks.py index 8a43467..e7660f2 100644 --- a/discord/ext/test/callbacks.py +++ b/discord/ext/test/callbacks.py @@ -5,15 +5,52 @@ """ import logging -import typing +import discord +from enum import Enum +from typing import Callable, overload, Literal, Any, Awaitable + from . import _types +GetChannelCallback = Callable[[_types.snowflake.Snowflake], Awaitable[None]] +SendMessageCallback = Callable[[discord.Message], Awaitable[None]] +EditMemberCallback = Callable[[dict[str, Any], discord.Member, str | None], Awaitable[None]] +Callback = GetChannelCallback | SendMessageCallback | EditMemberCallback | Callable[..., Awaitable[None]] + log = logging.getLogger("discord.ext.tests") -_callbacks = {} + +class CallbackEvent(Enum): + get_channel = "get_channel" + presence = "presence" + start_private_message = "start_private_message" + send_message = "send_message" + send_typing = "send_typing" + delete_message = "delete_message" + edit_message = "edit_message" + add_reaction = "add_reaction" + remove_reaction = "remove_reaction" + remove_own_reaction = "remove_own_reaction" + get_message = "get_message" + logs_from = "logs_from" + kick = "kick" + ban = "ban" + unban = "unban" + change_nickname = "change_nickname" + edit_member = "edit_member" + create_role = "create_role" + edit_role = "edit_role" + delete_role = "delete_role" + move_role = "move_role" + add_role = "add_role" + remove_role = "remove_role" + app_info = "app_info" + get_guilds = "get_guilds" -async def dispatch_event(event: str, *args: typing.Any, **kwargs: typing.Any) -> None: +_callbacks: dict[CallbackEvent, Callback] = {} + + +async def dispatch_event(event: CallbackEvent, *args: Any, **kwargs: Any) -> None: """ Dispatch an event to a set handler, if one exists. Will ignore handler errors, just print a log @@ -30,7 +67,19 @@ async def dispatch_event(event: str, *args: typing.Any, **kwargs: typing.Any) -> log.error(f"Error in handler for event {event}: {e}") -def set_callback(cb: _types.Callback, event: str) -> None: +@overload +def set_callback(cb: GetChannelCallback, event: Literal[CallbackEvent.get_channel]) -> None: ... + + +@overload +def set_callback(cb: SendMessageCallback, event: Literal[CallbackEvent.send_message]) -> None: ... + + +@overload +def set_callback(cb: EditMemberCallback, event: Literal[CallbackEvent.edit_member]) -> None: ... + + +def set_callback(cb: Callback, event: CallbackEvent) -> None: """ Set the callback to use for a specific event @@ -40,7 +89,7 @@ def set_callback(cb: _types.Callback, event: str) -> None: _callbacks[event] = cb -def get_callback(event: str) -> _types.Callback: +def get_callback(event: CallbackEvent) -> Callback: """ Get the current callback for an event, or raise an exception if one isn't set @@ -52,7 +101,7 @@ def get_callback(event: str) -> _types.Callback: return _callbacks[event] -def remove_callback(event: str) -> _types.Callback | None: +def remove_callback(event: CallbackEvent) -> Callback | None: """ Remove the callback set for an event, returning it, or None if one isn't set diff --git a/discord/ext/test/factories.py b/discord/ext/test/factories.py index 6a10e86..ffd73b5 100644 --- a/discord/ext/test/factories.py +++ b/discord/ext/test/factories.py @@ -3,11 +3,16 @@ for the rest of the library, which often needs to convert between objects and JSON at various stages. """ import functools -import typing import datetime as dt +from typing import Any, Literal, overload, Iterable, Protocol, NoReturn, Callable, ParamSpec, TypeVar + import discord from . import _types + +P = ParamSpec('P') +T = TypeVar('T') + generated_ids: int = 0 @@ -28,88 +33,119 @@ def make_id() -> int: return int(discord_epoch + worker + process + generated, 2) -@typing.overload +@overload def _fill_optional( data: _types.user.User, - obj: discord.User | dict[str, typing.Any], - items: typing.Iterable[str] + obj: discord.user.BaseUser | dict[str, object], + items: Iterable[str] ) -> None: ... -@typing.overload +@overload def _fill_optional( - data: _types.user.User, - obj: discord.ClientUser | dict[str, typing.Any], - items: typing.Iterable[str] + data: _types.member.Member | _types.member.MemberWithUser, + obj: discord.Member | dict[str, object], + items: Iterable[str] ) -> None: ... -@typing.overload +@overload def _fill_optional( - data: _types.guild.Member, - obj: discord.Member | dict[str, typing.Any], - items: typing.Iterable[str] + data: _types.gateway.GuildMemberUpdateEvent, + obj: discord.Member | dict[str, object], + items: Iterable[str] ) -> None: ... -@typing.overload +@overload def _fill_optional( data: _types.guild.Guild, - obj: discord.Guild | dict[str, typing.Any], - items: typing.Iterable[str] + obj: discord.Guild | dict[str, object], + items: Iterable[str] ) -> None: ... -@typing.overload +@overload def _fill_optional( data: _types.channel.PartialChannel, - obj: _types.AnyChannel | dict[str, typing.Any], - items: typing.Iterable[str] + obj: _types.AnyChannel | dict[str, object], + items: Iterable[str] ) -> None: ... -@typing.overload +@overload def _fill_optional( data: _types.message.Message, - obj: discord.Message | dict[str, typing.Any], - items: typing.Iterable[str] + obj: discord.Message | dict[str, object], + items: Iterable[str] ) -> None: ... -@typing.overload +@overload def _fill_optional( data: _types.emoji.Emoji, - obj: discord.Emoji | dict[str, typing.Any], - items: typing.Iterable[str] + obj: discord.Emoji | dict[str, object], + items: Iterable[str] ) -> None: ... -@typing.overload +@overload def _fill_optional( data: _types.sticker.GuildSticker, - obj: discord.GuildSticker | dict[str, typing.Any], - items: typing.Iterable[str] + obj: discord.GuildSticker | dict[str, object], + items: Iterable[str] ) -> None: ... -def _fill_optional(data: dict[str, typing.Any], obj: typing.Any, items: typing.Iterable[str]) -> None: - if isinstance(obj, dict): - obj: dict[str, typing.Any] - for item in items: - result = obj.pop(item, None) - if result is None: - continue - data[item] = result - if len(obj) > 0: - print("Warning: Invalid attributes passed") - else: - for item in items: - if hasattr(obj, item): - data[item] = getattr(obj, item) +@overload +def _fill_optional( + data: _types.poll.PollMedia, + obj: discord.PollMedia | dict[str, object], + items: Iterable[str] +) -> None: ... -def make_user_dict(username: str, discrim: str | int, avatar: str, id_num: int = -1, flags: int = 0, - **kwargs: typing.Any) -> _types.user.User: +def _fill_optional( # type: ignore[misc] + data: dict[str, object], + obj: object | dict[str, object], + items: Iterable[str] +) -> None: + if isinstance(obj, dict): + _fill_optional_dict(data, obj, items) + else: + _fill_optional_value(data, obj, items) + + +def _fill_optional_dict( + data: dict[str, object], + obj: dict[str, object], + items: Iterable[str], +) -> None: + for item in items: + result = obj.pop(item, None) + if result is None: + continue + data[item] = result + if len(obj) > 0: + print("Warning: Invalid attributes passed") + + +def _fill_optional_value( + data: dict[str, object], + obj: object, + items: Iterable[str], +) -> None: + for item in items: + if (val := getattr(obj, item, None)) is None and (val := getattr(obj, f"_{item}", None)) is None: + continue + if isinstance(val, discord.Poll): + data[item] = dict_from_object(val) + else: + data[item] = val + + +def make_user_dict(username: str, discrim: str | int, avatar: str | None, id_num: int = -1, flags: int = 0, + **kwargs: Any) -> _types.user.User: if isinstance(discrim, int): assert 0 < discrim < 10000 discrim = f"{discrim:04}" @@ -125,23 +161,21 @@ def make_user_dict(username: str, discrim: str | int, avatar: str, id_num: int = 'avatar': avatar, 'flags': flags, } - items: typing.Final = ("bot", "mfa_enabled", "locale", "verified", "email", "premium_type") + items = ("bot", "system", "mfa_enabled", "locale", "verified", "email", "premium_type", "public_flags") _fill_optional(out, kwargs, items) return out def make_member_dict( - guild: discord.Guild, - user: discord.User, - roles: list[int], + user: discord.user.BaseUser, + roles: list[_types.gateway.Snowflake], joined: str | None = None, deaf: bool = False, mute: bool = False, flags: int = 0, - **kwargs: typing.Any, -) -> _types.guild.Member: - out: _types.guild.Member = { - 'guild_id': guild.id, + **kwargs: Any, +) -> _types.member.MemberWithUser: + out: _types.member.MemberWithUser = { 'user': dict_from_object(user), 'roles': roles, 'joined_at': joined, @@ -149,101 +183,116 @@ def make_member_dict( 'mute': mute, 'flags': flags, } - items = ("nick",) + items = ("avatar", "nick", "premium_since", "pending", "permissions", "communication_disabled_until", + "avatar_decoration_data") _fill_optional(out, kwargs, items) return out def user_with_member(user: discord.User | discord.Member) -> _types.member.UserWithMember: if isinstance(user, discord.Member): - member = dict_from_object(user) + member: _types.member.MemberWithUser | None = dict_from_object(user) user = user._user else: member = None - out = dict_from_object(user) + out: _types.member.UserWithMember = dict_from_object(user) if member: out['member'] = member return out -@typing.overload -def dict_from_object(obj: discord.User) -> _types.user.User: ... - - -@typing.overload -def dict_from_object(obj: discord.Member) -> _types.member.MemberWithUser: ... - - -@typing.overload -def dict_from_object(obj: discord.Role) -> _types.role.Role: ... - - -@typing.overload -def dict_from_object(obj: discord.TextChannel) -> _types.channel.TextChannel: ... - - -@typing.overload -def dict_from_object(obj: discord.DMChannel) -> _types.channel.DMChannel: ... - - -@typing.overload -def dict_from_object(obj: discord.CategoryChannel) -> _types.channel.CategoryChannel: ... - - -@typing.overload -def dict_from_object(obj: discord.VoiceChannel) -> _types.channel.VoiceChannel: ... - - -@typing.overload -def dict_from_object(obj: discord.Message) -> _types.message.Message: ... - - -@typing.overload -def dict_from_object(obj: discord.Attachment) -> _types.message.Attachment: ... - - -@typing.overload -def dict_from_object(obj: discord.Emoji) -> _types.emoji.Emoji: ... - - -@typing.overload -def dict_from_object(obj: discord.Sticker) -> _types.sticker.Sticker: ... - - -@typing.overload -def dict_from_object(obj: discord.StageInstance) -> _types.channel.StageInstance: ... - - -@typing.overload -def dict_from_object(obj: discord.ScheduledEvent) -> _types.guild.GuildScheduledEvent: ... - - -@typing.overload -def dict_from_object(obj: discord.Guild) -> _types.guild.Guild: ... - - -@functools.singledispatch -def dict_from_object(obj: typing.Any) -> typing.Never: +class DictFromObject(Protocol): + @overload + def __call__(self, obj: discord.user.BaseUser) -> _types.member.UserWithMember: ... + @overload + def __call__(self, obj: discord.Member, *, guild: Literal[False] = ...) -> _types.member.MemberWithUser: ... + @overload + def __call__(self, obj: discord.Member, *, guild: Literal[True] = ...) -> _types.gateway.GuildMemberUpdateEvent: ... + + @overload + def __call__( + self, + obj: discord.Member, + *, + guild: bool = ..., + ) -> _types.member.MemberWithUser | _types.gateway.GuildMemberUpdateEvent: ... + + @overload + def __call__(self, obj: discord.Role) -> _types.role.Role: ... + + @overload + def __call__(self, obj: discord.TextChannel) -> _types.channel.TextChannel: ... + @overload + def __call__(self, obj: discord.DMChannel) -> _types.channel.DMChannel: ... + @overload + def __call__(self, obj: discord.CategoryChannel) -> _types.channel.CategoryChannel: ... + @overload + def __call__(self, obj: discord.VoiceChannel) -> _types.channel.VoiceChannel: ... + @overload + def __call__(self, obj: _types.AnyChannel) -> _types.channel.Channel: ... + + @overload + def __call__(self, obj: discord.Message) -> _types.message.Message: ... + @overload + def __call__(self, obj: discord.Attachment) -> _types.message.Attachment: ... + @overload + def __call__(self, obj: discord.Emoji) -> _types.emoji.Emoji: ... + + @overload + def __call__(self, obj: discord.GuildSticker) -> _types.sticker.GuildSticker: ... + @overload + def __call__(self, obj: discord.Sticker) -> _types.sticker.Sticker: ... + + @overload + def __call__(self, obj: discord.StageInstance) -> _types.guild.StageInstance: ... + @overload + def __call__(self, obj: discord.ScheduledEvent) -> _types.guild.GuildScheduledEvent: ... + @overload + def __call__(self, obj: discord.Guild) -> _types.guild.Guild: ... + + @overload + def __call__( + self, + obj: discord.PermissionOverwrite, + *, + target: discord.Member | discord.Role | discord.Object, + ) -> _types.channel.PermissionOverwrite: ... + + @overload + def __call__(self, obj: discord.Poll) -> _types.poll.Poll: ... + + @overload + def __call__(self, obj: discord.PollAnswer, *, count: Literal[True] = ...) -> _types.poll.PollAnswerCount: ... + @overload + def __call__(self, obj: discord.PollAnswer, *, count: Literal[False] = ...) -> _types.poll.PollAnswerWithID: ... + + @overload + def __call__( + self, + obj: discord.PollAnswer, + *, + count: bool = False, + ) -> _types.poll.PollAnswerWithID | _types.poll.PollAnswerCount: ... + + @overload + def __call__(self, obj: discord.PollMedia) -> _types.poll.PollMedia: ... + + def __call__(self, obj: object, **_kwargs: Any) -> NoReturn: ... + + def register(self, ty: type) -> Callable[[Callable[P, T]], Callable[P, T]]: ... + + +dict_from_object: DictFromObject + + +@functools.singledispatch # type: ignore[no-redef] +def dict_from_object(obj: object, **_kwargs: Any) -> Any: raise TypeError(f"Unrecognized discord model type {type(obj)}") -@dict_from_object.register(discord.ClientUser) -def _from_client_user(user: discord.ClientUser) -> _types.user.User: - out: _types.user.User = { - 'id': user.id, - 'global_name': user.global_name, - 'username': user.name, - 'discriminator': user.discriminator, - 'avatar': user.avatar.url if user.avatar else None, - } - items = ("bot", "mfa_enabled", "locale", "verified", "email", "premium_type") - _fill_optional(out, user, items) - return out - - -@dict_from_object.register(discord.User) -def _from_user(user: discord.User) -> _types.user.User: - out: _types.user.User = { +@dict_from_object.register(discord.user.BaseUser) +def _from_base_user(user: discord.user.BaseUser) -> _types.member.UserWithMember: + out: _types.member.UserWithMember = { 'id': user.id, 'global_name': user.global_name, 'username': user.name, @@ -256,21 +305,42 @@ def _from_user(user: discord.User) -> _types.user.User: @dict_from_object.register(discord.Member) -def _from_member(member: discord.Member) -> _types.member.MemberWithUser: +def _from_member( + member: discord.Member, + *, + guild: bool = False, +) -> _types.member.MemberWithUser | _types.gateway.GuildMemberUpdateEvent: # discord code adds default role to every member later on in Member constructor roles_no_default = list(filter(lambda r: not r == member.guild.default_role, member.roles)) - out: _types.guild.Member = { - 'guild_id': member.guild.id, - 'user': dict_from_object(member._user), - 'roles': list(map(lambda role: int(role.id), roles_no_default)), - 'joined_at': str(int(member.joined_at.timestamp())) if member.joined_at else None, - 'flags': member.flags.value, - 'deaf': member.voice.deaf if member.voice else False, - 'mute': member.voice.mute if member.voice else False, - } - items = ("nick",) - _fill_optional(out, member, items) - return out + items: tuple[str, ...] + if guild: + out: _types.gateway.GuildMemberUpdateEvent = { + 'guild_id': member.guild.id, + 'user': dict_from_object(member._user), + 'avatar': member.avatar.url if member.avatar else "", + 'roles': list(map(lambda role: int(role.id), roles_no_default)), + 'joined_at': str(int(member.joined_at.timestamp())) if member.joined_at else None, + 'flags': member.flags.value, + 'deaf': member.voice.deaf if member.voice else False, + 'mute': member.voice.mute if member.voice else False, + } + items = ("nick", "premium_since", "pending", "permissions", "communication_disabled_until", + "avatar_decoration_data") + _fill_optional(out, member, items) + return out + else: + mem_user: _types.member.MemberWithUser = { + 'user': dict_from_object(member._user), + 'roles': list(map(lambda role: int(role.id), roles_no_default)), + 'joined_at': str(int(member.joined_at.timestamp())) if member.joined_at else None, + 'flags': member.flags.value, + 'deaf': member.voice.deaf if member.voice else False, + 'mute': member.voice.mute if member.voice else False, + } + items = ("avatar", "nick", "premium_since", "pending", "permissions", "communication_disabled_until", + "avatar_decoration_data") + _fill_optional(mem_user, member, items) + return mem_user @dict_from_object.register(discord.Role) @@ -301,8 +371,8 @@ def _from_text_channel(channel: discord.TextChannel) -> _types.channel.TextChann 'position': channel.position, 'id': channel.id, 'guild_id': channel.guild.id, - 'permission_overwrites': [dict_from_overwrite(k, v) for k, v in channel.overwrites.items()], - 'type': channel.type, + 'permission_overwrites': [dict_from_object(v, target=k) for k, v in channel.overwrites.items()], + 'type': channel.type.value, 'parent_id': channel.category_id, 'nsfw': channel.nsfw, } @@ -310,7 +380,14 @@ def _from_text_channel(channel: discord.TextChannel) -> _types.channel.TextChann @dict_from_object.register(discord.DMChannel) def _from_dm_channel(channel: discord.DMChannel) -> _types.channel.DMChannel: - pass + return { + 'id': channel.id, + 'name': "", + 'type': channel.type.value, + # TODO: Map this correctly? + 'last_message_id': 0, + 'recipients': list(map(dict_from_object, channel.recipients)) + } @dict_from_object.register(discord.CategoryChannel) @@ -320,8 +397,8 @@ def _from_category_channel(channel: discord.CategoryChannel) -> _types.channel.C 'position': channel.position, 'id': channel.id, 'guild_id': channel.guild.id, - 'permission_overwrites': [dict_from_overwrite(k, v) for k, v in channel.overwrites.items()], - 'type': channel.type, + 'permission_overwrites': [dict_from_object(v, target=k) for k, v in channel.overwrites.items()], + 'type': channel.type.value, 'nsfw': channel.nsfw, 'parent_id': channel.category_id, } @@ -334,8 +411,8 @@ def _from_voice_channel(channel: discord.VoiceChannel) -> _types.channel.VoiceCh 'position': channel.position, 'id': channel.id, 'guild_id': channel.guild.id, - 'permission_overwrites': [dict_from_overwrite(k, v) for k, v in channel.overwrites.items()], - 'type': channel.type, + 'permission_overwrites': [dict_from_object(v, target=k) for k, v in channel.overwrites.items()], + 'type': channel.type.value, 'nsfw': channel.nsfw, 'parent_id': channel.category_id, 'bitrate': channel.bitrate, @@ -346,7 +423,7 @@ def _from_voice_channel(channel: discord.VoiceChannel) -> _types.channel.VoiceCh @dict_from_object.register(discord.Message) def _from_message(message: discord.Message) -> _types.message.Message: if isinstance(message.author, discord.Member): - member = dict_from_object(message.author) + member: _types.member.MemberWithUser | None = dict_from_object(message.author) user = message.author._user else: member = None @@ -366,20 +443,20 @@ def _from_message(message: discord.Message) -> _types.message.Message: 'tts': message.tts, 'mention_everyone': message.mention_everyone, 'pinned': message.pinned, - 'type': message.type.value, + 'type': message.type.value, # type: ignore[typeddict-item] } if member: - out['member'] = member + out['member'] = {**member} items = ('content', 'pinned', 'activity', - 'mention_everyone', 'tts', 'type', 'nonce') + 'mention_everyone', 'tts', 'type', 'nonce', 'poll') _fill_optional(out, message, items) return out @dict_from_object.register(discord.Attachment) def _from_attachment(attachment: discord.Attachment) -> _types.message.Attachment: - return { + out: _types.message.Attachment = { 'id': attachment.id, 'filename': attachment.filename, 'size': attachment.size, @@ -387,8 +464,10 @@ def _from_attachment(attachment: discord.Attachment) -> _types.message.Attachmen 'proxy_url': attachment.proxy_url, 'height': attachment.height, 'width': attachment.width, - 'content_type': attachment.content_type, } + if attachment.content_type: + out['content_type'] = attachment.content_type + return out @dict_from_object.register(discord.Emoji) @@ -405,7 +484,7 @@ def _from_emoji(emoji: discord.Emoji) -> _types.emoji.Emoji: @dict_from_object.register(discord.Sticker) def _from_sticker(sticker: discord.Sticker) -> _types.sticker.Sticker: if isinstance(sticker, discord.StandardSticker): - out: _types.sticker.StandardSticker = { + standard: _types.sticker.StandardSticker = { 'id': sticker.id, 'name': sticker.name, 'description': sticker.description, @@ -415,22 +494,23 @@ def _from_sticker(sticker: discord.Sticker) -> _types.sticker.Sticker: 'sort_value': sticker.sort_value, 'pack_id': sticker.pack_id, } + return standard elif isinstance(sticker, discord.GuildSticker): - out: _types.sticker.GuildSticker = { + guild: _types.sticker.GuildSticker = { 'id': sticker.id, 'name': sticker.name, 'description': sticker.description, - 'tags': sticker.tags, + 'tags': sticker.emoji, 'format_type': sticker.format.value, 'type': 2, 'available': sticker.available, 'guild_id': sticker.guild_id, } items = ("user",) - _fill_optional(out, sticker, items) + _fill_optional(guild, sticker, items) + return guild else: raise TypeError(f"Invalid type for sticker {type(sticker)}") - return out @dict_from_object.register(discord.StageInstance) @@ -449,7 +529,7 @@ def _from_stage_instance(stage_instance: discord.StageInstance) -> _types.channe @dict_from_object.register(discord.ScheduledEvent) def _from_scheduled_event(event: discord.ScheduledEvent) -> _types.guild.GuildScheduledEvent: if event.entity_type == discord.EntityType.stage_instance: - out: _types.scheduled_event.StageInstanceScheduledEvent = { + stage: _types.scheduled_event.StageInstanceScheduledEvent = { 'id': event.id, 'guild_id': event.guild_id, 'entity_id': event.entity_id, @@ -462,9 +542,10 @@ def _from_scheduled_event(event: discord.ScheduledEvent) -> _types.guild.GuildSc 'entity_metadata': None, } if event.end_time: - out["scheduled_end_time"] = str(int(event.end_time.timestamp())) + stage["scheduled_end_time"] = str(int(event.end_time.timestamp())) + return stage elif event.entity_type == discord.EntityType.voice: - out: _types.scheduled_event.VoiceScheduledEvent = { + voice: _types.scheduled_event.VoiceScheduledEvent = { 'id': event.id, 'guild_id': event.guild_id, 'entity_id': event.entity_id, @@ -477,9 +558,10 @@ def _from_scheduled_event(event: discord.ScheduledEvent) -> _types.guild.GuildSc 'entity_metadata': None, } if event.end_time: - out["scheduled_end_time"] = str(int(event.end_time.timestamp())) + voice["scheduled_end_time"] = str(int(event.end_time.timestamp())) + return voice else: - out: _types.scheduled_event.ExternalScheduledEvent = { + external: _types.scheduled_event.ExternalScheduledEvent = { 'id': event.id, 'guild_id': event.guild_id, 'entity_id': event.entity_id, @@ -489,10 +571,11 @@ def _from_scheduled_event(event: discord.ScheduledEvent) -> _types.guild.GuildSc 'status': event.status.value, 'entity_type': 3, 'channel_id': None, - 'scheduled_end_time': str(int(event.end_time.timestamp())), + # end_time guaranteed non-None for external events + 'scheduled_end_time': str(int(event.end_time.timestamp())), # type: ignore[union-attr] 'entity_metadata': {"location": event.location or ""} } - return out + return external @dict_from_object.register(discord.Guild) @@ -500,10 +583,10 @@ def _from_guild(guild: discord.Guild) -> _types.guild.Guild: return { 'id': guild.id, 'name': guild.name, - 'icon': guild.icon.url, - 'splash': guild.splash.url, - 'owner_id': guild.owner_id, - 'region': guild.region, + 'icon': guild.icon.url if guild.icon else None, + 'splash': guild.splash.url if guild.splash else None, + 'owner_id': guild.owner_id or 0, + 'region': "us-west", # deprecated? 'afk_channel_id': guild.afk_channel.id if guild.afk_channel else None, 'afk_timeout': guild.afk_timeout, 'verification_level': guild.verification_level.value, @@ -525,7 +608,7 @@ def _from_guild(guild: discord.Guild) -> _types.guild.Guild: 'system_channel_flags': guild.system_channel_flags.value, 'rules_channel_id': guild.rules_channel.id if guild.rules_channel else None, 'vanity_url_code': guild.vanity_url_code, - 'premium_tier': guild.premium_tier, + 'premium_tier': guild.premium_tier, # type: ignore[typeddict-item] 'preferred_locale': guild.preferred_locale.value, 'public_updates_channel_id': guild.public_updates_channel.id if guild.public_updates_channel else None, 'stage_instances': list(map(dict_from_object, guild.stage_instances)), @@ -533,6 +616,67 @@ def _from_guild(guild: discord.Guild) -> _types.guild.Guild: } +@dict_from_object.register(discord.PermissionOverwrite) +def _from_overwrite( + overwrite: discord.PermissionOverwrite, + *, + target: discord.Member | discord.Role | discord.Object, +) -> _types.channel.PermissionOverwrite: + allow, deny = overwrite.pair() + ovr: _types.channel.PermissionOverwrite = { + 'id': target.id, + 'allow': str(allow.value), + 'deny': str(deny.value), + 'type': 0 if isinstance(target, discord.Role) else 1 + } + return ovr + + +@dict_from_object.register(discord.Poll) +def _from_poll(poll: discord.Poll) -> _types.poll.Poll: + out: _types.poll.Poll = { + 'allow_multiselect': poll.multiple, + 'answers': [dict_from_object(answer, count=False) for answer in poll.answers], + 'expiry': (poll.expires_at or (dt.datetime.now(tz=dt.timezone.utc) + poll.duration)).isoformat(), + 'layout_type': poll.layout_type, # type: ignore[typeddict-item] + 'question': dict_from_object(poll._question_media), + 'results': { + 'is_finalized': poll.is_finalized(), + 'answer_counts': [dict_from_object(answer, count=True) for answer in poll.answers], + }, + } + return out + + +@dict_from_object.register(discord.PollAnswer) +def _from_poll_answer( + answer: discord.PollAnswer, + *, + count: bool = False, +) -> _types.poll.PollAnswerWithID | _types.poll.PollAnswerCount: + if count: + return { + 'id': answer.id, + 'count': answer.vote_count, + 'me_voted': answer.self_voted, + } + else: + return { + 'answer_id': answer.id, + 'poll_media': dict_from_object(answer.media), + } + + +@dict_from_object.register(discord.PollMedia) +def _from_poll_media(media: discord.PollMedia) -> _types.poll.PollMedia: + out: _types.poll.PollMedia = { + 'text': media.text, + } + items = ("emoji",) + _fill_optional(out, media, items) + return out + + # discord.py 1.7 bump requires the 'permissions_new', but if we keep 'permissions' then we seem to work on pre 1.7 def make_role_dict( name: str, @@ -554,7 +698,7 @@ def make_role_dict( raise ValueError("Both 'colour' and 'color' can be supplied at the same time") colour = color if colors is None: - colors: _types.role.RoleColours = { + colors = { 'primary_color': colour, 'secondary_color': None, 'tertiary_color': None, @@ -573,40 +717,43 @@ def make_role_dict( } -@typing.overload +@overload def make_channel_dict( - ctype: typing.Literal[0], + ctype: Literal[0], id_num: int = ..., - **kwargs: typing.Any, + **kwargs: Any, ) -> _types.channel.TextChannel: ... -@typing.overload +@overload def make_channel_dict( - ctype: typing.Literal[1], + ctype: Literal[1], id_num: int = ..., - **kwargs: typing.Any, + **kwargs: Any, ) -> _types.channel.DMChannel: ... -@typing.overload +@overload def make_channel_dict( - ctype: typing.Literal[2], + ctype: Literal[2], id_num: int = ..., - **kwargs: typing.Any, + **kwargs: Any, ) -> _types.channel.VoiceChannel: ... -@typing.overload +@overload def make_channel_dict( - ctype: typing.Literal[4], + ctype: Literal[4], id_num: int = ..., - **kwargs: typing.Any, + **kwargs: Any, ) -> _types.channel.CategoryChannel: ... -def make_channel_dict(ctype: typing.Literal[0, 1, 2, 3], id_num: int = -1, - **kwargs: typing.Any) -> _types.channel.PartialChannel: +def make_channel_dict( + ctype: Literal[0, 1, 2, 3, 4], + id_num: int = -1, + **kwargs: Any, +) -> _types.channel.Channel: if id_num < 0: id_num = make_id() out: _types.channel.PartialChannel = { @@ -618,55 +765,43 @@ def make_channel_dict(ctype: typing.Literal[0, 1, 2, 3], id_num: int = -1, "user_limit", "rate_limit_per_user", "recipients", "icon", "owner_id", "application_id", "parent_id", "last_pin_timestamp") _fill_optional(out, kwargs, items) - return out + return out # type: ignore[return-value] -def make_text_channel_dict(name: str, id_num: int = -1, **kwargs: typing.Any) -> _types.channel.TextChannel: +def make_text_channel_dict(name: str, id_num: int = -1, **kwargs: Any) -> _types.channel.TextChannel: return make_channel_dict(discord.ChannelType.text.value, id_num, name=name, **kwargs) -def make_category_channel_dict(name: str, id_num: int = -1, **kwargs: typing.Any) -> _types.channel.CategoryChannel: +def make_category_channel_dict(name: str, id_num: int = -1, **kwargs: Any) -> _types.channel.CategoryChannel: return make_channel_dict(discord.ChannelType.category.value, id_num, name=name, **kwargs) -def make_dm_channel_dict(user: discord.User, id_num: int = -1, **kwargs: typing.Any) -> _types.channel.DMChannel: +def make_dm_channel_dict(user: discord.User, id_num: int = -1, **kwargs: Any) -> _types.channel.DMChannel: return make_channel_dict(discord.ChannelType.private.value, id_num, recipients=[dict_from_object(user)], **kwargs) -def make_voice_channel_dict(name: str, id_num: int = -1, **kwargs: typing.Any) -> _types.channel.VoiceChannel: +def make_voice_channel_dict(name: str, id_num: int = -1, **kwargs: Any) -> _types.channel.VoiceChannel: return make_channel_dict(discord.ChannelType.voice.value, id_num, name=name, **kwargs) -def dict_from_overwrite(target: discord.Member | discord.Role, - overwrite: discord.PermissionOverwrite) -> _types.channel.PermissionOverwrite: - allow, deny = overwrite.pair() - ovr: _types.channel.PermissionOverwrite = { - 'id': target.id, - 'allow': str(allow.value), - 'deny': str(deny.value), - 'type': 0 if isinstance(target, discord.Role) else 1 - } - return ovr - - # TODO: Convert reactions, activity, and application to a dict. def make_message_dict( channel: _types.AnyChannel, - author: discord.user.BaseUser, + author: discord.user.BaseUser | discord.Member, id_num: int = -1, - content: str = None, - timestamp: str = None, + content: str = "", + timestamp: str | None = None, edited_timestamp: str | None = None, tts: bool = False, mention_everyone: bool = False, - mentions: list[discord.User | discord.Member] = None, - mention_roles: list[int] = None, - mention_channels: list[_types.AnyChannel] = None, - attachments: list[discord.Attachment] = None, + mentions: list[discord.User | discord.Member] | None = None, + mention_roles: list[_types.gateway.Snowflake] | None = None, + mention_channels: list[_types.AnyChannel] | None = None, + attachments: list[discord.Attachment] | None = None, embeds: list[discord.Embed] | None = None, pinned: bool = False, type: int = 0, - **kwargs, + **kwargs: Any, ) -> _types.message.Message: if mentions is None: mentions = [] @@ -677,8 +812,6 @@ def make_message_dict( if attachments is None: attachments = [] - if not content: - content = "" if id_num < 0: id_num = make_id() if isinstance(channel, discord.abc.GuildChannel): @@ -688,10 +821,10 @@ def make_message_dict( kwargs["member"] = dict_from_object(author) if timestamp is None: timestamp = str(int(discord.utils.snowflake_time(id_num).timestamp())) - mentions = list(map(user_with_member, mentions)) if mentions else [] - mention_channels = list(map(_mention_from_channel, mention_channels)) if mention_channels else [] - attachments = list(map(dict_from_object, attachments)) if attachments else [] - embeds = list(map(discord.Embed.to_dict, embeds)) if embeds else [] + mentions_json = list(map(user_with_member, mentions)) if mentions else [] + mention_channels_json = list(map(_mention_from_channel, mention_channels)) if mention_channels else [] + attachments_json = list(map(dict_from_object, attachments)) if attachments else [] + embeds_json = list(map(discord.Embed.to_dict, embeds)) if embeds else [] out: _types.message.Message = { 'id': id_num, @@ -701,16 +834,16 @@ def make_message_dict( 'timestamp': timestamp, 'edited_timestamp': edited_timestamp, 'tts': tts, - 'mention_channels': mention_channels, + 'mention_channels': mention_channels_json, 'mention_everyone': mention_everyone, - 'mentions': mentions, + 'mentions': mentions_json, 'mention_roles': mention_roles, - 'attachments': attachments, - 'embeds': embeds, + 'attachments': attachments_json, + 'embeds': embeds_json, 'pinned': pinned, - 'type': type, + 'type': type, # type: ignore[typeddict-item] } - items = ('guild_id', 'member', 'reactions', 'nonce', 'webhook_id', 'activity', 'application') + items = ('guild_id', 'member', 'reactions', 'nonce', 'webhook_id', 'activity', 'application', 'poll') _fill_optional(out, kwargs, items) return out @@ -718,14 +851,14 @@ def make_message_dict( def _mention_from_channel(channel: _types.AnyChannel) -> _types.message.ChannelMention: out: _types.message.ChannelMention = { "id": channel.id, - "type": str(channel.type), + "type": channel.type.value, "guild_id": 0, "name": "" } if hasattr(channel, "guild"): - out["guild_id"] = channel.guild.id + out["guild_id"] = channel.guild.id if channel.guild else 0 if hasattr(channel, "name"): - out["name"] = channel.name + out["name"] = channel.name or "" return out @@ -769,14 +902,14 @@ def make_guild_dict( region: str = "en_north", afk_channel_id: int | None = None, afk_timeout: int = 600, - verification_level: int = 0, - default_message_notifications: int = 0, - explicit_content_filter: int = 0, - features: list[str] | None = None, - mfa_level: int = 0, + verification_level: Literal[0, 1, 2, 3, 4] = 0, + default_message_notifications: Literal[0, 1] = 0, + explicit_content_filter: Literal[0, 1, 2] = 0, + features: list[_types.guild.GuildFeature] | None = None, + mfa_level: Literal[0, 1] = 0, application_id: int | None = None, system_channel_id: int | None = None, - **kwargs: typing.Any, + **kwargs: Any, ) -> _types.guild.Guild: if id_num < 0: id_num = make_id() diff --git a/discord/ext/test/runner.py b/discord/ext/test/runner.py index a757c3d..a21bd56 100644 --- a/discord/ext/test/runner.py +++ b/discord/ext/test/runner.py @@ -12,17 +12,24 @@ import sys import asyncio import logging +from typing import NamedTuple, Callable, Any + import discord -import typing import pathlib from itertools import count +from discord.ext import commands +from discord.ext.commands import CommandError +from discord.ext.commands._types import BotT +from typing_extensions import ParamSpec, TypeVar + from . import backend as back, callbacks, _types +from .callbacks import CallbackEvent from .utils import PeekableQueue -class RunnerConfig(typing.NamedTuple): +class RunnerConfig(NamedTuple): """ Exposed discord test configuration Contains the current client, and lists of faked objects @@ -36,11 +43,17 @@ class RunnerConfig(typing.NamedTuple): log = logging.getLogger("discord.ext.tests") _cur_config: RunnerConfig | None = None -sent_queue: PeekableQueue = PeekableQueue() -error_queue: PeekableQueue = PeekableQueue() +sent_queue: PeekableQueue[discord.Message] = PeekableQueue() +error_queue: PeekableQueue[tuple[ + commands.Context[commands.Bot | commands.AutoShardedBot], CommandError +]] = PeekableQueue() + + +T = TypeVar('T') +P = ParamSpec('P') -def require_config(func: typing.Callable[..., _types.T]) -> typing.Callable[..., _types.T]: +def require_config(func: Callable[P, T]) -> Callable[P, T]: """ Decorator to enforce that configuration is completed before the decorated function is called. @@ -49,7 +62,9 @@ def require_config(func: typing.Callable[..., _types.T]) -> typing.Callable[..., :return: Function with added check for configuration being setup """ - def wrapper(*args, **kwargs): + wrapper: _types.Wrapper[P, T] + + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # type: ignore[no-redef] if _cur_config is None: log.error("Attempted to make call before runner configured") raise RuntimeError(f"Configure runner before calling {func.__name__}") @@ -61,11 +76,11 @@ def wrapper(*args, **kwargs): return wrapper -def _task_coro_name(task: asyncio.Task) -> str | None: +def _task_coro_name(task: asyncio.Task[Any]) -> str | None: """ - Uses getattr() to avoid AttributeErrors when the _coro doesn't have a __name__ + Uses getattr() to avoid AttributeErrors when the coroutine doesn't have a __name__ """ - return getattr(task._coro, "__name__", None) + return getattr(task.get_coro(), "__name__", None) async def run_all_events() -> None: @@ -149,7 +164,7 @@ async def _message_callback(message: discord.Message) -> None: await sent_queue.put(message) -async def _edit_member_callback(fields: typing.Any, member: discord.Member, reason: str | None): +async def _edit_member_callback(fields: Any, member: discord.Member, reason: str | None) -> None: """ Internal callback. Updates a guild's voice states to reflect the given Member connecting to the given channel. Other updates to members are handled in http.edit_member(). @@ -162,7 +177,8 @@ async def _edit_member_callback(fields: typing.Any, member: discord.Member, reas guild = member.guild channel = fields.get('channel_id') if not fields.get('nick') and not fields.get('roles'): - guild._update_voice_state(data, channel) + # Data is allowed to not contain fields + guild._update_voice_state(data, channel) # type: ignore[arg-type] counter = count(0) @@ -173,7 +189,7 @@ async def message( content: str, channel: _types.AnyChannel | int = 0, member: discord.Member | int = 0, - attachments: list[pathlib.Path | str] = None + attachments: list[pathlib.Path | str] | None = None ) -> discord.Message: """ Fake a message being sent by some user to a channel. @@ -185,19 +201,19 @@ async def message( :return: New message that was sent """ if isinstance(channel, int): - channel = _cur_config.channels[channel] + channel = get_config().channels[channel] if isinstance(member, int): - member = _cur_config.members[member] + member = get_config().members[member] import os if attachments is None: attachments = [] - attachments = [ + attachments_model = [ discord.Attachment( data={ 'id': counter.__next__(), 'filename': os.path.basename(attachment), 'size': 0, - 'url': attachment, + 'url': str(attachment), 'proxy_url': "", 'height': 0, 'width': 0 @@ -206,7 +222,7 @@ async def message( ) for attachment in attachments ] - mes = back.make_message(content, member, channel, attachments=attachments) + mes = back.make_message(content, member, channel, attachments=attachments_model) await run_all_events() @@ -219,10 +235,10 @@ async def message( @require_config async def set_permission_overrides( - target: discord.User | discord.Role, - channel: discord.abc.GuildChannel, + target: discord.Member | discord.Role | int, + channel: discord.abc.GuildChannel | int, overrides: discord.PermissionOverwrite | None = None, - **kwargs: typing.Any, + **kwargs: Any, ) -> None: """ Set the permission override for a channel, as if set by another user. @@ -239,21 +255,21 @@ async def set_permission_overrides( overrides = discord.PermissionOverwrite(**kwargs) if isinstance(target, int): - target = _cur_config.members[target] + target = get_config().members[target] if isinstance(channel, int): - channel = _cur_config.channels[channel] + channel = get_config().channels[channel] - if not isinstance(channel, discord.abc.GuildChannel): - raise TypeError(f"channel '{channel}' must be a abc.GuildChannel, not '{type(channel)}''") + if not isinstance(channel, discord.TextChannel): + raise TypeError(f"channel '{channel}' must be a discord.TextChannel, not '{type(channel)}''") if not isinstance(target, (discord.abc.User, discord.Role)): - raise TypeError(f"target '{target}' must be a abc.User or Role, not '{type(target)}''") + raise TypeError(f"target '{target}' must be an abc.User or Role, not '{type(target)}''") # TODO: This will probably break for video channels/non-standard text channels back.update_text_channel(channel, target, overrides) @require_config -async def add_role(member: discord.Member, role: discord.Role) -> None: +async def add_role(member: discord.Member | int, role: discord.Role) -> None: """ Add a role to a member, as if added by another user. @@ -261,7 +277,7 @@ async def add_role(member: discord.Member, role: discord.Role) -> None: :param role: Role to be added """ if isinstance(member, int): - member = _cur_config.members[member] + member = get_config().members[member] if not isinstance(role, discord.Role): raise TypeError("Role argument must be of type discord.Role") @@ -270,7 +286,7 @@ async def add_role(member: discord.Member, role: discord.Role) -> None: @require_config -async def remove_role(member: discord.Member, role: discord.Role) -> None: +async def remove_role(member: discord.Member | int, role: discord.Role) -> None: """ Remove a role from a member, as if removed by another user. @@ -278,7 +294,7 @@ async def remove_role(member: discord.Member, role: discord.Role) -> None: :param role: Role to remove """ if isinstance(member, int): - member = _cur_config.members[member] + member = get_config().members[member] if not isinstance(role, discord.Role): raise TypeError("Role argument must be of type discord.Role") @@ -301,7 +317,7 @@ async def add_reaction(user: discord.user.BaseUser | discord.abc.User, @require_config -async def remove_reaction(user: discord.user.BaseUser | discord.abc.User, +async def remove_reaction(user: discord.user.BaseUser | discord.Member, message: discord.Message, emoji: str) -> None: """ Remove a reaction from a message, as if done by another user @@ -319,8 +335,8 @@ async def member_join( guild: discord.Guild | int = 0, user: discord.User | None = None, *, - name: str = None, - discrim: str | int = None + name: str | None = None, + discrim: str | int | None = None ) -> discord.Member: """ Have a new member join a guild, either an existing or new user for the framework @@ -332,7 +348,7 @@ async def member_join( """ import random if isinstance(guild, int): - guild = _cur_config.guilds[guild] + guild = _cur_config.guilds[guild] # type: ignore[union-attr] if user is None: if name is None: @@ -352,6 +368,8 @@ def get_config() -> RunnerConfig: :return: Current runner config """ + if _cur_config is None: + raise RuntimeError("Runner not configured yet") return _cur_config @@ -384,7 +402,9 @@ def configure(client: discord.Client, if hasattr(client, "on_command_error"): old_error = client.on_command_error - async def on_command_error(ctx, error): + on_command_error: _types.FnWithOld[[commands.Context[commands.Bot | commands.AutoShardedBot], CommandError], None] + + async def on_command_error(ctx: commands.Context[BotT], error: CommandError) -> None: # type: ignore[no-redef] try: if old_error: await old_error(ctx, error) @@ -392,11 +412,12 @@ async def on_command_error(ctx, error): await error_queue.put((ctx, error)) on_command_error.__old__ = old_error - client.on_command_error = on_command_error + + client.on_command_error = on_command_error # type: ignore[attr-defined] # Configure global callbacks - callbacks.set_callback(_message_callback, "send_message") - callbacks.set_callback(_edit_member_callback, "edit_member") + callbacks.set_callback(_message_callback, CallbackEvent.send_message) + callbacks.set_callback(_edit_member_callback, CallbackEvent.edit_member) back.get_state().stop_dispatch() @@ -410,28 +431,28 @@ async def on_command_error(ctx, error): guild = back.make_guild(guild_name) _guilds.append(guild) - _channels = [] + _channels: list[discord.abc.GuildChannel] = [] _members = [] for guild in _guilds: # Text channels if isinstance(text_channels, int): for num in range(text_channels): - channel = back.make_text_channel(f"TextChannel_{num}", guild) - _channels.append(channel) + text = back.make_text_channel(f"TextChannel_{num}", guild) + _channels.append(text) if isinstance(text_channels, list): for chan in text_channels: - channel = back.make_text_channel(chan, guild) - _channels.append(channel) + text = back.make_text_channel(chan, guild) + _channels.append(text) # Voice channels if isinstance(voice_channels, int): for num in range(voice_channels): - channel = back.make_voice_channel(f"VoiceChannel_{num}", guild) - _channels.append(channel) + voice = back.make_voice_channel(f"VoiceChannel_{num}", guild) + _channels.append(voice) if isinstance(voice_channels, list): for chan in voice_channels: - channel = back.make_voice_channel(chan, guild) - _channels.append(channel) + voice = back.make_voice_channel(chan, guild) + _channels.append(voice) # Members if isinstance(members, int): @@ -445,7 +466,9 @@ async def on_command_error(ctx, error): member = back.make_member(user, guild, nick=f"{user.name}_{str(num)}_nick") _members.append(member) - back.make_member(back.get_state().user, guild, nick=f"{client.user.name}_nick") + client_user = back.get_state().user + if client_user is not None: + back.make_member(client_user, guild, nick=f"{client_user.name}_nick") back.get_state().start_dispatch() diff --git a/discord/ext/test/state.py b/discord/ext/test/state.py index 91397dc..6a86c45 100644 --- a/discord/ext/test/state.py +++ b/discord/ext/test/state.py @@ -4,15 +4,23 @@ """ import asyncio +from asyncio import Future +from typing import TypeVar, ParamSpec, Any, Literal, overload + import discord import discord.http as dhttp import discord.state as dstate +from . import _types from . import factories as facts from . import backend as back from .voice import FakeVoiceChannel +P = ParamSpec('P') +T = TypeVar('T') + + class FakeState(dstate.ConnectionState): """ A mock implementation of a ``ConnectionState``. Overrides methods that would otherwise cause issues, and @@ -20,13 +28,14 @@ class FakeState(dstate.ConnectionState): """ http: 'back.FakeHttp' # String because of circular import + user: discord.ClientUser - def __init__(self, client: discord.Client, http: dhttp.HTTPClient, user: discord.ClientUser = None, - loop: asyncio.AbstractEventLoop = None) -> None: + def __init__(self, client: discord.Client, http: dhttp.HTTPClient, user: discord.ClientUser | None = None, + loop: asyncio.AbstractEventLoop | None = None) -> None: if loop is None: loop = asyncio.get_event_loop() super().__init__(dispatch=client.dispatch, - handlers=None, hooks=None, + handlers={}, hooks={}, syncer=None, http=http, loop=loop, intents=client.intents, member_cache_flags=client._connection.member_cache_flags) @@ -41,9 +50,9 @@ def __init__(self, client: discord.Client, http: dhttp.HTTPClient, user: discord real_disp = self.dispatch - def dispatch(*args, **kwargs): + def dispatch(*args: Any, **kwargs: Any) -> T | None: if not self._do_dispatch: - return + return None return real_disp(*args, **kwargs) self.dispatch = dispatch @@ -61,21 +70,40 @@ def start_dispatch(self) -> None: self._do_dispatch = True # TODO: Respect limit parameters - async def query_members(self, guild: discord.Guild, query: str, limit: int, user_ids: int, - cache: bool, presences: bool) -> None: - guild: discord.Guild = discord.utils.get(self.guilds, id=guild.id) - return guild.members - - async def chunk_guild(self, guild: discord.Guild, *, wait: bool = True, cache: bool | None = None): - pass - - def _guild_needs_chunking(self, guild: discord.Guild): + async def query_members(self, guild: discord.Guild, query: str | None, limit: int, user_ids: list[int] | None, + cache: bool, presences: bool) -> list[discord.Member]: + guild = discord.utils.get(self.guilds, id=guild.id) # type: ignore[assignment] + return list(guild.members) + + @overload + async def chunk_guild( + self, + guild: discord.Guild, + *, + wait: Literal[True] = ..., + cache: bool | None = ..., + ) -> list[discord.Member]: ... + + @overload + async def chunk_guild( + self, guild: discord.Guild, *, wait: Literal[False] = ..., cache: bool | None = ... + ) -> asyncio.Future[list[discord.Member]]: ... + + async def chunk_guild( + self, + guild: discord.Guild, + *, wait: bool = True, + cache: bool | None = None, + ) -> list[discord.Member] | Future[list[discord.Member]]: + return [] + + def _guild_needs_chunking(self, guild: discord.Guild) -> bool: """ Prevents chunking which can throw asyncio wait_for errors with tests under 60 seconds """ return False - def parse_channel_create(self, data) -> None: + def parse_channel_create(self, data: _types.gateway._ChannelEvent | _types.channel.Channel) -> None: """ Need to make sure that FakeVoiceChannels are created when this is called to create VoiceChannels. Otherwise, guilds would not be set up correctly. @@ -94,8 +122,8 @@ def parse_channel_create(self, data) -> None: guild = self._get_guild(guild_id) if guild is not None: # the factory can't be a DMChannel or GroupChannel here - channel = factory(guild=guild, state=self, data=data) # type: ignore - guild._add_channel(channel) # type: ignore + channel = factory(guild=guild, state=self, data=data) # type: ignore[arg-type] + guild._add_channel(channel) self.dispatch('guild_channel_create', channel) else: return diff --git a/discord/ext/test/utils.py b/discord/ext/test/utils.py index f579c6b..35e32b2 100644 --- a/discord/ext/test/utils.py +++ b/discord/ext/test/utils.py @@ -4,16 +4,14 @@ """ import asyncio +from typing import TypeVar + import discord def embed_eq(embed1: discord.Embed | None, embed2: discord.Embed | None) -> bool: - if embed1 == embed2: - return True - elif embed1 is None and embed2 is not None: - return False - elif embed2 is None and embed1 is not None: - return False + if embed1 is None or embed2 is None: + return embed1 == embed2 return all([embed1.title == embed2.title, embed1.description == embed2.description, @@ -23,37 +21,69 @@ def embed_eq(embed1: discord.Embed | None, embed2: discord.Embed | None) -> bool embed1.fields == embed2.fields]) -def activity_eq(act1: discord.Activity | None, act2: discord.Activity | None) -> bool: - if act1 == act2: - return True - elif act1 is None and act2 is not None: - return False - elif act2 is None and act1 is not None: - return False +def activity_eq(act1: discord.activity.ActivityTypes | None, act2: discord.activity.ActivityTypes | None) -> bool: + if act1 is None or act2 is None: + return act1 == act2 - return all([ - act1.name == act2.name, - act1.url == act2.url, - act1.type == act2.type, - act1.details == act2.details, - act1.emoji == act2.emoji, - ]) + match (act1, act2): + case (discord.Activity(), discord.Activity()): + return all([ + act1.name == act2.name, + act1.url == act2.url, + act1.type == act2.type, + act1.details == act2.details, + act1.emoji == act2.emoji, + ]) + case (discord.Game(), discord.Game()): + return all([ + act1.name == act2.name, + act1.platform == act2.platform, + act1.assets == act2.assets, + ]) + case (discord.CustomActivity(), discord.CustomActivity()): + return all([ + act1.name == act2.name, + act1.emoji == act2.emoji, + ]) + case (discord.Streaming(), discord.Streaming()): + return all([ + act1.platform == act2.platform, + act1.name == act2.name, + act1.details == act2.details, + act1.game == act2.game, + act1.url == act2.url, + act1.assets == act2.assets, + ]) + case (discord.Spotify(), discord.Spotify()): + return all([ + act1.title == act2.title, + act1.artist == act2.artist, + act1.album == act2.album, + act1.album_cover_url == act2.album_cover_url, + act1.track_id == act2.track_id, + act1.start == act2.start, + act1.end == act2.end, + ]) + return False -def embed_proxy_eq(embed_proxy1, embed_proxy2): +def embed_proxy_eq(embed_proxy1: discord.embeds.EmbedProxy, embed_proxy2: discord.embeds.EmbedProxy) -> bool: return embed_proxy1.__repr__ == embed_proxy2.__repr__ -class PeekableQueue(asyncio.Queue): +T = TypeVar('T') + + +class PeekableQueue(asyncio.Queue[T]): """ An extension of an asyncio queue with a peek message, so other code doesn't need to rely on unstable internal artifacts """ - def peek(self): + def peek(self) -> T: """ Peek the current last value in the queue, or raise an exception if there are no values :return: Last value in the queue, assuming there are any """ - return self._queue[-1] + return self._queue[-1] # type: ignore[attr-defined] diff --git a/discord/ext/test/verify.py b/discord/ext/test/verify.py index b92b873..f792c77 100644 --- a/discord/ext/test/verify.py +++ b/discord/ext/test/verify.py @@ -11,10 +11,13 @@ import asyncio import pathlib +from typing import TypeVar, Callable + import discord from .runner import sent_queue, get_config from .utils import embed_eq, activity_eq +from ._types import Undef, undefined def _msg_to_str(msg: discord.Message) -> str: @@ -32,19 +35,16 @@ def _msg_to_str(msg: discord.Message) -> str: return f"Message({inner})" -class _Undef: - _singleton = None - - def __new__(cls): - if cls._singleton is None: - cls._singleton = super().__new__(cls) - return cls._singleton +T = TypeVar('T') - def __eq__(self, other): - return self is other - -_undefined = _Undef() +def opt_undef_or(start: str, v: T | Undef | None, f: Callable[[T], str]) -> str: + if v is undefined: + return "" + elif v is None: + return f"{start}=Empty" + else: + return f"{start}={f(v)}" class VerifyMessage: @@ -55,24 +55,24 @@ class VerifyMessage: ``assert dpytest.verify().message().content("Hello World!")`` """ - _used: discord.Message | int | _Undef | None + _used: discord.Message | int | Undef | None _contains: bool _peek: bool _nothing: bool - _content: str | _Undef | None - _embed: discord.Embed | _Undef | None - _attachment: str | pathlib.Path | _Undef | None + _content: str | Undef | None + _embed: discord.Embed | Undef | None + _attachment: str | pathlib.Path | Undef | None def __init__(self) -> None: - self._used = _undefined + self._used = undefined self._contains = False self._peek = False self._nothing = False - self._content = _undefined - self._embed = _undefined - self._attachment = _undefined + self._content = undefined + self._embed = undefined + self._attachment = undefined def __del__(self) -> None: if not self._used: @@ -80,10 +80,10 @@ def __del__(self) -> None: warnings.warn("VerifyMessage dropped without being used, did you forget an `assert`?", RuntimeWarning) def __repr__(self) -> str: - if self._used is not _undefined: - return f"" + if self._used is not undefined: + return f"" else: - return f"" + return f"" def __bool__(self) -> bool: self._used = None @@ -109,19 +109,20 @@ def _expectation(self) -> str: return "no messages" else: contains = "contains" - content = f"content=\"{self._content}\"" if self._content is not _undefined else "" - embed = f"embed={str(self._embed.to_dict())}" if self._embed is not _undefined else "" - attachment = f"attachment={self._attachment}" if self._attachment is not _undefined else "" + content = opt_undef_or("content", self._content, lambda x: f'"{x}"') + embed = opt_undef_or("embed", self._embed, lambda x: str(x.to_dict())) + attachment = opt_undef_or("attachment", self._attachment, lambda x: str(x)) event = " ".join(filter(lambda x: x, [contains, content, embed, attachment])) return f"{event}" def _diff_msg(self) -> str: - if self._nothing: + if isinstance(self._used, int): return f"{self._used} messages" + elif isinstance(self._used, discord.Message): + return f"{_msg_to_str(self._used)}" elif self._used is None: return "no message" - else: - return str(self._used) + return "" def _check_msg(self, msg: discord.Message) -> bool: # If any attributes are 'None', check that they don't exist @@ -132,20 +133,21 @@ def _check_msg(self, msg: discord.Message) -> bool: if self._attachment is None and msg.attachments: return False - # For any attributes that aren't None or _undefined, check that they match - if self._content is not None and self._content is not _undefined: + # For any attributes that aren't None or undefined, check that they match + if self._content is not None and self._content is not undefined: if self._contains and self._content not in msg.content: return False if not self._contains and self._content != msg.content: return False - if self._embed is not None and self._embed is not _undefined: - if self._contains and not any(map(lambda e: embed_eq(self._embed, e), msg.embeds)): + _embed = self._embed + if _embed is not None and _embed is not undefined: + if self._contains and not any(map(lambda e: embed_eq(_embed, e), msg.embeds)): return False - if not self._contains and (len(msg.embeds) != 1 or not embed_eq(self._embed, msg.embeds[0])): + if not self._contains and (len(msg.embeds) != 1 or not embed_eq(_embed, msg.embeds[0])): return False # TODO: Support contains for attachments, 'contains' should mean 'any number of which one matches', # while 'exact' should be 'only one which must match' - if self._attachment is not None and self._attachment is not _undefined: + if self._attachment is not None and self._attachment is not undefined: import urllib.request as request with open(self._attachment, "rb") as file: expected = file.read() @@ -182,7 +184,7 @@ def nothing(self) -> 'VerifyMessage': :return: Self for chaining """ - if self._content is not _undefined or self._embed is not _undefined or self._attachment is not _undefined: + if self._content is not undefined or self._embed is not undefined or self._attachment is not undefined: raise ValueError("Verify nothing conflicts with verifying some content, embed, or attachment") self._nothing = True return self @@ -232,13 +234,18 @@ class VerifyActivity: ``assert not dpytest.verify().activity().name("Foobar")`` """ + _activity: discord.activity.ActivityTypes | None | Undef + _name: str | None | Undef + _url: str | None | Undef + _type: discord.ActivityType | None | Undef + def __init__(self) -> None: self._used = False - self._activity = _undefined - self._name = _undefined - self._url = _undefined - self._type = _undefined + self._activity = undefined + self._name = undefined + self._url = undefined + self._type = undefined def __del__(self) -> None: if not self._used: @@ -250,62 +257,74 @@ def __bool__(self) -> bool: bot_act = get_config().guilds[0].me.activity - if self._activity is not _undefined: + if self._activity is not undefined: return activity_eq(self._activity, bot_act) - if self._name is not _undefined and self._name != bot_act.name: - return False - if self._url is not _undefined and self._url != bot_act.url: - return False - if self._type is not _undefined and self._type != bot_act.type: - return False + if bot_act is None: + return (self._name not in [undefined, None] + and self._url not in [undefined, None] + and self._type not in [undefined, None]) + + if isinstance(bot_act, discord.Game): + pass + elif isinstance(bot_act, discord.CustomActivity): + pass + elif isinstance(bot_act, discord.Spotify): + pass + else: + if self._name is not undefined and self._name != bot_act.name: + return False + if self._url is not undefined and self._url != bot_act.url: + return False + if self._type is not undefined and self._type != bot_act.type: + return False return True - def matches(self, activity) -> 'VerifyActivity': + def matches(self, activity: discord.activity.ActivityTypes | None) -> 'VerifyActivity': """ Ensure that the bot activity exactly matches the passed activity. Most restrictive possible check. :param activity: Activity to compare against :return: Self for chaining """ - if self._name is not _undefined or self._url is not _undefined or self._type is not _undefined: + if self._name is not undefined or self._url is not undefined or self._type is not undefined: raise ValueError("Verify exact match conflicts with verifying attributes") self._activity = activity return self - def name(self, name: str) -> 'VerifyActivity': + def name(self, name: str | None) -> 'VerifyActivity': """ Check that the activity name matches the input :param name: Name to match against :return: Self for chaining """ - if self._activity is not _undefined: + if self._activity is not undefined: raise ValueError("Verify name conflicts with verifying exact match") self._name = name return self - def url(self, url: str) -> 'VerifyActivity': + def url(self, url: str | None) -> 'VerifyActivity': """ Check the the activity url matches the input :param url: Url to match against :return: Self for chaining """ - if self._activity is not _undefined: + if self._activity is not undefined: raise ValueError("Verify url conflicts with verifying exact match") self._url = url return self - def type(self, type: discord.ActivityType) -> 'VerifyActivity': + def type(self, type: discord.ActivityType | None) -> 'VerifyActivity': """ Check the activity type matches the input :param type: Type to match against :return: Self for chaining """ - if self._activity is not _undefined: + if self._activity is not undefined: raise ValueError("Verify type conflicts with verifying exact match") self._type = type return self @@ -317,7 +336,7 @@ class Verify: intermediate step for the return of verify(). """ - def __init__(self): + def __init__(self) -> None: pass def message(self) -> VerifyMessage: diff --git a/discord/ext/test/voice.py b/discord/ext/test/voice.py index 70701b9..e78b2d7 100644 --- a/discord/ext/test/voice.py +++ b/discord/ext/test/voice.py @@ -28,7 +28,7 @@ async def connect( *, timeout: float = 60.0, reconnect: bool = True, - cls: Callable[[Client, Connectable], T] = FakeVoiceClient, + cls: Callable[[Client, Connectable], T] = FakeVoiceClient, # type: ignore[assignment] self_deaf: bool = False, self_mute: bool = False, ) -> T: diff --git a/discord/ext/test/websocket.py b/discord/ext/test/websocket.py index 661a75d..da30c1b 100644 --- a/discord/ext/test/websocket.py +++ b/discord/ext/test/websocket.py @@ -3,11 +3,13 @@ hooking of its methods to update the backend and provide callbacks. """ -import typing +from typing import Any + import discord import discord.gateway as gateway from . import callbacks +from .callbacks import CallbackEvent class FakeWebSocket(gateway.DiscordWebSocket): @@ -16,9 +18,13 @@ class FakeWebSocket(gateway.DiscordWebSocket): it simply triggers calls to the ``dpytest`` backend, as well as triggering runner callbacks. """ - def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: + cur_event: CallbackEvent | None + event_args: tuple[Any, ...] + event_kwargs: dict[str, Any] + + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - self.cur_event = "" + self.cur_event = None self.event_args = () self.event_kwargs = {} @@ -38,6 +44,6 @@ async def change_presence( status: str | None = None, since: float = 0.0 ) -> None: - self.cur_event = "presence" + self.cur_event = CallbackEvent.presence self.event_args = (activity, status, since) await super().change_presence(activity=activity, status=status, since=since) diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..df64f4f --- /dev/null +++ b/mypy.ini @@ -0,0 +1,15 @@ +[mypy] +packages = discord.ext.test, tests + +warn_redundant_casts = True +warn_unused_ignores = True + +strict_equality = True + +disallow_any_generics = True +disallow_any_unimported = True +disallow_subclassing_any = True + +disallow_untyped_calls = True +disallow_untyped_defs = True +disallow_untyped_decorators = True diff --git a/tests/conftest.py b/tests/conftest.py index 38ea755..93ac55d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,16 @@ from pathlib import Path +from typing import AsyncGenerator + import pytest_asyncio import discord import discord.ext.commands as commands import discord.ext.test as dpytest +from pytest import FixtureRequest from discord.client import _LoopSentinel @pytest_asyncio.fixture -async def bot(request) -> commands.Bot: +async def bot(request: FixtureRequest) -> commands.Bot: intents = discord.Intents.default() intents.members = True intents.message_content = True @@ -32,12 +35,12 @@ async def bot(request) -> commands.Bot: @pytest_asyncio.fixture(autouse=True) -async def cleanup(): +async def cleanup() -> AsyncGenerator[None, None]: yield await dpytest.empty_queue() -def pytest_sessionfinish(session, exitstatus): +def pytest_sessionfinish() -> None: """ Code to execute after all tests. """ # dat files are created when using attachements diff --git a/tests/internal/cogs/echo.py b/tests/internal/cogs/echo.py index 3169773..0c1fcb9 100644 --- a/tests/internal/cogs/echo.py +++ b/tests/internal/cogs/echo.py @@ -1,16 +1,18 @@ +from discord.ext import commands +from discord.ext.commands._types import BotT from discord.ext.commands import Cog, command class Echo(Cog): # Silence the default on_error handler - async def cog_command_error(self, ctx, error): + async def cog_command_error(self, ctx: commands.Context[BotT], error: Exception) -> None: pass @command() - async def echo(self, ctx, *, text: str): + async def echo(self, ctx: commands.Context[BotT], *, text: str) -> None: await ctx.send(text) -async def setup(bot): +async def setup(bot: commands.Bot) -> None: await bot.add_cog(Echo()) diff --git a/tests/internal/cogs/greeting.py b/tests/internal/cogs/greeting.py index fed1c89..1325edd 100644 --- a/tests/internal/cogs/greeting.py +++ b/tests/internal/cogs/greeting.py @@ -1,14 +1,15 @@ +import discord from discord.ext import commands class Greeting(commands.Cog): @commands.Cog.listener() - async def on_member_join(self, member): + async def on_member_join(self, member: discord.Member) -> None: channel = member.guild.text_channels[0] if channel is not None: await channel.send(f"Welcome {member.mention}.") -async def setup(bot): +async def setup(bot: commands.Bot) -> None: await bot.add_cog(Greeting()) diff --git a/tests/internal/cogs/poll.py b/tests/internal/cogs/poll.py new file mode 100644 index 0000000..6dc4aeb --- /dev/null +++ b/tests/internal/cogs/poll.py @@ -0,0 +1,19 @@ +from datetime import timedelta + +import discord +from discord.ext import commands +from discord.ext.commands._types import BotT +from discord.ext.commands import Cog, command + + +class Misc(Cog): + @command() + async def pollme(self, ctx: commands.Context[BotT]) -> None: + poll = discord.Poll(question="Test?", duration=timedelta(hours=1)) + poll.add_answer(text="Yes") + poll.add_answer(text="No") + await ctx.send("Poll test", poll=poll) + + +async def setup(bot: commands.Bot) -> None: + await bot.add_cog(Misc()) diff --git a/tests/test_activity.py b/tests/test_activity.py index 98ccb97..06fdfa6 100644 --- a/tests/test_activity.py +++ b/tests/test_activity.py @@ -4,7 +4,7 @@ @pytest.mark.asyncio -async def test_verify_activity_matches(bot): +async def test_verify_activity_matches(bot: discord.Client) -> None: fake_act = discord.Activity(name="Streaming", url="http://mystreamingfeed.xyz", type=discord.ActivityType.streaming) @@ -17,6 +17,6 @@ async def test_verify_activity_matches(bot): @pytest.mark.asyncio -async def test_verify_no_activity(bot): +async def test_verify_no_activity(bot: discord.Client) -> None: await bot.change_presence(activity=None) assert dpytest.verify().activity().matches(None) diff --git a/tests/test_ban.py b/tests/test_ban.py index e9fe017..35cc6ff 100644 --- a/tests/test_ban.py +++ b/tests/test_ban.py @@ -4,7 +4,7 @@ @pytest.mark.asyncio -async def test_ban_user(bot: discord.Client): +async def test_ban_user(bot: discord.Client) -> None: guild = bot.guilds[0] member = guild.members[0] await guild.ban(member) @@ -13,7 +13,7 @@ async def test_ban_user(bot: discord.Client): @pytest.mark.asyncio -async def test_unban_user(bot: discord.Client): +async def test_unban_user(bot: discord.Client) -> None: guild = bot.guilds[0] member = guild.members[0] await guild.ban(member) diff --git a/tests/test_configure.py b/tests/test_configure.py index 889ee13..993f823 100644 --- a/tests/test_configure.py +++ b/tests/test_configure.py @@ -4,7 +4,7 @@ @pytest.mark.asyncio -async def test_configure_guilds(bot): +async def test_configure_guilds(bot: discord.Client) -> None: dpytest.configure(bot, guilds=2) assert len(bot.guilds) == 2 assert bot.guilds[0].name == "Test Guild 0" @@ -23,12 +23,12 @@ async def test_configure_guilds(bot): @pytest.mark.asyncio -async def test_configure_text_channels(bot): +async def test_configure_text_channels(bot: discord.Client) -> None: dpytest.configure(bot, text_channels=3) guild = bot.guilds[0] assert len(guild.text_channels) == 3 - for num, channel in enumerate(guild.text_channels): - assert channel.name == f"TextChannel_{num}" + for num, chan in enumerate(guild.text_channels): + assert chan.name == f"TextChannel_{num}" dpytest.configure(bot, text_channels=["Fruits", "Videogames", "Coding", "Fun"]) guild = bot.guilds[0] @@ -40,18 +40,19 @@ async def test_configure_text_channels(bot): # we can even use discord.utils.get channel = discord.utils.get(guild.text_channels, name='Videogames') + assert channel is not None assert channel.name == "Videogames" await channel.send("Test Message") assert dpytest.verify().message().content("Test Message") @pytest.mark.asyncio -async def test_configure_voice_channels(bot): +async def test_configure_voice_channels(bot: discord.Client) -> None: dpytest.configure(bot, voice_channels=3) guild = bot.guilds[0] assert len(guild.voice_channels) == 3 - for num, channel in enumerate(guild.voice_channels): - assert channel.name == f"VoiceChannel_{num}" + for num, chan in enumerate(guild.voice_channels): + assert chan.name == f"VoiceChannel_{num}" dpytest.configure(bot, voice_channels=["Fruits", "Videogames", "Coding", "Fun"]) guild = bot.guilds[0] @@ -63,11 +64,12 @@ async def test_configure_voice_channels(bot): # we can even use discord.utils.get channel = discord.utils.get(guild.voice_channels, name='Videogames') + assert channel is not None assert channel.name == "Videogames" @pytest.mark.asyncio -async def test_configure_members(bot): +async def test_configure_members(bot: discord.Client) -> None: dpytest.configure(bot, members=3) guild = bot.guilds[0] assert len(guild.members) == 3 + 1 # because the bot is a member too @@ -84,20 +86,23 @@ async def test_configure_members(bot): # we can even use discord.utils.get william_member = discord.utils.get(guild.members, name='William') + assert william_member is not None assert william_member.name == "William" @pytest.mark.asyncio @pytest.mark.cogs("cogs.echo") -async def test_configure_all(bot): +async def test_configure_all(bot: discord.Client) -> None: dpytest.configure(bot, guilds=["CoolGuild", "LameGuild"], text_channels=["Fruits", "Videogames"], voice_channels=["Apples", "Bananas"], members=["Joe", "Jack", "William", "Averell"]) guild = bot.guilds[1] - channel: discord.TextChannel = discord.utils.get(guild.text_channels, name='Videogames') - jack: discord.Member = discord.utils.get(guild.members, name="Jack") + channel = discord.utils.get(guild.text_channels, name='Videogames') + assert channel is not None + jack = discord.utils.get(guild.members, name="Jack") + assert jack is not None mess = await dpytest.message("!echo Hello, my name is Jack", channel=channel, member=jack) assert mess.author.name == "Jack" - assert mess.channel.name == "Videogames" + assert mess.channel.name == "Videogames" # type: ignore[union-attr] assert dpytest.verify().message().content("Hello, my name is Jack") diff --git a/tests/test_create_channel.py b/tests/test_create_channel.py index 15be798..2e578cd 100644 --- a/tests/test_create_channel.py +++ b/tests/test_create_channel.py @@ -4,20 +4,20 @@ @pytest.mark.asyncio -async def test_create_voice_channel(bot): +async def test_create_voice_channel(bot: discord.Client) -> None: guild = bot.guilds[0] http = bot.http # create_channel checks the value of variables in the parent call context, so we need to set these for it to work self = guild # noqa: F841 name = "voice_channel_1" - channel = await http.create_channel(guild, channel_type=discord.ChannelType.voice.value) - assert channel['type'] == discord.ChannelType.voice + channel = await http.create_channel(guild.id, channel_type=discord.ChannelType.voice.value) + assert channel['type'] == discord.ChannelType.voice.value assert channel['name'] == name @pytest.mark.asyncio -async def test_make_voice_channel(bot): +async def test_make_voice_channel(bot: discord.Client) -> None: guild = bot.guilds[0] bitrate = 100 user_limit = 5 diff --git a/tests/test_dmchannel.py b/tests/test_dmchannel.py index cf619ad..104e620 100644 --- a/tests/test_dmchannel.py +++ b/tests/test_dmchannel.py @@ -1,10 +1,10 @@ - +import discord import pytest import discord.ext.test as dpytest @pytest.mark.asyncio -async def test_dm_send(bot): +async def test_dm_send(bot: discord.Client) -> None: guild = bot.guilds[0] await guild.members[0].send("hi") @@ -13,7 +13,7 @@ async def test_dm_send(bot): @pytest.mark.asyncio @pytest.mark.cogs("cogs.echo") -async def test_dm_message(bot): +async def test_dm_message(bot: discord.Client) -> None: guild = bot.guilds[0] member = guild.members[0] dm = await member.create_dm() diff --git a/tests/test_edit.py b/tests/test_edit.py index 5ba6341..8c8f882 100644 --- a/tests/test_edit.py +++ b/tests/test_edit.py @@ -1,12 +1,14 @@ +import discord import pytest import discord.ext.test as dpytest # noqa: F401 import discord.ext.commands as commands @pytest.mark.asyncio -async def test_edit(bot: commands.Bot): +async def test_edit(bot: commands.Bot) -> None: guild = bot.guilds[0] channel = guild.channels[0] + assert isinstance(channel, discord.TextChannel) mes = await channel.send("Test Message") persisted_mes1 = await channel.fetch_message(mes.id) diff --git a/tests/test_fetch_message.py b/tests/test_fetch_message.py index 1da686c..bb80f51 100644 --- a/tests/test_fetch_message.py +++ b/tests/test_fetch_message.py @@ -4,9 +4,9 @@ @pytest.mark.asyncio -async def test_get_message(bot): +async def test_get_message(bot: discord.Client) -> None: guild = bot.guilds[0] - channel = guild.channels[0] + channel: discord.TextChannel = guild.channels[0] # type: ignore[assignment] message = await channel.send("Test Message") message2 = await channel.fetch_message(message.id) diff --git a/tests/test_get.py b/tests/test_get.py index 75fa5d1..85d494d 100644 --- a/tests/test_get.py +++ b/tests/test_get.py @@ -4,7 +4,7 @@ @pytest.mark.asyncio -async def test_get_message(bot): +async def test_get_message(bot: discord.Client) -> None: """Dont use this in your code, it's just dummy test. Use verify_message() instead of 'get_message' and 'message.content' """ @@ -17,7 +17,7 @@ async def test_get_message(bot): @pytest.mark.asyncio -async def test_get_message_peek(bot): +async def test_get_message_peek(bot: discord.Client) -> None: """Dont use this in your code, it's just dummy test. Use verify_message() instead of 'get_message' and 'message.content' """ @@ -30,7 +30,7 @@ async def test_get_message_peek(bot): @pytest.mark.asyncio -async def test_get_embed(bot): +async def test_get_embed(bot: discord.Client) -> None: """Dont use this in your code, it's just dummy test. Use verify_embed() instead of 'get_embed' """ @@ -46,7 +46,7 @@ async def test_get_embed(bot): @pytest.mark.asyncio -async def test_get_embed_peek(bot): +async def test_get_embed_peek(bot: discord.Client) -> None: """Dont use this in your code, it's just dummy test. Use verify_embed() instead of 'get_embed' """ diff --git a/tests/test_get_channel_history.py b/tests/test_get_channel_history.py index 3f158d4..330fa6d 100644 --- a/tests/test_get_channel_history.py +++ b/tests/test_get_channel_history.py @@ -1,10 +1,11 @@ import pytest +import discord import discord.ext.test as dpytest # noqa: F401 from discord.utils import get @pytest.mark.asyncio -async def test_get_channel(bot): +async def test_get_channel(bot: discord.Client) -> None: guild = bot.guilds[0] channel_0 = guild.channels[0] @@ -14,11 +15,11 @@ async def test_get_channel(bot): @pytest.mark.asyncio -async def test_get_channel_history(bot): +async def test_get_channel_history(bot: discord.Client) -> None: guild = bot.guilds[0] channel_0 = guild.channels[0] - channel_get = get(guild.channels, name=channel_0.name) + channel_get: discord.TextChannel | None = get(guild.channels, name=channel_0.name) # type: ignore[assignment] assert channel_0 == channel_get diff --git a/tests/test_member_join.py b/tests/test_member.py similarity index 72% rename from tests/test_member_join.py rename to tests/test_member.py index 757d40e..0a41b57 100644 --- a/tests/test_member_join.py +++ b/tests/test_member.py @@ -1,10 +1,11 @@ +import discord import pytest import discord.ext.test as dpytest @pytest.mark.asyncio @pytest.mark.cogs("cogs.greeting") -async def test_member_join(bot): +async def test_member_join(bot: discord.Client) -> None: """Dont use this in your code, it's just dummy test. Use verify_message() instead of 'get_message' and 'message.content' """ @@ -21,3 +22,11 @@ async def test_member_join(bot): await dpytest.run_all_events() # requires for the cov Greeting listner to be executed # noqa: E501 assert dpytest.verify().message().content(f"Welcome {new_member.mention}.") + + +@pytest.mark.asyncio +async def test_fetch_members(bot: discord.Client) -> None: + guild = bot.guilds[0] + + async for member in guild.fetch_members(): + assert member.guild.id == guild.id diff --git a/tests/test_mentions.py b/tests/test_mentions.py index 19b5eba..da0fbd4 100644 --- a/tests/test_mentions.py +++ b/tests/test_mentions.py @@ -1,10 +1,11 @@ import pytest +import discord import discord.ext.test as dpytest @pytest.mark.asyncio -async def test_user_mention(bot): +async def test_user_mention(bot: discord.Client) -> None: guild = bot.guilds[0] mes = await dpytest.message(f"<@{guild.me.id}>") @@ -17,7 +18,7 @@ async def test_user_mention(bot): @pytest.mark.asyncio -async def test_role_mention(bot): +async def test_role_mention(bot: discord.Client) -> None: guild = bot.guilds[0] role = await guild.create_role(name="Test Role") mes = await dpytest.message(f"<@&{role.id}>") @@ -31,7 +32,7 @@ async def test_role_mention(bot): @pytest.mark.asyncio -async def test_channel_mention(bot): +async def test_channel_mention(bot: discord.Client) -> None: guild = bot.guilds[0] channel = guild.channels[0] mes = await dpytest.message(f"<#{channel.id}>") @@ -45,7 +46,8 @@ async def test_channel_mention(bot): @pytest.mark.asyncio -async def test_bot_mention(bot): +async def test_bot_mention(bot: discord.Client) -> None: + assert bot.user mes = await dpytest.message(f"<@{bot.user.id}>") assert len(mes.mentions) == 1 diff --git a/tests/test_message.py b/tests/test_message.py index 1224b4b..e93af3b 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -4,12 +4,12 @@ @pytest.mark.asyncio -async def test_messasge(bot): +async def test_messasge(bot: discord.Client) -> None: """Test make_message_dict from factory. """ guild = bot.guilds[0] author: discord.Member = guild.members[0] - channel = guild.channels[0] + channel: discord.TextChannel = guild.channels[0] # type: ignore[assignment] attach: discord.Attachment = discord.Attachment( state=dpytest.back.get_state(), data=dpytest.back.facts.make_attachment_dict( @@ -27,3 +27,13 @@ async def test_messasge(bot): message: discord.Message = discord.Message(state=dpytest.back.get_state(), channel=channel, data=message_dict) # noqa: E501,F841 (variable never used) except Exception as err: pytest.fail(str(err)) + + +@pytest.mark.asyncio +@pytest.mark.cogs("cogs.poll") +async def test_message_poll(bot: discord.Client) -> None: + """Test that messages with polls round-trip""" + await dpytest.message("!pollme") + message = dpytest.get_message() + assert message.content == "Poll test" + assert message.poll is not None diff --git a/tests/test_permissions.py b/tests/test_permissions.py index c4a694c..1c1ff9f 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -4,10 +4,8 @@ from discord import PermissionOverwrite -# TODO : fix this -@pytest.mark.skip(reason="test is currently broken, probably set_permission_overrides doing something wrong.") @pytest.mark.asyncio -async def test_permission_setting(bot): +async def test_permission_setting(bot: discord.Client) -> None: """tests, that the framework sets overrides correctly""" g = bot.guilds[0] c = g.text_channels[0] @@ -38,11 +36,9 @@ async def test_permission_setting(bot): assert perm.ban_members is False -# TODO : fix this -@pytest.mark.skip(reason="test is currently broken, probably set_permission_overrides doing something wrong.") @pytest.mark.asyncio @pytest.mark.cogs("cogs.echo") -async def test_bot_send_not_allowed(bot): +async def test_bot_send_not_allowed(bot: discord.Client) -> None: """tests, that a bot gets an Exception, if not allowed to send a message""" g = bot.guilds[0] c = g.text_channels[0] diff --git a/tests/test_reactions.py b/tests/test_reactions.py index 6a59fc4..131b7bc 100644 --- a/tests/test_reactions.py +++ b/tests/test_reactions.py @@ -1,9 +1,10 @@ import pytest +import discord import discord.ext.test as dpytest @pytest.mark.asyncio -async def test_add_reaction(bot): +async def test_add_reaction(bot: discord.Client) -> None: g = bot.guilds[0] c = g.text_channels[0] @@ -16,7 +17,7 @@ async def test_add_reaction(bot): @pytest.mark.asyncio -async def test_remove_reaction(bot): +async def test_remove_reaction(bot: discord.Client) -> None: g = bot.guilds[0] c = g.text_channels[0] @@ -29,7 +30,7 @@ async def test_remove_reaction(bot): @pytest.mark.asyncio -async def test_user_add_reaction(bot): +async def test_user_add_reaction(bot: discord.Client) -> None: g = bot.guilds[0] c = g.text_channels[0] m = g.members[0] @@ -45,7 +46,7 @@ async def test_user_add_reaction(bot): @pytest.mark.asyncio -async def test_user_remove_reaction(bot): +async def test_user_remove_reaction(bot: discord.Client) -> None: g = bot.guilds[0] c = g.text_channels[0] m = g.members[0] diff --git a/tests/test_role.py b/tests/test_role.py index 0c86f77..0090e34 100644 --- a/tests/test_role.py +++ b/tests/test_role.py @@ -4,7 +4,7 @@ @pytest.mark.asyncio -async def test_add_role(bot): +async def test_add_role(bot: discord.Client) -> None: guild = bot.guilds[0] staff_role = await guild.create_role(name="Staff") # Role object member1 = guild.members[0] # Member @@ -14,7 +14,7 @@ async def test_add_role(bot): @pytest.mark.asyncio -async def test_edit_role(bot): +async def test_edit_role(bot: discord.Client) -> None: await test_add_role(bot=bot) await bot.guilds[0].create_role(name="TestRole") # Role object assert len(bot.guilds[0].roles) == 3 @@ -29,7 +29,7 @@ async def test_edit_role(bot): @pytest.mark.asyncio -async def test_remove_role(bot): +async def test_remove_role(bot: discord.Client) -> None: guild = bot.guilds[0] staff_role = await guild.create_role(name="Staff") # Role object member1 = guild.members[0] # Member @@ -44,7 +44,7 @@ async def test_remove_role(bot): @pytest.mark.asyncio -async def test_remove_role2(bot): +async def test_remove_role2(bot: discord.Client) -> None: guild = bot.guilds[0] staff_role = await guild.create_role(name="Staff") # Role object @@ -55,3 +55,17 @@ async def test_remove_role2(bot): # then remove_role await dpytest.remove_role(0, staff_role) assert staff_role not in guild.members[0].roles + + +@pytest.mark.asyncio +async def test_member_add_roles(bot: discord.Client) -> None: + guild = bot.guilds[0] + member = guild.members[0] + + staff_role = await guild.create_role(name="Staff") + user_role = await guild.create_role(name="User") + + await member.add_roles(*[staff_role, user_role]) + + assert staff_role in member.roles + assert user_role in member.roles diff --git a/tests/test_send.py b/tests/test_send.py index b2fa3c3..999dc1d 100644 --- a/tests/test_send.py +++ b/tests/test_send.py @@ -4,7 +4,7 @@ @pytest.mark.asyncio -async def test_message(bot): +async def test_message(bot: discord.Client) -> None: guild = bot.guilds[0] channel = guild.text_channels[0] @@ -12,7 +12,7 @@ async def test_message(bot): @pytest.mark.asyncio -async def test_embed(bot): +async def test_embed(bot: discord.Client) -> None: guild = bot.guilds[0] channel = guild.text_channels[0] diff --git a/tests/test_utils.py b/tests/test_utils.py index 6623170..536c0a0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,51 +1,52 @@ from copy import deepcopy import pytest from discord import Embed +from discord.ext import commands from discord.ext.test.utils import embed_eq @pytest.mark.asyncio -async def test_embed_eq_direct(bot) -> None: +async def test_embed_eq_direct(bot: commands.Bot) -> None: embed_1: Embed = Embed() embed_2: Embed = Embed() assert embed_eq(embed_1, embed_2) is True @pytest.mark.asyncio -async def test_embed_eq_embed1_is_none(bot) -> None: +async def test_embed_eq_embed1_is_none(bot: commands.Bot) -> None: embed_2: Embed = Embed() assert embed_eq(None, embed_2) is False @pytest.mark.asyncio -async def test_embed_eq_embed2_is_none(bot) -> None: +async def test_embed_eq_embed2_is_none(bot: commands.Bot) -> None: embed_1: Embed = Embed() assert embed_eq(embed_1, None) is False @pytest.mark.asyncio -async def test_embed_eq_attr_title(bot) -> None: +async def test_embed_eq_attr_title(bot: commands.Bot) -> None: embed_1: Embed = Embed(title="Foo") embed_2: Embed = Embed(title="Bar") assert embed_eq(embed_1, embed_2) is False @pytest.mark.asyncio -async def test_embed_eq_attr_description(bot) -> None: +async def test_embed_eq_attr_description(bot: commands.Bot) -> None: embed_1: Embed = Embed(title="Foo", description="This is a Foo.") embed_2: Embed = Embed(title="Foo", description="This is a slightly different Foo.") assert embed_eq(embed_1, embed_2) is False @pytest.mark.asyncio -async def test_embed_eq_attr_url(bot) -> None: +async def test_embed_eq_attr_url(bot: commands.Bot) -> None: embed_1: Embed = Embed(title="Foo", description="This is a Foo.", url="http://www.foo.foo") embed_2: Embed = Embed(title="Foo", description="This is a Foo.", url="http://www.foo.bar") assert embed_eq(embed_1, embed_2) is False @pytest.mark.asyncio -async def test_embed_eq_attr_footer(bot) -> None: +async def test_embed_eq_attr_footer(bot: commands.Bot) -> None: embed_1: Embed = Embed(title="Foo", description="This is a Foo.", url="http://www.foo.foo") embed_1.set_footer(text="This is the footer for Foo.") embed_2: Embed = deepcopy(embed_1) @@ -54,7 +55,7 @@ async def test_embed_eq_attr_footer(bot) -> None: @pytest.mark.asyncio -async def test_embed_eq_attr_image(bot) -> None: +async def test_embed_eq_attr_image(bot: commands.Bot) -> None: embed_1: Embed = Embed(title="Foo", description="This is a Foo.", url="http://www.foo.foo") embed_1.set_footer(text="This is the footer for Foo.") embed_1.set_image(url="http://image.foo") @@ -64,7 +65,7 @@ async def test_embed_eq_attr_image(bot) -> None: @pytest.mark.asyncio -async def test_embed_eq_attr_fields(bot) -> None: +async def test_embed_eq_attr_fields(bot: commands.Bot) -> None: embed_1: Embed = Embed(title="Foo", description="This is a Foo.", url="http://www.foo.foo") embed_1.set_footer(text="This is the footer for Foo.") embed_1.set_image(url="http://image.foo") @@ -75,7 +76,7 @@ async def test_embed_eq_attr_fields(bot) -> None: @pytest.mark.asyncio -async def test_embed_eq_attr_equal(bot): +async def test_embed_eq_attr_equal(bot: commands.Bot) -> None: embed_1: Embed = Embed(title="Foo", description="This is a Foo.", url="http://www.foo.foo") embed_1.set_footer(text="This is the footer for Foo.") embed_1.set_image(url="http://image.foo") diff --git a/tests/test_verify_embed.py b/tests/test_verify_embed.py index d68d5a4..7427f0f 100644 --- a/tests/test_verify_embed.py +++ b/tests/test_verify_embed.py @@ -4,7 +4,7 @@ @pytest.mark.asyncio -async def test_embed(bot): +async def test_embed(bot: discord.Client) -> None: guild = bot.guilds[0] channel = guild.text_channels[0] @@ -19,7 +19,7 @@ async def test_embed(bot): @pytest.mark.asyncio -async def test_embed_KO(bot): +async def test_embed_KO(bot: discord.Client) -> None: guild = bot.guilds[0] channel = guild.text_channels[0] @@ -34,12 +34,12 @@ async def test_embed_KO(bot): @pytest.mark.asyncio -async def test_embed_assert_nothing(bot): +async def test_embed_assert_nothing(bot: discord.Client) -> None: assert dpytest.verify().message().nothing() @pytest.mark.asyncio -async def test_embed_peek(bot): +async def test_embed_peek(bot: discord.Client) -> None: guild = bot.guilds[0] channel = guild.text_channels[0] diff --git a/tests/test_verify_file.py b/tests/test_verify_file.py index e8c8801..7be5c0b 100644 --- a/tests/test_verify_file.py +++ b/tests/test_verify_file.py @@ -5,7 +5,7 @@ @pytest.mark.asyncio -async def test_verify_file_text(bot): +async def test_verify_file_text(bot: discord.Client) -> None: guild = bot.guilds[0] channel = guild.text_channels[0] @@ -16,7 +16,7 @@ async def test_verify_file_text(bot): @pytest.mark.asyncio -async def test_verify_file_jpg(bot): +async def test_verify_file_jpg(bot: discord.Client) -> None: guild = bot.guilds[0] channel = guild.text_channels[0] @@ -27,7 +27,7 @@ async def test_verify_file_jpg(bot): @pytest.mark.asyncio -async def test_verify_file_KO(bot): +async def test_verify_file_KO(bot: discord.Client) -> None: guild = bot.guilds[0] channel = guild.text_channels[0] diff --git a/tests/test_verify_message.py b/tests/test_verify_message.py index 20b497f..7e171f4 100644 --- a/tests/test_verify_message.py +++ b/tests/test_verify_message.py @@ -4,7 +4,7 @@ @pytest.mark.asyncio -async def test_message_equals(bot): +async def test_message_equals(bot: discord.Client) -> None: guild = bot.guilds[0] channel = guild.text_channels[0] @@ -13,7 +13,7 @@ async def test_message_equals(bot): @pytest.mark.asyncio -async def test_message_not_equals(bot): +async def test_message_not_equals(bot: discord.Client) -> None: guild = bot.guilds[0] channel = guild.text_channels[0] @@ -22,7 +22,7 @@ async def test_message_not_equals(bot): @pytest.mark.asyncio -async def test_message_contains_true(bot): +async def test_message_contains_true(bot: discord.Client) -> None: guild = bot.guilds[0] channel = guild.text_channels[0] @@ -31,7 +31,7 @@ async def test_message_contains_true(bot): @pytest.mark.asyncio -async def test_message_contains_false(bot): +async def test_message_contains_false(bot: discord.Client) -> None: guild = bot.guilds[0] channel = guild.text_channels[0] @@ -40,12 +40,12 @@ async def test_message_contains_false(bot): @pytest.mark.asyncio -async def test_message_assert_nothing(bot): +async def test_message_assert_nothing(bot: discord.Client) -> None: assert dpytest.verify().message().nothing() @pytest.mark.asyncio -async def test_message_peek(bot): +async def test_message_peek(bot: discord.Client) -> None: guild = bot.guilds[0] channel = guild.text_channels[0] diff --git a/tests/test_voice.py b/tests/test_voice.py index 986412a..6c24c8f 100644 --- a/tests/test_voice.py +++ b/tests/test_voice.py @@ -1,28 +1,30 @@ +import discord import pytest @pytest.mark.asyncio -async def test_bot_join_voice(bot): +async def test_bot_join_voice(bot: discord.Client) -> None: assert not bot.voice_clients await bot.guilds[0].voice_channels[0].connect() assert bot.voice_clients @pytest.mark.asyncio -async def test_bot_leave_voice(bot): - voice_client = await bot.guilds[0].voice_channels[0].connect() +async def test_bot_leave_voice(bot: discord.Client) -> None: + voice_client: discord.VoiceClient = await bot.guilds[0].voice_channels[0].connect() await voice_client.disconnect() assert not bot.voice_clients @pytest.mark.asyncio -async def test_move_member(bot): +async def test_move_member(bot: discord.Client) -> None: guild = bot.guilds[0] voice_channel = guild.voice_channels[0] member = guild.members[0] assert member.voice is None await member.move_to(voice_channel) + assert member.voice is not None assert member.voice.channel == voice_channel await member.move_to(None)