diff --git a/lightning/cogs/automod/cog.py b/lightning/cogs/automod/cog.py index 14f96a93..f76a5abd 100644 --- a/lightning/cogs/automod/cog.py +++ b/lightning/cogs/automod/cog.py @@ -34,7 +34,7 @@ AutoModDurationResponse, IgnorableEntities) from lightning.cogs.automod.models import (AutomodConfig, GateKeeperConfig, - SpamConfig) + HeatConfig, SpamConfig) from lightning.constants import (AUTOMOD_ADVANCED_EVENT_NAMES_MAPPING, AUTOMOD_ALL_EVENT_NAMES_LITERAL, AUTOMOD_BASIC_EVENT_NAMES_MAPPING, @@ -67,6 +67,7 @@ class AutoMod(LightningCog, required=["Moderation"]): def __init__(self, bot: LightningBot): super().__init__(bot) self.gatekeepers: dict[int, GateKeeperConfig] = {} + self.heat_configs: dict[int, HeatConfig] = {} self.bot.loop.create_task(self.load_all_gatekeepers()) self.bot.add_dynamic_items(ui.GatekeeperVerificationButton, ui.GatekeeperVerificationHoneyPotButton) @@ -96,6 +97,37 @@ def invalidate_gatekeeper(self, guild_id: int): if gtkp := self.gatekeepers.pop(guild_id, None): gtkp.gtkp_loop.cancel() + async def get_heat_config(self, guild_id: int) -> Optional[HeatConfig]: + """Gets the heat configuration for a guild. + + Parameters + ---------- + guild_id : int + The guild ID + + Returns + ------- + Optional[HeatConfig] + The heat configuration, or None if not configured + """ + if guild_id in self.heat_configs: + return self.heat_configs[guild_id] + + query = "SELECT * FROM guild_automod_heat_config WHERE guild_id=$1;" + config = await self.bot.pool.fetchrow(query, guild_id) + if not config: + return None + + query = "SELECT * FROM guild_automod_heat_thresholds WHERE guild_id=$1 ORDER BY heat_threshold ASC;" + thresholds = await self.bot.pool.fetch(query, guild_id) + + self.heat_configs[guild_id] = heat_config = HeatConfig(self.bot, config, thresholds) + return heat_config + + def invalidate_heat_config(self, guild_id: int): + """Invalidates the cached heat config for a guild.""" + self.heat_configs.pop(guild_id, None) + async def cog_check(self, ctx: LightningContext) -> bool: if ctx.guild is None: raise commands.NoPrivateMessage() @@ -419,6 +451,229 @@ async def automod_warn_threshold_remove(self, ctx: GuildContext): await ctx.send("Removed warn threshold!") await self.get_automod_config.invalidate(ctx.guild.id) + # Heat-based automod commands + @automod.group(name='heat', level=CommandLevel.Admin) + async def automod_heat(self, ctx: GuildContext): + """Manage heat-based automod system""" + if ctx.invoked_subcommand is None: + await ctx.send_help('automod heat') + + @automod_heat.command(name='enable', level=CommandLevel.Admin) + @is_server_manager() + async def automod_heat_enable(self, ctx: GuildContext, decay_seconds: commands.Range[int, 60, 86400] = 3600): + """Enables the heat-based automod system + + Parameters + ---------- + decay_seconds: int + How long (in seconds) it takes for heat to fully decay. Default is 1 hour (3600 seconds). + Minimum 60 seconds, maximum 24 hours (86400 seconds). + """ + query = """INSERT INTO guild_automod_heat_config (guild_id, enabled, decay_seconds) + VALUES ($1, true, $2) + ON CONFLICT (guild_id) + DO UPDATE SET enabled=true, decay_seconds=EXCLUDED.decay_seconds;""" + await self.bot.pool.execute(query, ctx.guild.id, decay_seconds) + self.invalidate_heat_config(ctx.guild.id) + await ctx.send(f"✅ Heat-based automod enabled! Heat will decay over {decay_seconds} seconds.") + + @automod_heat.command(name='disable', level=CommandLevel.Admin) + @is_server_manager() + async def automod_heat_disable(self, ctx: GuildContext): + """Disables the heat-based automod system""" + query = "UPDATE guild_automod_heat_config SET enabled=false WHERE guild_id=$1;" + resp = await self.bot.pool.execute(query, ctx.guild.id) + + if resp == "UPDATE 0": + await ctx.send("Heat-based automod was not enabled!") + return + + self.invalidate_heat_config(ctx.guild.id) + await ctx.send("✅ Heat-based automod disabled!") + + @automod_heat.command(name='setheat', level=CommandLevel.Admin) + @is_server_manager() + async def automod_heat_set_value(self, ctx: GuildContext, violation_type: str, heat_value: float): + """Sets the heat value for a specific violation type + + Parameters + ---------- + violation_type: str + The violation type (e.g., 'message-spam', 'invite-spam', 'url-spam', 'mass-mentions', 'message-content-spam') + heat_value: float + The amount of heat to add for this violation (must be positive) + """ + # Validate heat value is positive + if heat_value <= 0: + await ctx.send("❌ Heat value must be a positive number!") + return + + # Validate violation type + valid_types = ['message-spam', 'invite-spam', 'url-spam', 'mass-mentions', 'message-content-spam', + 'message_spam', 'invite_spam', 'url_spam', 'mass_mentions', 'message_content_spam'] + if violation_type not in valid_types: + await ctx.send(f"❌ Invalid violation type! Valid types are: {', '.join(set(valid_types))}") + return + + # Normalize to hyphenated format for storage + violation_type = violation_type.replace('_', '-') + + # Ensure heat config exists + query = """INSERT INTO guild_automod_heat_config (guild_id) + VALUES ($1) + ON CONFLICT (guild_id) DO NOTHING;""" + await self.bot.pool.execute(query, ctx.guild.id) + + # Update heat values + query = """UPDATE guild_automod_heat_config + SET heat_per_violation = jsonb_set( + COALESCE(heat_per_violation, '{}'::jsonb), + ARRAY[$2], + to_jsonb($3::text) + ) + WHERE guild_id=$1;""" + await self.bot.pool.execute(query, ctx.guild.id, violation_type, str(heat_value)) + self.invalidate_heat_config(ctx.guild.id) + await ctx.send(f"✅ Set heat value for `{violation_type}` to `{heat_value}`") + + @automod_heat.command(name='addthreshold', level=CommandLevel.Admin) + @is_server_manager() + async def automod_heat_add_threshold(self, ctx: GuildContext, heat_threshold: int, + punishment: Literal['WARN', 'MUTE', 'KICK', 'BAN'], + duration: Optional[AutoModDuration] = None): + """Adds a heat threshold with a punishment + + Parameters + ---------- + heat_threshold: int + The heat level that triggers this punishment (must be positive) + punishment: Literal['WARN', 'MUTE', 'KICK', 'BAN'] + The punishment to apply + duration: Optional[AutoModDuration] + Duration for temporary punishments (MUTE/BAN only) + """ + # Validate heat threshold is positive + if heat_threshold <= 0: + await ctx.send("❌ Heat threshold must be a positive number!") + return + + # Validate duration is only provided for MUTE/BAN + if duration and punishment not in ('MUTE', 'BAN'): + await ctx.send(f"❌ Duration can only be specified for MUTE or BAN punishments, not {punishment}!") + return + + # Ensure heat config exists + query = """INSERT INTO guild_automod_heat_config (guild_id) + VALUES ($1) + ON CONFLICT (guild_id) DO NOTHING;""" + await self.bot.pool.execute(query, ctx.guild.id) + + # Check for duplicate threshold + query = """SELECT id FROM guild_automod_heat_thresholds + WHERE guild_id=$1 AND heat_threshold=$2;""" + existing = await self.bot.pool.fetchrow(query, ctx.guild.id, heat_threshold) + if existing: + await ctx.send(f"❌ A threshold already exists for {heat_threshold} heat (ID: {existing['id']}). Remove it first if you want to change it.") + return + + punishment_duration = None + if duration: + punishment_duration = duration.seconds + + query = """INSERT INTO guild_automod_heat_thresholds (guild_id, heat_threshold, punishment_type, punishment_duration) + VALUES ($1, $2, $3, $4);""" + await self.bot.pool.execute(query, ctx.guild.id, heat_threshold, punishment, punishment_duration) + self.invalidate_heat_config(ctx.guild.id) + + duration_str = f" for {duration.human_readable}" if duration else "" + await ctx.send(f"✅ Added threshold: {heat_threshold} heat → {punishment}{duration_str}") + + @automod_heat.command(name='removethreshold', level=CommandLevel.Admin) + @is_server_manager() + async def automod_heat_remove_threshold(self, ctx: GuildContext, threshold_id: int): + """Removes a heat threshold + + Parameters + ---------- + threshold_id: int + The ID of the threshold to remove + """ + query = "DELETE FROM guild_automod_heat_thresholds WHERE id=$1 AND guild_id=$2;" + resp = await self.bot.pool.execute(query, threshold_id, ctx.guild.id) + + if resp == "DELETE 0": + await ctx.send("Threshold not found!") + return + + self.invalidate_heat_config(ctx.guild.id) + await ctx.send("✅ Threshold removed!") + + @automod_heat.command(name='view', level=CommandLevel.Admin) + async def automod_heat_view(self, ctx: GuildContext): + """Views the current heat configuration""" + heat_config = await self.get_heat_config(ctx.guild.id) + + if not heat_config: + await ctx.send("Heat-based automod is not configured for this server!") + return + + embed = discord.Embed(title="Heat-Based Automod Configuration", color=discord.Color.blue()) + embed.add_field(name="Enabled", value="✅ Yes" if heat_config.enabled else "❌ No", inline=False) + embed.add_field(name="Decay Time", value=f"{heat_config.decay_seconds} seconds", inline=False) + + if heat_config.heat_per_violation: + violations = "\n".join([f"`{k}`: {v}" for k, v in heat_config.heat_per_violation.items()]) + embed.add_field(name="Heat Values", value=violations or "None set", inline=False) + + if heat_config.thresholds: + thresholds = "\n".join([ + f"ID {t.id}: {t.threshold} heat → {t.punishment}" + + (f" ({t.duration}s)" if t.duration else "") + for t in heat_config.thresholds + ]) + embed.add_field(name="Thresholds", value=thresholds, inline=False) + else: + embed.add_field(name="Thresholds", value="None configured", inline=False) + + await ctx.send(embed=embed) + + @automod_heat.command(name='check', level=CommandLevel.Mod) + async def automod_heat_check(self, ctx: GuildContext, member: discord.Member): + """Checks a member's current heat level + + Parameters + ---------- + member: discord.Member + The member to check + """ + heat_config = await self.get_heat_config(ctx.guild.id) + + if not heat_config or not heat_config.enabled: + await ctx.send("Heat-based automod is not enabled!") + return + + heat = await heat_config.get_user_heat(member.id) + await ctx.send(f"{member.mention} currently has **{heat:.1f}** heat.") + + @automod_heat.command(name='reset', level=CommandLevel.Admin) + @is_server_manager() + async def automod_heat_reset(self, ctx: GuildContext, member: discord.Member): + """Resets a member's heat to 0 + + Parameters + ---------- + member: discord.Member + The member whose heat to reset + """ + heat_config = await self.get_heat_config(ctx.guild.id) + + if not heat_config: + await ctx.send("Heat-based automod is not configured!") + return + + await heat_config.reset_heat(member.id) + await ctx.send(f"✅ Reset heat for {member.mention}") + async def create_automod_config(self, guild: discord.Guild): query = """INSERT INTO guild_automod_config (guild_id) VALUES ($1) @@ -617,6 +872,49 @@ async def _handle_punishment(self, options: GuildAutoModRulePunishment, message: await meth(self, message, options.duration, reason=reason) + async def _handle_heat_punishment(self, message: AutoModMessage, violation_type: str) -> Optional[str]: + """Handles heat-based punishment for a violation. + + Parameters + ---------- + message : AutoModMessage + The message that triggered the violation + violation_type : str + The type of violation (e.g., 'message_spam', 'invite_spam', 'url_spam', 'mass_mentions', 'message_content_spam') + These are the internal attribute names used by the automod system. + + Returns + ------- + Optional[str] + The punishment type that was applied (e.g., 'WARN', 'MUTE', 'KICK', 'BAN'), or None if no punishment was applied + """ + heat_config = await self.get_heat_config(message.guild.id) + if not heat_config or not heat_config.enabled: + return None + + # Add heat and get new level + new_heat = await heat_config.add_heat(message.author.id, violation_type) + + # Check if we should apply a punishment + punishment = heat_config.get_punishment_for_heat(new_heat) + if not punishment: + return None + + # Apply the punishment + automod_rule_name = AUTOMOD_EVENT_NAMES_MAPPING.get(violation_type.replace('_', '-'), "AutoMod rule") + reason = f"{automod_rule_name} triggered (Heat: {new_heat:.1f})" + + meth = self.punishments[punishment.punishment] + + if punishment.punishment not in ("MUTE", "BAN"): + await meth(self, message, reason=reason) + else: + await meth(self, message, punishment.duration, reason=reason) + + # Reset heat after punishment + await heat_config.reset_heat(message.author.id) + return punishment.punishment + async def _delete_tracked_messages(self, messages: set[str], guild: discord.Guild): # Deletes message IDs tracked in AutoMod tmp: Dict[str, List[discord.Object]] = {} @@ -652,8 +950,20 @@ async def handle_bucket(attr_name: str, increment: Optional[Callable[[discord.Me self.bot.dispatch("lightning_guild_automod_rule_triggered", attr_name, message.guild.id) messages = await obj.fetch_responsible_messages(message) await obj.reset_bucket(message) - await self._handle_punishment(obj.punishment, message, attr_name) - if obj.punishment.type != "BAN": + + # Try heat-based punishment first + heat_punishment_type = await self._handle_heat_punishment(message, attr_name) + + # Determine which punishment was actually applied + if heat_punishment_type: + applied_punishment_type = heat_punishment_type + else: + # If no heat punishment was applied, use the configured rule punishment + await self._handle_punishment(obj.punishment, message, attr_name) + applied_punishment_type = str(obj.punishment.type) + + # Only delete messages if the punishment wasn't a BAN + if applied_punishment_type != "BAN": await self._delete_tracked_messages(messages, message.guild) await handle_bucket('mass_mentions', lambda m: len(m.mentions) + len(m.role_mentions)) diff --git a/lightning/cogs/automod/models.py b/lightning/cogs/automod/models.py index 6f53b6fa..45dd14a5 100644 --- a/lightning/cogs/automod/models.py +++ b/lightning/cogs/automod/models.py @@ -284,3 +284,99 @@ async def disable(self): if members: await self.bot.redis_pool.lpush(f"lightning:automod:gatekeeper:{self.guild_id}:remove", *members) self.members.clear() + + +class HeatThreshold: + """Represents a heat threshold with its associated punishment.""" + __slots__ = ("id", "threshold", "punishment", "duration") + + def __init__(self, record: asyncpg.Record) -> None: + self.id: int = record['id'] + self.threshold: int = record['heat_threshold'] + self.punishment: str = record['punishment_type'] + self.duration: Optional[int] = record.get('punishment_duration') + + +class HeatConfig: + """Configuration for the heat-based automod system.""" + __slots__ = ("guild_id", "enabled", "decay_seconds", "heat_per_violation", "thresholds", "bot") + + def __init__(self, bot: LightningBot, config: asyncpg.Record, thresholds: list[asyncpg.Record]) -> None: + self.bot = bot + self.guild_id: int = config['guild_id'] + self.enabled: bool = config['enabled'] + self.decay_seconds: int = config['decay_seconds'] + self.heat_per_violation: dict = config.get('heat_per_violation', {}) + self.thresholds: list[HeatThreshold] = sorted( + [HeatThreshold(t) for t in thresholds], + key=lambda x: x.threshold + ) + + async def get_user_heat(self, user_id: int) -> float: + """Gets the current heat level for a user. + + Returns + ------- + float + The current heat level (0.0 if no heat or expired) + """ + key = f"lightning:automod:heat:{self.guild_id}:{user_id}" + heat = await self.bot.redis_pool.get(key) + return float(heat) if heat else 0.0 + + async def add_heat(self, user_id: int, violation_type: str) -> float: + """Adds heat to a user for a violation. + + Parameters + ---------- + user_id : int + The user ID + violation_type : str + The type of violation (e.g., 'message-spam', 'invite-spam') + + Returns + ------- + float + The new heat level + """ + heat_to_add = self.heat_per_violation.get(violation_type, 1.0) + key = f"lightning:automod:heat:{self.guild_id}:{user_id}" + + # Add heat and set expiry + # Note: EXPIRE resets the TTL on each violation, meaning heat decays after + # a period of inactivity rather than gradually over time. This is intentional - + # active violators should maintain high heat, and it only decays when they stop. + pipe = self.bot.redis_pool.pipeline() + pipe.incrbyfloat(key, heat_to_add) + pipe.expire(key, self.decay_seconds) + result = await pipe.execute() + return float(result[0]) + + async def reset_heat(self, user_id: int) -> None: + """Resets a user's heat to 0. + + Parameters + ---------- + user_id : int + The user ID to reset + """ + key = f"lightning:automod:heat:{self.guild_id}:{user_id}" + await self.bot.redis_pool.delete(key) + + def get_punishment_for_heat(self, heat: float) -> Optional[HeatThreshold]: + """Gets the appropriate punishment for a heat level. + + Parameters + ---------- + heat : float + The current heat level + + Returns + ------- + Optional[HeatThreshold] + The punishment to apply, or None if no threshold is met + """ + for threshold in reversed(self.thresholds): + if heat >= threshold.threshold: + return threshold + return None diff --git a/migrations/0021_heat_based_automod.sql b/migrations/0021_heat_based_automod.sql new file mode 100644 index 00000000..b8bf9582 --- /dev/null +++ b/migrations/0021_heat_based_automod.sql @@ -0,0 +1,21 @@ +-- Heat-based automod system +-- Stores configuration for heat thresholds and punishments + +CREATE TABLE IF NOT EXISTS guild_automod_heat_config +( + guild_id BIGINT NOT NULL REFERENCES guilds (id) ON DELETE CASCADE PRIMARY KEY, + enabled BOOLEAN DEFAULT 'f', + decay_seconds INT DEFAULT 3600, -- How long it takes for heat to fully decay (1 hour default) + heat_per_violation JSONB DEFAULT '{}'::jsonb -- Maps violation types to heat values +); + +CREATE TABLE IF NOT EXISTS guild_automod_heat_thresholds +( + id BIGINT GENERATED BY DEFAULT AS IDENTITY UNIQUE PRIMARY KEY, + guild_id BIGINT NOT NULL REFERENCES guilds (id) ON DELETE CASCADE, + heat_threshold INT NOT NULL, -- Heat level that triggers this punishment + punishment_type automod_punishment_enum NOT NULL, + punishment_duration INT -- Duration in seconds for temporary punishments +); + +CREATE INDEX IF NOT EXISTS guild_automod_heat_thresholds_guild_id_idx ON guild_automod_heat_thresholds(guild_id); diff --git a/tests/test_heat_system.py b/tests/test_heat_system.py new file mode 100644 index 00000000..d66dfaa0 --- /dev/null +++ b/tests/test_heat_system.py @@ -0,0 +1,137 @@ +""" +Unit tests for the heat-based automod system +""" +import unittest +from unittest.mock import AsyncMock, MagicMock, Mock + + +class MockRecord: + """Mock asyncpg.Record for testing""" + def __init__(self, data): + self._data = data + + def __getitem__(self, key): + return self._data[key] + + def get(self, key, default=None): + return self._data.get(key, default) + + +class TestHeatConfig(unittest.IsolatedAsyncioTestCase): + def setUp(self): + """Set up test fixtures""" + self.mock_bot = Mock() + self.mock_bot.redis_pool = AsyncMock() + + # Create a sample heat config record + config_record = MockRecord({ + 'guild_id': 123456789, + 'enabled': True, + 'decay_seconds': 3600, + 'heat_per_violation': {'message-spam': 2.0, 'invite-spam': 5.0} + }) + + # Create sample threshold records + threshold_records = [ + MockRecord({ + 'id': 1, + 'heat_threshold': 10, + 'punishment_type': 'WARN', + 'punishment_duration': None + }), + MockRecord({ + 'id': 2, + 'heat_threshold': 20, + 'punishment_type': 'MUTE', + 'punishment_duration': 3600 + }), + MockRecord({ + 'id': 3, + 'heat_threshold': 30, + 'punishment_type': 'KICK', + 'punishment_duration': None + }) + ] + + # Import here to avoid dependency issues + import sys + import os + sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + from lightning.cogs.automod.models import HeatConfig + + self.heat_config = HeatConfig(self.mock_bot, config_record, threshold_records) + + def test_heat_config_initialization(self): + """Test that HeatConfig initializes correctly""" + self.assertEqual(self.heat_config.guild_id, 123456789) + self.assertTrue(self.heat_config.enabled) + self.assertEqual(self.heat_config.decay_seconds, 3600) + self.assertEqual(self.heat_config.heat_per_violation['message-spam'], 2.0) + self.assertEqual(len(self.heat_config.thresholds), 3) + + def test_thresholds_sorted(self): + """Test that thresholds are sorted by heat level""" + thresholds = [t.threshold for t in self.heat_config.thresholds] + self.assertEqual(thresholds, [10, 20, 30]) + + async def test_add_heat(self): + """Test adding heat to a user""" + # Mock Redis response + self.mock_bot.redis_pool.pipeline.return_value.execute = AsyncMock(return_value=[15.0, True]) + + new_heat = await self.heat_config.add_heat(987654321, 'message-spam') + self.assertEqual(new_heat, 15.0) + + async def test_get_user_heat(self): + """Test getting user heat""" + # Mock Redis response + self.mock_bot.redis_pool.get = AsyncMock(return_value='12.5') + + heat = await self.heat_config.get_user_heat(987654321) + self.assertEqual(heat, 12.5) + + async def test_get_user_heat_no_heat(self): + """Test getting user heat when user has no heat""" + # Mock Redis response for no heat + self.mock_bot.redis_pool.get = AsyncMock(return_value=None) + + heat = await self.heat_config.get_user_heat(987654321) + self.assertEqual(heat, 0.0) + + def test_get_punishment_for_heat(self): + """Test getting the correct punishment for a heat level""" + # Test heat below all thresholds + punishment = self.heat_config.get_punishment_for_heat(5.0) + self.assertIsNone(punishment) + + # Test heat at first threshold + punishment = self.heat_config.get_punishment_for_heat(10.0) + self.assertIsNotNone(punishment) + self.assertEqual(punishment.punishment, 'WARN') + + # Test heat between thresholds + punishment = self.heat_config.get_punishment_for_heat(15.0) + self.assertIsNotNone(punishment) + self.assertEqual(punishment.punishment, 'WARN') + + # Test heat at second threshold + punishment = self.heat_config.get_punishment_for_heat(20.0) + self.assertIsNotNone(punishment) + self.assertEqual(punishment.punishment, 'MUTE') + self.assertEqual(punishment.duration, 3600) + + # Test heat above all thresholds + punishment = self.heat_config.get_punishment_for_heat(50.0) + self.assertIsNotNone(punishment) + self.assertEqual(punishment.punishment, 'KICK') + + async def test_reset_heat(self): + """Test resetting user heat""" + self.mock_bot.redis_pool.delete = AsyncMock(return_value=1) + + await self.heat_config.reset_heat(987654321) + self.mock_bot.redis_pool.delete.assert_called_once() + + +if __name__ == '__main__': + unittest.main()