diff --git a/discord/ext/test/backend.py b/discord/ext/test/backend.py index 5afb4dc..d734460 100644 --- a/discord/ext/test/backend.py +++ b/discord/ext/test/backend.py @@ -221,15 +221,7 @@ async def edit_message(self, channel_id: int, message_id: int, **fields: dhttp.M await callbacks.dispatch_event("edit_message", message.channel, message, fields) - out = facts.dict_from_message(message) - payload = fields.get("params").payload - # TODO : do something for files and stuff. - # if params.files: - # return self.request(r, files=params.files, form=params.multipart) - # else: - # return self.request(r, json=params.payload) - out.update(payload) - return out + return edit_message(message, **fields) async def add_reaction(self, channel_id: int, message_id: int, emoji: str) -> None: locs = _get_higher_locs(1) @@ -853,6 +845,27 @@ def make_message( return state._get_message(data["id"]) +def edit_message( + message: discord.Message, **fields: dhttp.MultipartParameters +) -> _types.JsonDict: + data = facts.dict_from_message(message) + payload = fields.get("params").payload + # TODO : do something for files and stuff. + # if params.files: + # return self.request(r, files=params.files, form=params.multipart) + # else: + # return self.request(r, json=params.payload) + data.update(payload) + + i = 0 + while i < len(_cur_config.messages[message.channel.id]): + if _cur_config.messages[message.channel.id][i].get("id") == data.get("id"): + _cur_config.messages[message.channel.id][i] = data + i += 1 + + return data + + MEMBER_MENTION: typing.Pattern = re.compile(r"<@!?[0-9]{17,21}>", re.MULTILINE) ROLE_MENTION: typing.Pattern = re.compile(r"<@&([0-9]{17,21})>", re.MULTILINE) CHANNEL_MENTION: typing.Pattern = re.compile(r"<#[0-9]{17,21}>", re.MULTILINE) diff --git a/tests/test_edit.py b/tests/test_edit.py index b1dc888..49847a7 100644 --- a/tests/test_edit.py +++ b/tests/test_edit.py @@ -9,6 +9,10 @@ async def test_edit(bot): channel = guild.channels[0] mes = await channel.send("Test Message") + persisted_mes1 = await channel.fetch_message(mes.id) edited_mes = await mes.edit(content="New Message") + persisted_mes2 = await channel.fetch_message(mes.id) assert edited_mes.content == "New Message" + assert persisted_mes1.content == "Test Message" + assert persisted_mes2.content == "New Message"