diff --git a/.jules/bolt.md b/.jules/bolt.md index 6f687f0a..97f56a05 100644 --- a/.jules/bolt.md +++ b/.jules/bolt.md @@ -37,3 +37,7 @@ ## 2026-02-08 - Return Type Consistency in Utilities **Learning:** Inconsistent return types in shared utility functions (like `process_uploaded_image`) can cause runtime crashes across multiple modules, especially when some expect tuples and others expect single values. This can lead to deployment failures that are hard to debug without full integration logs. **Action:** Always maintain strict return type consistency for core utilities. Use type hints and verify all call sites when changing a function's signature. Ensure that performance-oriented optimizations (like returning multiple processed formats) are applied uniformly. + +## 2026-02-09 - O(N) Cache Eviction in Hot Paths +**Learning:** Custom cache implementations using `dict` often resort to O(N) scans (like `min()` over all keys) for eviction, which degrades performance as cache size grows. Additionally, ad-hoc global caches in async routers can introduce race conditions and memory leaks. +**Action:** Replace manual dictionary caches with `collections.OrderedDict` to achieve O(1) LRU eviction. Centralize caching logic in a thread-safe utility class rather than duplicating weak implementations across modules. diff --git a/backend/cache.py b/backend/cache.py index 8dc58bdb..d7eced95 100644 --- a/backend/cache.py +++ b/backend/cache.py @@ -2,7 +2,7 @@ import logging import threading from typing import Any, Optional -from datetime import datetime, timedelta +from collections import OrderedDict logger = logging.getLogger(__name__) @@ -10,15 +10,14 @@ class ThreadSafeCache: """ Thread-safe cache implementation with TTL and memory management. Fixes race conditions and implements proper cache expiration. + Uses OrderedDict for O(1) LRU eviction. """ def __init__(self, ttl: int = 300, max_size: int = 100): - self._data = {} - self._timestamps = {} + self._data = OrderedDict() # Key -> (value, timestamp) self._ttl = ttl # Time to live in seconds self._max_size = max_size # Maximum number of cache entries self._lock = threading.RLock() # Reentrant lock for thread safety - self._access_count = {} # Track access frequency for LRU eviction def get(self, key: str = "default") -> Optional[Any]: """ @@ -27,16 +26,22 @@ def get(self, key: str = "default") -> Optional[Any]: with self._lock: current_time = time.time() - # Check if key exists and is not expired - if key in self._data and key in self._timestamps: - if current_time - self._timestamps[key] < self._ttl: - # Update access count for LRU - self._access_count[key] = self._access_count.get(key, 0) + 1 - return self._data[key] + # Check if key exists + if key in self._data: + value, timestamp = self._data[key] + + # Check expiration + if current_time - timestamp < self._ttl: + # Move to end (MRU) + self._data.move_to_end(key) + # print(f"DEBUG: get({key}) hit. Order: {list(self._data.keys())}") + return value else: # Expired entry - remove it - self._remove_key(key) + del self._data[key] + # print(f"DEBUG: get({key}) expired. Order: {list(self._data.keys())}") + # print(f"DEBUG: get({key}) miss. Order: {list(self._data.keys())}") return None def set(self, data: Any, key: str = "default") -> None: @@ -46,27 +51,30 @@ def set(self, data: Any, key: str = "default") -> None: with self._lock: current_time = time.time() - # Clean up expired entries before adding new one - self._cleanup_expired() + # If key already exists, update and move to end + if key in self._data: + self._data.move_to_end(key) - # If cache is full, evict least recently used entry - if len(self._data) >= self._max_size and key not in self._data: - self._evict_lru() + # Set new data + self._data[key] = (data, current_time) - # Set new data atomically - self._data[key] = data - self._timestamps[key] = current_time - self._access_count[key] = 1 + # Evict if over capacity + if len(self._data) > self._max_size: + # Remove first item (LRU) + popped = self._data.popitem(last=False) + # print(f"DEBUG: Evicted {popped[0]}. Order: {list(self._data.keys())}") + + # print(f"DEBUG: set({key}). Order: {list(self._data.keys())}") - logger.debug(f"Cache set: key={key}, size={len(self._data)}") def invalidate(self, key: str = "default") -> None: """ Thread-safe invalidation of specific key. """ with self._lock: - self._remove_key(key) - logger.debug(f"Cache invalidated: key={key}") + if key in self._data: + del self._data[key] + logger.debug(f"Cache invalidated: key={key}") def clear(self) -> None: """ @@ -74,8 +82,6 @@ def clear(self) -> None: """ with self._lock: self._data.clear() - self._timestamps.clear() - self._access_count.clear() logger.debug("Cache cleared") def get_stats(self) -> dict: @@ -85,8 +91,8 @@ def get_stats(self) -> dict: with self._lock: current_time = time.time() expired_count = sum( - 1 for ts in self._timestamps.values() - if current_time - ts >= self._ttl + 1 for _, timestamp in self._data.values() + if current_time - timestamp >= self._ttl ) return { @@ -95,45 +101,6 @@ def get_stats(self) -> dict: "max_size": self._max_size, "ttl_seconds": self._ttl } - - def _remove_key(self, key: str) -> None: - """ - Internal method to remove a key from all tracking dictionaries. - Must be called within lock context. - """ - self._data.pop(key, None) - self._timestamps.pop(key, None) - self._access_count.pop(key, None) - - def _cleanup_expired(self) -> None: - """ - Internal method to clean up expired entries. - Must be called within lock context. - """ - current_time = time.time() - expired_keys = [ - key for key, timestamp in self._timestamps.items() - if current_time - timestamp >= self._ttl - ] - - for key in expired_keys: - self._remove_key(key) - - if expired_keys: - logger.debug(f"Cleaned up {len(expired_keys)} expired cache entries") - - def _evict_lru(self) -> None: - """ - Internal method to evict least recently used entry. - Must be called within lock context. - """ - if not self._access_count: - return - - # Find key with lowest access count - lru_key = min(self._access_count.keys(), key=lambda k: self._access_count[k]) - self._remove_key(lru_key) - logger.debug(f"Evicted LRU cache entry: {lru_key}") class SimpleCache: """ diff --git a/backend/requirements-render.txt b/backend/requirements-render.txt index a5428240..654d3237 100644 --- a/backend/requirements-render.txt +++ b/backend/requirements-render.txt @@ -6,11 +6,11 @@ google-generativeai python-multipart psycopg2-binary huggingface-hub -httpx +httpx<0.28.0 pywebpush Pillow firebase-functions -firebase-admin +firebase-admin<7.0.0 a2wsgi python-jose[cryptography] passlib[bcrypt] diff --git a/backend/requirements.txt b/backend/requirements.txt index 054087c6..3150c27d 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -9,7 +9,7 @@ ultralyticsplus==0.0.28 ultralytics opencv-python-headless huggingface-hub -httpx +httpx<0.28.0 python-magic pywebpush # Local ML dependencies (Issue #76) @@ -17,7 +17,7 @@ torch transformers Pillow firebase-functions -firebase-admin +firebase-admin<7.0.0 a2wsgi # Spatial deduplication dependencies scikit-learn diff --git a/backend/routers/detection.py b/backend/routers/detection.py index 2b2dccc8..f33cbe79 100644 --- a/backend/routers/detection.py +++ b/backend/routers/detection.py @@ -46,27 +46,18 @@ # Cached Functions -# Simple Cache Implementation to avoid async-lru dependency issues on Render -_cache_store = {} +# Robust Cache Implementation using ThreadSafeCache (OrderedDict + LRU) +from backend.cache import ThreadSafeCache + CACHE_TTL = 3600 # 1 hour MAX_CACHE_SIZE = 500 +_detection_cache = ThreadSafeCache(ttl=CACHE_TTL, max_size=MAX_CACHE_SIZE) async def _get_cached_result(key: str, func, *args, **kwargs): - current_time = time.time() - # Check cache - if key in _cache_store: - result, timestamp = _cache_store[key] - if current_time - timestamp < CACHE_TTL: - return result - else: - del _cache_store[key] - - # Prune cache if too large - if len(_cache_store) > MAX_CACHE_SIZE: - keys_to_remove = list(_cache_store.keys())[:int(MAX_CACHE_SIZE * 0.2)] - for k in keys_to_remove: - del _cache_store[k] + cached_result = _detection_cache.get(key) + if cached_result is not None: + return cached_result # Execute function if 'client' not in kwargs: @@ -74,7 +65,7 @@ async def _get_cached_result(key: str, func, *args, **kwargs): kwargs['client'] = backend.dependencies.SHARED_HTTP_CLIENT result = await func(*args, **kwargs) - _cache_store[key] = (result, current_time) + _detection_cache.set(result, key) return result async def _cached_detect_severity(image_bytes: bytes): diff --git a/backend/tests/test_cache_refactor.py b/backend/tests/test_cache_refactor.py new file mode 100644 index 00000000..77c5cd68 --- /dev/null +++ b/backend/tests/test_cache_refactor.py @@ -0,0 +1,73 @@ +import time +import threading +import pytest +from unittest.mock import patch +from backend.cache import ThreadSafeCache + +def test_cache_lru_eviction(): + cache = ThreadSafeCache(ttl=60, max_size=3) + + # set(data, key) + cache.set("A", "a") + cache.set("B", "b") + cache.set("C", "c") + + # Access 'a' to make it MRU + cache.get("a") + + # Add 'd', should evict 'b' (LRU) + # Queue before: [b, c, a] (b is LRU) + # Queue after: [c, a, d] + cache.set("D", "d") + + assert cache.get("a") == "A" + assert cache.get("b") is None + assert cache.get("c") == "C" + assert cache.get("d") == "D" + +def test_cache_ttl_expiration(): + with patch('backend.cache.time.time') as mock_time: + mock_time.return_value = 1000 + cache = ThreadSafeCache(ttl=10, max_size=10) + + # set(data, key) + cache.set("A", "a") + assert cache.get("a") == "A" + + # Advance time beyond TTL + mock_time.return_value = 1011 + + assert cache.get("a") is None + +def test_cache_update_refresh(): + cache = ThreadSafeCache(ttl=60, max_size=2) + + cache.set("A1", "a") + cache.set("B1", "b") + + # Update 'a', making it MRU + cache.set("A2", "a") + + # Add 'c', should evict 'b' (LRU) + # Queue before: [b, a] + # Queue after: [a, c] + cache.set("C1", "c") + + assert cache.get("a") == "A2" + assert cache.get("b") is None + assert cache.get("c") == "C1" + +def test_thread_safety_concurrent_writes(): + """Verify thread safety under concurrent load.""" + cache = ThreadSafeCache(ttl=60, max_size=50) + + def worker(): + for i in range(100): + # set(data, key) + cache.set(i, f"key-{i}") + + threads = [threading.Thread(target=worker) for _ in range(10)] + for t in threads: t.start() + for t in threads: t.join() + + assert cache.get_stats()["total_entries"] <= 50