From 1378b6ae255080ee16b24da247c2b14762b7ae59 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 29 Jan 2026 22:29:34 +0000 Subject: [PATCH 1/4] Initial plan From b1415820b6664647126dd80421f68aa54c2b46fa Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 29 Jan 2026 22:34:03 +0000 Subject: [PATCH 2/4] Add heat-based automod system with configuration commands Co-authored-by: LightSage <46062298+LightSage@users.noreply.github.com> --- lightning/cogs/automod/cog.py | 277 ++++++++++++++++++++++++- lightning/cogs/automod/models.py | 92 ++++++++ migrations/0021_heat_based_automod.sql | 21 ++ 3 files changed, 388 insertions(+), 2 deletions(-) create mode 100644 migrations/0021_heat_based_automod.sql diff --git a/lightning/cogs/automod/cog.py b/lightning/cogs/automod/cog.py index 14f96a93..3c1b295a 100644 --- a/lightning/cogs/automod/cog.py +++ b/lightning/cogs/automod/cog.py @@ -34,7 +34,7 @@ AutoModDurationResponse, IgnorableEntities) from lightning.cogs.automod.models import (AutomodConfig, GateKeeperConfig, - SpamConfig) + HeatConfig, SpamConfig) from lightning.constants import (AUTOMOD_ADVANCED_EVENT_NAMES_MAPPING, AUTOMOD_ALL_EVENT_NAMES_LITERAL, AUTOMOD_BASIC_EVENT_NAMES_MAPPING, @@ -67,6 +67,7 @@ class AutoMod(LightningCog, required=["Moderation"]): def __init__(self, bot: LightningBot): super().__init__(bot) self.gatekeepers: dict[int, GateKeeperConfig] = {} + self.heat_configs: dict[int, HeatConfig] = {} self.bot.loop.create_task(self.load_all_gatekeepers()) self.bot.add_dynamic_items(ui.GatekeeperVerificationButton, ui.GatekeeperVerificationHoneyPotButton) @@ -96,6 +97,39 @@ def invalidate_gatekeeper(self, guild_id: int): if gtkp := self.gatekeepers.pop(guild_id, None): gtkp.gtkp_loop.cancel() + @cache.cache() + async def get_heat_config(self, guild_id: int) -> Optional[HeatConfig]: + """Gets the heat configuration for a guild. + + Parameters + ---------- + guild_id : int + The guild ID + + Returns + ------- + Optional[HeatConfig] + The heat configuration, or None if not configured + """ + if guild_id in self.heat_configs: + return self.heat_configs[guild_id] + + query = "SELECT * FROM guild_automod_heat_config WHERE guild_id=$1;" + config = await self.bot.pool.fetchrow(query, guild_id) + if not config: + return None + + query = "SELECT * FROM guild_automod_heat_thresholds WHERE guild_id=$1 ORDER BY heat_threshold ASC;" + thresholds = await self.bot.pool.fetch(query, guild_id) + + self.heat_configs[guild_id] = heat_config = HeatConfig(self.bot, config, thresholds) + return heat_config + + def invalidate_heat_config(self, guild_id: int): + """Invalidates the cached heat config for a guild.""" + self.heat_configs.pop(guild_id, None) + self.get_heat_config.invalidate(self, guild_id) + async def cog_check(self, ctx: LightningContext) -> bool: if ctx.guild is None: raise commands.NoPrivateMessage() @@ -419,6 +453,196 @@ async def automod_warn_threshold_remove(self, ctx: GuildContext): await ctx.send("Removed warn threshold!") await self.get_automod_config.invalidate(ctx.guild.id) + # Heat-based automod commands + @automod.group(name='heat', level=CommandLevel.Admin) + async def automod_heat(self, ctx: GuildContext): + """Manage heat-based automod system""" + if ctx.invoked_subcommand is None: + await ctx.send_help('automod heat') + + @automod_heat.command(name='enable', level=CommandLevel.Admin) + @is_server_manager() + async def automod_heat_enable(self, ctx: GuildContext, decay_seconds: commands.Range[int, 60, 86400] = 3600): + """Enables the heat-based automod system + + Parameters + ---------- + decay_seconds: int + How long (in seconds) it takes for heat to fully decay. Default is 1 hour (3600 seconds). + Minimum 60 seconds, maximum 24 hours (86400 seconds). + """ + query = """INSERT INTO guild_automod_heat_config (guild_id, enabled, decay_seconds) + VALUES ($1, true, $2) + ON CONFLICT (guild_id) + DO UPDATE SET enabled=true, decay_seconds=EXCLUDED.decay_seconds;""" + await self.bot.pool.execute(query, ctx.guild.id, decay_seconds) + self.invalidate_heat_config(ctx.guild.id) + await ctx.send(f"✅ Heat-based automod enabled! Heat will decay over {decay_seconds} seconds.") + + @automod_heat.command(name='disable', level=CommandLevel.Admin) + @is_server_manager() + async def automod_heat_disable(self, ctx: GuildContext): + """Disables the heat-based automod system""" + query = "UPDATE guild_automod_heat_config SET enabled=false WHERE guild_id=$1;" + resp = await self.bot.pool.execute(query, ctx.guild.id) + + if resp == "UPDATE 0": + await ctx.send("Heat-based automod was not enabled!") + return + + self.invalidate_heat_config(ctx.guild.id) + await ctx.send("✅ Heat-based automod disabled!") + + @automod_heat.command(name='setheat', level=CommandLevel.Admin) + @is_server_manager() + async def automod_heat_set_value(self, ctx: GuildContext, violation_type: str, heat_value: float): + """Sets the heat value for a specific violation type + + Parameters + ---------- + violation_type: str + The violation type (e.g., 'message-spam', 'invite-spam', 'url-spam', 'mass-mentions', 'message-content-spam') + heat_value: float + The amount of heat to add for this violation + """ + # Ensure heat config exists + query = """INSERT INTO guild_automod_heat_config (guild_id) + VALUES ($1) + ON CONFLICT (guild_id) DO NOTHING;""" + await self.bot.pool.execute(query, ctx.guild.id) + + # Update heat values + query = """UPDATE guild_automod_heat_config + SET heat_per_violation = jsonb_set( + COALESCE(heat_per_violation, '{}'::jsonb), + ARRAY[$2], + to_jsonb($3::text) + ) + WHERE guild_id=$1;""" + await self.bot.pool.execute(query, ctx.guild.id, violation_type, str(heat_value)) + self.invalidate_heat_config(ctx.guild.id) + await ctx.send(f"✅ Set heat value for `{violation_type}` to `{heat_value}`") + + @automod_heat.command(name='addthreshold', level=CommandLevel.Admin) + @is_server_manager() + async def automod_heat_add_threshold(self, ctx: GuildContext, heat_threshold: int, + punishment: Literal['WARN', 'MUTE', 'KICK', 'BAN'], + duration: Optional[AutoModDuration] = None): + """Adds a heat threshold with a punishment + + Parameters + ---------- + heat_threshold: int + The heat level that triggers this punishment + punishment: Literal['WARN', 'MUTE', 'KICK', 'BAN'] + The punishment to apply + duration: Optional[AutoModDuration] + Duration for temporary punishments (MUTE/BAN only) + """ + # Ensure heat config exists + query = """INSERT INTO guild_automod_heat_config (guild_id) + VALUES ($1) + ON CONFLICT (guild_id) DO NOTHING;""" + await self.bot.pool.execute(query, ctx.guild.id) + + punishment_duration = None + if duration: + punishment_duration = duration.seconds + + query = """INSERT INTO guild_automod_heat_thresholds (guild_id, heat_threshold, punishment_type, punishment_duration) + VALUES ($1, $2, $3, $4);""" + await self.bot.pool.execute(query, ctx.guild.id, heat_threshold, punishment, punishment_duration) + self.invalidate_heat_config(ctx.guild.id) + + duration_str = f" for {duration.human_readable}" if duration else "" + await ctx.send(f"✅ Added threshold: {heat_threshold} heat → {punishment}{duration_str}") + + @automod_heat.command(name='removethreshold', level=CommandLevel.Admin) + @is_server_manager() + async def automod_heat_remove_threshold(self, ctx: GuildContext, threshold_id: int): + """Removes a heat threshold + + Parameters + ---------- + threshold_id: int + The ID of the threshold to remove + """ + query = "DELETE FROM guild_automod_heat_thresholds WHERE id=$1 AND guild_id=$2;" + resp = await self.bot.pool.execute(query, threshold_id, ctx.guild.id) + + if resp == "DELETE 0": + await ctx.send("Threshold not found!") + return + + self.invalidate_heat_config(ctx.guild.id) + await ctx.send("✅ Threshold removed!") + + @automod_heat.command(name='view', level=CommandLevel.Admin) + async def automod_heat_view(self, ctx: GuildContext): + """Views the current heat configuration""" + heat_config = await self.get_heat_config(ctx.guild.id) + + if not heat_config: + await ctx.send("Heat-based automod is not configured for this server!") + return + + embed = discord.Embed(title="Heat-Based Automod Configuration", color=discord.Color.blue()) + embed.add_field(name="Enabled", value="✅ Yes" if heat_config.enabled else "❌ No", inline=False) + embed.add_field(name="Decay Time", value=f"{heat_config.decay_seconds} seconds", inline=False) + + if heat_config.heat_per_violation: + violations = "\n".join([f"`{k}`: {v}" for k, v in heat_config.heat_per_violation.items()]) + embed.add_field(name="Heat Values", value=violations or "None set", inline=False) + + if heat_config.thresholds: + thresholds = "\n".join([ + f"ID {i}: {t.threshold} heat → {t.punishment}" + + (f" ({t.duration}s)" if t.duration else "") + for i, t in enumerate(heat_config.thresholds, 1) + ]) + embed.add_field(name="Thresholds", value=thresholds, inline=False) + else: + embed.add_field(name="Thresholds", value="None configured", inline=False) + + await ctx.send(embed=embed) + + @automod_heat.command(name='check', level=CommandLevel.Mod) + async def automod_heat_check(self, ctx: GuildContext, member: discord.Member): + """Checks a member's current heat level + + Parameters + ---------- + member: discord.Member + The member to check + """ + heat_config = await self.get_heat_config(ctx.guild.id) + + if not heat_config or not heat_config.enabled: + await ctx.send("Heat-based automod is not enabled!") + return + + heat = await heat_config.get_user_heat(member.id) + await ctx.send(f"{member.mention} currently has **{heat:.1f}** heat.") + + @automod_heat.command(name='reset', level=CommandLevel.Admin) + @is_server_manager() + async def automod_heat_reset(self, ctx: GuildContext, member: discord.Member): + """Resets a member's heat to 0 + + Parameters + ---------- + member: discord.Member + The member whose heat to reset + """ + heat_config = await self.get_heat_config(ctx.guild.id) + + if not heat_config: + await ctx.send("Heat-based automod is not configured!") + return + + await heat_config.reset_heat(member.id) + await ctx.send(f"✅ Reset heat for {member.mention}") + async def create_automod_config(self, guild: discord.Guild): query = """INSERT INTO guild_automod_config (guild_id) VALUES ($1) @@ -617,6 +841,48 @@ async def _handle_punishment(self, options: GuildAutoModRulePunishment, message: await meth(self, message, options.duration, reason=reason) + async def _handle_heat_punishment(self, message: AutoModMessage, violation_type: str) -> bool: + """Handles heat-based punishment for a violation. + + Parameters + ---------- + message : AutoModMessage + The message that triggered the violation + violation_type : str + The type of violation (e.g., 'message-spam', 'invite-spam') + + Returns + ------- + bool + True if a heat-based punishment was applied, False otherwise + """ + heat_config = await self.get_heat_config(message.guild.id) + if not heat_config or not heat_config.enabled: + return False + + # Add heat and get new level + new_heat = await heat_config.add_heat(message.author.id, violation_type) + + # Check if we should apply a punishment + punishment = heat_config.get_punishment_for_heat(new_heat) + if not punishment: + return False + + # Apply the punishment + automod_rule_name = AUTOMOD_EVENT_NAMES_MAPPING.get(violation_type, "AutoMod rule") + reason = f"{automod_rule_name} triggered (Heat: {new_heat:.1f})" + + meth = self.punishments[punishment.punishment] + + if punishment.punishment not in ("MUTE", "BAN"): + await meth(self, message, reason=reason) + else: + await meth(self, message, punishment.duration, reason=reason) + + # Reset heat after punishment + await heat_config.reset_heat(message.author.id) + return True + async def _delete_tracked_messages(self, messages: set[str], guild: discord.Guild): # Deletes message IDs tracked in AutoMod tmp: Dict[str, List[discord.Object]] = {} @@ -652,7 +918,14 @@ async def handle_bucket(attr_name: str, increment: Optional[Callable[[discord.Me self.bot.dispatch("lightning_guild_automod_rule_triggered", attr_name, message.guild.id) messages = await obj.fetch_responsible_messages(message) await obj.reset_bucket(message) - await self._handle_punishment(obj.punishment, message, attr_name) + + # Try heat-based punishment first + heat_applied = await self._handle_heat_punishment(message, attr_name) + + # If no heat punishment was applied, use the configured rule punishment + if not heat_applied: + await self._handle_punishment(obj.punishment, message, attr_name) + if obj.punishment.type != "BAN": await self._delete_tracked_messages(messages, message.guild) diff --git a/lightning/cogs/automod/models.py b/lightning/cogs/automod/models.py index 6f53b6fa..c9a9121a 100644 --- a/lightning/cogs/automod/models.py +++ b/lightning/cogs/automod/models.py @@ -284,3 +284,95 @@ async def disable(self): if members: await self.bot.redis_pool.lpush(f"lightning:automod:gatekeeper:{self.guild_id}:remove", *members) self.members.clear() + + +class HeatThreshold: + """Represents a heat threshold with its associated punishment.""" + __slots__ = ("threshold", "punishment", "duration") + + def __init__(self, record: asyncpg.Record) -> None: + self.threshold: int = record['heat_threshold'] + self.punishment: str = record['punishment_type'] + self.duration: Optional[int] = record.get('punishment_duration') + + +class HeatConfig: + """Configuration for the heat-based automod system.""" + __slots__ = ("guild_id", "enabled", "decay_seconds", "heat_per_violation", "thresholds", "bot") + + def __init__(self, bot: LightningBot, config: asyncpg.Record, thresholds: list[asyncpg.Record]) -> None: + self.bot = bot + self.guild_id: int = config['guild_id'] + self.enabled: bool = config['enabled'] + self.decay_seconds: int = config['decay_seconds'] + self.heat_per_violation: dict = config.get('heat_per_violation', {}) + self.thresholds: list[HeatThreshold] = sorted( + [HeatThreshold(t) for t in thresholds], + key=lambda x: x.threshold + ) + + async def get_user_heat(self, user_id: int) -> float: + """Gets the current heat level for a user. + + Returns + ------- + float + The current heat level (0.0 if no heat or expired) + """ + key = f"lightning:automod:heat:{self.guild_id}:{user_id}" + heat = await self.bot.redis_pool.get(key) + return float(heat) if heat else 0.0 + + async def add_heat(self, user_id: int, violation_type: str) -> float: + """Adds heat to a user for a violation. + + Parameters + ---------- + user_id : int + The user ID + violation_type : str + The type of violation (e.g., 'message-spam', 'invite-spam') + + Returns + ------- + float + The new heat level + """ + heat_to_add = self.heat_per_violation.get(violation_type, 1.0) + key = f"lightning:automod:heat:{self.guild_id}:{user_id}" + + # Add heat and set expiry + pipe = self.bot.redis_pool.pipeline() + pipe.incrbyfloat(key, heat_to_add) + pipe.expire(key, self.decay_seconds) + result = await pipe.execute() + return float(result[0]) + + async def reset_heat(self, user_id: int) -> None: + """Resets a user's heat to 0. + + Parameters + ---------- + user_id : int + The user ID to reset + """ + key = f"lightning:automod:heat:{self.guild_id}:{user_id}" + await self.bot.redis_pool.delete(key) + + def get_punishment_for_heat(self, heat: float) -> Optional[HeatThreshold]: + """Gets the appropriate punishment for a heat level. + + Parameters + ---------- + heat : float + The current heat level + + Returns + ------- + Optional[HeatThreshold] + The punishment to apply, or None if no threshold is met + """ + for threshold in reversed(self.thresholds): + if heat >= threshold.threshold: + return threshold + return None diff --git a/migrations/0021_heat_based_automod.sql b/migrations/0021_heat_based_automod.sql new file mode 100644 index 00000000..b8bf9582 --- /dev/null +++ b/migrations/0021_heat_based_automod.sql @@ -0,0 +1,21 @@ +-- Heat-based automod system +-- Stores configuration for heat thresholds and punishments + +CREATE TABLE IF NOT EXISTS guild_automod_heat_config +( + guild_id BIGINT NOT NULL REFERENCES guilds (id) ON DELETE CASCADE PRIMARY KEY, + enabled BOOLEAN DEFAULT 'f', + decay_seconds INT DEFAULT 3600, -- How long it takes for heat to fully decay (1 hour default) + heat_per_violation JSONB DEFAULT '{}'::jsonb -- Maps violation types to heat values +); + +CREATE TABLE IF NOT EXISTS guild_automod_heat_thresholds +( + id BIGINT GENERATED BY DEFAULT AS IDENTITY UNIQUE PRIMARY KEY, + guild_id BIGINT NOT NULL REFERENCES guilds (id) ON DELETE CASCADE, + heat_threshold INT NOT NULL, -- Heat level that triggers this punishment + punishment_type automod_punishment_enum NOT NULL, + punishment_duration INT -- Duration in seconds for temporary punishments +); + +CREATE INDEX IF NOT EXISTS guild_automod_heat_thresholds_guild_id_idx ON guild_automod_heat_thresholds(guild_id); From d3e47bf1793edffaea845cfbfdd6dce245cddcb4 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 29 Jan 2026 22:35:56 +0000 Subject: [PATCH 3/4] Fix threshold ID tracking and add comprehensive tests Co-authored-by: LightSage <46062298+LightSage@users.noreply.github.com> --- lightning/cogs/automod/cog.py | 4 +- lightning/cogs/automod/models.py | 3 +- tests/test_heat_system.py | 137 +++++++++++++++++++++++++++++++ 3 files changed, 141 insertions(+), 3 deletions(-) create mode 100644 tests/test_heat_system.py diff --git a/lightning/cogs/automod/cog.py b/lightning/cogs/automod/cog.py index 3c1b295a..ceb40617 100644 --- a/lightning/cogs/automod/cog.py +++ b/lightning/cogs/automod/cog.py @@ -596,9 +596,9 @@ async def automod_heat_view(self, ctx: GuildContext): if heat_config.thresholds: thresholds = "\n".join([ - f"ID {i}: {t.threshold} heat → {t.punishment}" + + f"ID {t.id}: {t.threshold} heat → {t.punishment}" + (f" ({t.duration}s)" if t.duration else "") - for i, t in enumerate(heat_config.thresholds, 1) + for t in heat_config.thresholds ]) embed.add_field(name="Thresholds", value=thresholds, inline=False) else: diff --git a/lightning/cogs/automod/models.py b/lightning/cogs/automod/models.py index c9a9121a..a11e7eaa 100644 --- a/lightning/cogs/automod/models.py +++ b/lightning/cogs/automod/models.py @@ -288,9 +288,10 @@ async def disable(self): class HeatThreshold: """Represents a heat threshold with its associated punishment.""" - __slots__ = ("threshold", "punishment", "duration") + __slots__ = ("id", "threshold", "punishment", "duration") def __init__(self, record: asyncpg.Record) -> None: + self.id: int = record['id'] self.threshold: int = record['heat_threshold'] self.punishment: str = record['punishment_type'] self.duration: Optional[int] = record.get('punishment_duration') diff --git a/tests/test_heat_system.py b/tests/test_heat_system.py new file mode 100644 index 00000000..d66dfaa0 --- /dev/null +++ b/tests/test_heat_system.py @@ -0,0 +1,137 @@ +""" +Unit tests for the heat-based automod system +""" +import unittest +from unittest.mock import AsyncMock, MagicMock, Mock + + +class MockRecord: + """Mock asyncpg.Record for testing""" + def __init__(self, data): + self._data = data + + def __getitem__(self, key): + return self._data[key] + + def get(self, key, default=None): + return self._data.get(key, default) + + +class TestHeatConfig(unittest.IsolatedAsyncioTestCase): + def setUp(self): + """Set up test fixtures""" + self.mock_bot = Mock() + self.mock_bot.redis_pool = AsyncMock() + + # Create a sample heat config record + config_record = MockRecord({ + 'guild_id': 123456789, + 'enabled': True, + 'decay_seconds': 3600, + 'heat_per_violation': {'message-spam': 2.0, 'invite-spam': 5.0} + }) + + # Create sample threshold records + threshold_records = [ + MockRecord({ + 'id': 1, + 'heat_threshold': 10, + 'punishment_type': 'WARN', + 'punishment_duration': None + }), + MockRecord({ + 'id': 2, + 'heat_threshold': 20, + 'punishment_type': 'MUTE', + 'punishment_duration': 3600 + }), + MockRecord({ + 'id': 3, + 'heat_threshold': 30, + 'punishment_type': 'KICK', + 'punishment_duration': None + }) + ] + + # Import here to avoid dependency issues + import sys + import os + sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + from lightning.cogs.automod.models import HeatConfig + + self.heat_config = HeatConfig(self.mock_bot, config_record, threshold_records) + + def test_heat_config_initialization(self): + """Test that HeatConfig initializes correctly""" + self.assertEqual(self.heat_config.guild_id, 123456789) + self.assertTrue(self.heat_config.enabled) + self.assertEqual(self.heat_config.decay_seconds, 3600) + self.assertEqual(self.heat_config.heat_per_violation['message-spam'], 2.0) + self.assertEqual(len(self.heat_config.thresholds), 3) + + def test_thresholds_sorted(self): + """Test that thresholds are sorted by heat level""" + thresholds = [t.threshold for t in self.heat_config.thresholds] + self.assertEqual(thresholds, [10, 20, 30]) + + async def test_add_heat(self): + """Test adding heat to a user""" + # Mock Redis response + self.mock_bot.redis_pool.pipeline.return_value.execute = AsyncMock(return_value=[15.0, True]) + + new_heat = await self.heat_config.add_heat(987654321, 'message-spam') + self.assertEqual(new_heat, 15.0) + + async def test_get_user_heat(self): + """Test getting user heat""" + # Mock Redis response + self.mock_bot.redis_pool.get = AsyncMock(return_value='12.5') + + heat = await self.heat_config.get_user_heat(987654321) + self.assertEqual(heat, 12.5) + + async def test_get_user_heat_no_heat(self): + """Test getting user heat when user has no heat""" + # Mock Redis response for no heat + self.mock_bot.redis_pool.get = AsyncMock(return_value=None) + + heat = await self.heat_config.get_user_heat(987654321) + self.assertEqual(heat, 0.0) + + def test_get_punishment_for_heat(self): + """Test getting the correct punishment for a heat level""" + # Test heat below all thresholds + punishment = self.heat_config.get_punishment_for_heat(5.0) + self.assertIsNone(punishment) + + # Test heat at first threshold + punishment = self.heat_config.get_punishment_for_heat(10.0) + self.assertIsNotNone(punishment) + self.assertEqual(punishment.punishment, 'WARN') + + # Test heat between thresholds + punishment = self.heat_config.get_punishment_for_heat(15.0) + self.assertIsNotNone(punishment) + self.assertEqual(punishment.punishment, 'WARN') + + # Test heat at second threshold + punishment = self.heat_config.get_punishment_for_heat(20.0) + self.assertIsNotNone(punishment) + self.assertEqual(punishment.punishment, 'MUTE') + self.assertEqual(punishment.duration, 3600) + + # Test heat above all thresholds + punishment = self.heat_config.get_punishment_for_heat(50.0) + self.assertIsNotNone(punishment) + self.assertEqual(punishment.punishment, 'KICK') + + async def test_reset_heat(self): + """Test resetting user heat""" + self.mock_bot.redis_pool.delete = AsyncMock(return_value=1) + + await self.heat_config.reset_heat(987654321) + self.mock_bot.redis_pool.delete.assert_called_once() + + +if __name__ == '__main__': + unittest.main() From 5ca33e5a7087a5aadc251a98a235480e0a8b2536 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 29 Jan 2026 22:40:32 +0000 Subject: [PATCH 4/4] Address code review feedback: add validation and fix punishment tracking Co-authored-by: LightSage <46062298+LightSage@users.noreply.github.com> --- lightning/cogs/automod/cog.py | 69 ++++++++++++++++++++++++-------- lightning/cogs/automod/models.py | 3 ++ 2 files changed, 56 insertions(+), 16 deletions(-) diff --git a/lightning/cogs/automod/cog.py b/lightning/cogs/automod/cog.py index ceb40617..f76a5abd 100644 --- a/lightning/cogs/automod/cog.py +++ b/lightning/cogs/automod/cog.py @@ -97,7 +97,6 @@ def invalidate_gatekeeper(self, guild_id: int): if gtkp := self.gatekeepers.pop(guild_id, None): gtkp.gtkp_loop.cancel() - @cache.cache() async def get_heat_config(self, guild_id: int) -> Optional[HeatConfig]: """Gets the heat configuration for a guild. @@ -128,7 +127,6 @@ async def get_heat_config(self, guild_id: int) -> Optional[HeatConfig]: def invalidate_heat_config(self, guild_id: int): """Invalidates the cached heat config for a guild.""" self.heat_configs.pop(guild_id, None) - self.get_heat_config.invalidate(self, guild_id) async def cog_check(self, ctx: LightningContext) -> bool: if ctx.guild is None: @@ -503,8 +501,23 @@ async def automod_heat_set_value(self, ctx: GuildContext, violation_type: str, h violation_type: str The violation type (e.g., 'message-spam', 'invite-spam', 'url-spam', 'mass-mentions', 'message-content-spam') heat_value: float - The amount of heat to add for this violation + The amount of heat to add for this violation (must be positive) """ + # Validate heat value is positive + if heat_value <= 0: + await ctx.send("❌ Heat value must be a positive number!") + return + + # Validate violation type + valid_types = ['message-spam', 'invite-spam', 'url-spam', 'mass-mentions', 'message-content-spam', + 'message_spam', 'invite_spam', 'url_spam', 'mass_mentions', 'message_content_spam'] + if violation_type not in valid_types: + await ctx.send(f"❌ Invalid violation type! Valid types are: {', '.join(set(valid_types))}") + return + + # Normalize to hyphenated format for storage + violation_type = violation_type.replace('_', '-') + # Ensure heat config exists query = """INSERT INTO guild_automod_heat_config (guild_id) VALUES ($1) @@ -533,18 +546,36 @@ async def automod_heat_add_threshold(self, ctx: GuildContext, heat_threshold: in Parameters ---------- heat_threshold: int - The heat level that triggers this punishment + The heat level that triggers this punishment (must be positive) punishment: Literal['WARN', 'MUTE', 'KICK', 'BAN'] The punishment to apply duration: Optional[AutoModDuration] Duration for temporary punishments (MUTE/BAN only) """ + # Validate heat threshold is positive + if heat_threshold <= 0: + await ctx.send("❌ Heat threshold must be a positive number!") + return + + # Validate duration is only provided for MUTE/BAN + if duration and punishment not in ('MUTE', 'BAN'): + await ctx.send(f"❌ Duration can only be specified for MUTE or BAN punishments, not {punishment}!") + return + # Ensure heat config exists query = """INSERT INTO guild_automod_heat_config (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;""" await self.bot.pool.execute(query, ctx.guild.id) + # Check for duplicate threshold + query = """SELECT id FROM guild_automod_heat_thresholds + WHERE guild_id=$1 AND heat_threshold=$2;""" + existing = await self.bot.pool.fetchrow(query, ctx.guild.id, heat_threshold) + if existing: + await ctx.send(f"❌ A threshold already exists for {heat_threshold} heat (ID: {existing['id']}). Remove it first if you want to change it.") + return + punishment_duration = None if duration: punishment_duration = duration.seconds @@ -841,7 +872,7 @@ async def _handle_punishment(self, options: GuildAutoModRulePunishment, message: await meth(self, message, options.duration, reason=reason) - async def _handle_heat_punishment(self, message: AutoModMessage, violation_type: str) -> bool: + async def _handle_heat_punishment(self, message: AutoModMessage, violation_type: str) -> Optional[str]: """Handles heat-based punishment for a violation. Parameters @@ -849,16 +880,17 @@ async def _handle_heat_punishment(self, message: AutoModMessage, violation_type: message : AutoModMessage The message that triggered the violation violation_type : str - The type of violation (e.g., 'message-spam', 'invite-spam') + The type of violation (e.g., 'message_spam', 'invite_spam', 'url_spam', 'mass_mentions', 'message_content_spam') + These are the internal attribute names used by the automod system. Returns ------- - bool - True if a heat-based punishment was applied, False otherwise + Optional[str] + The punishment type that was applied (e.g., 'WARN', 'MUTE', 'KICK', 'BAN'), or None if no punishment was applied """ heat_config = await self.get_heat_config(message.guild.id) if not heat_config or not heat_config.enabled: - return False + return None # Add heat and get new level new_heat = await heat_config.add_heat(message.author.id, violation_type) @@ -866,10 +898,10 @@ async def _handle_heat_punishment(self, message: AutoModMessage, violation_type: # Check if we should apply a punishment punishment = heat_config.get_punishment_for_heat(new_heat) if not punishment: - return False + return None # Apply the punishment - automod_rule_name = AUTOMOD_EVENT_NAMES_MAPPING.get(violation_type, "AutoMod rule") + automod_rule_name = AUTOMOD_EVENT_NAMES_MAPPING.get(violation_type.replace('_', '-'), "AutoMod rule") reason = f"{automod_rule_name} triggered (Heat: {new_heat:.1f})" meth = self.punishments[punishment.punishment] @@ -881,7 +913,7 @@ async def _handle_heat_punishment(self, message: AutoModMessage, violation_type: # Reset heat after punishment await heat_config.reset_heat(message.author.id) - return True + return punishment.punishment async def _delete_tracked_messages(self, messages: set[str], guild: discord.Guild): # Deletes message IDs tracked in AutoMod @@ -920,13 +952,18 @@ async def handle_bucket(attr_name: str, increment: Optional[Callable[[discord.Me await obj.reset_bucket(message) # Try heat-based punishment first - heat_applied = await self._handle_heat_punishment(message, attr_name) + heat_punishment_type = await self._handle_heat_punishment(message, attr_name) - # If no heat punishment was applied, use the configured rule punishment - if not heat_applied: + # Determine which punishment was actually applied + if heat_punishment_type: + applied_punishment_type = heat_punishment_type + else: + # If no heat punishment was applied, use the configured rule punishment await self._handle_punishment(obj.punishment, message, attr_name) + applied_punishment_type = str(obj.punishment.type) - if obj.punishment.type != "BAN": + # Only delete messages if the punishment wasn't a BAN + if applied_punishment_type != "BAN": await self._delete_tracked_messages(messages, message.guild) await handle_bucket('mass_mentions', lambda m: len(m.mentions) + len(m.role_mentions)) diff --git a/lightning/cogs/automod/models.py b/lightning/cogs/automod/models.py index a11e7eaa..45dd14a5 100644 --- a/lightning/cogs/automod/models.py +++ b/lightning/cogs/automod/models.py @@ -343,6 +343,9 @@ async def add_heat(self, user_id: int, violation_type: str) -> float: key = f"lightning:automod:heat:{self.guild_id}:{user_id}" # Add heat and set expiry + # Note: EXPIRE resets the TTL on each violation, meaning heat decays after + # a period of inactivity rather than gradually over time. This is intentional - + # active violators should maintain high heat, and it only decays when they stop. pipe = self.bot.redis_pool.pipeline() pipe.incrbyfloat(key, heat_to_add) pipe.expire(key, self.decay_seconds)