Skip to content

Commit 655e960

Browse files
authored
Merge pull request #1465 from codeflash-ai/brainiac
clear lru cache between runs
2 parents 8b90050 + 99d6e11 commit 655e960

File tree

2 files changed

+150
-4
lines changed

2 files changed

+150
-4
lines changed

codeflash/verification/pytest_plugin.py

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
from __future__ import annotations
22

3+
import contextlib
4+
import inspect
5+
36
# System Imports
47
import logging
58
import os
69
import re
10+
import sys
711
import time
812
import warnings
9-
from typing import TYPE_CHECKING
13+
from typing import TYPE_CHECKING, Any, Callable
1014
from unittest import TestCase
1115

1216
# PyTest Imports
@@ -133,6 +137,9 @@ def pytest_runtestloop(self, session: Session) -> bool:
133137
item._nodeid = self._set_nodeid(item._nodeid, count)
134138

135139
next_item: pytest.Item = session.items[index + 1] if index + 1 < len(session.items) else None
140+
141+
self._clear_lru_caches(item)
142+
136143
item.config.hook.pytest_runtest_protocol(item=item, nextitem=next_item)
137144
if session.shouldfail:
138145
raise session.Failed(session.shouldfail)
@@ -143,6 +150,59 @@ def pytest_runtestloop(self, session: Session) -> bool:
143150
time.sleep(self._get_delay_time(session))
144151
return True
145152

153+
def _clear_lru_caches(self, item: pytest.Item) -> None:
154+
processed_functions: set[Callable] = set()
155+
protected_modules = {
156+
"gc",
157+
"inspect",
158+
"os",
159+
"sys",
160+
"time",
161+
"functools",
162+
"pathlib",
163+
"typing",
164+
"dill",
165+
"pytest",
166+
"importlib",
167+
}
168+
169+
def _clear_cache_for_object(obj: Any) -> None: # noqa: ANN401
170+
if obj in processed_functions:
171+
return
172+
processed_functions.add(obj)
173+
174+
if hasattr(obj, "__wrapped__"):
175+
module_name = obj.__wrapped__.__module__
176+
else:
177+
try:
178+
obj_module = inspect.getmodule(obj)
179+
module_name = obj_module.__name__.split(".")[0] if obj_module is not None else None
180+
except Exception: # noqa: BLE001
181+
module_name = None
182+
183+
if module_name in protected_modules:
184+
return
185+
186+
if hasattr(obj, "cache_clear") and callable(obj.cache_clear):
187+
with contextlib.suppress(Exception):
188+
obj.cache_clear()
189+
190+
_clear_cache_for_object(item.function) # type: ignore[attr-defined]
191+
192+
try:
193+
if hasattr(item.function, "__module__"): # type: ignore[attr-defined]
194+
module_name = item.function.__module__ # type: ignore[attr-defined]
195+
try:
196+
module = sys.modules.get(module_name)
197+
if module:
198+
for _, obj in inspect.getmembers(module):
199+
if callable(obj):
200+
_clear_cache_for_object(obj)
201+
except Exception: # noqa: BLE001, S110
202+
pass
203+
except Exception: # noqa: BLE001, S110
204+
pass
205+
146206
def _set_nodeid(self, nodeid: str, count: int) -> str:
147207
"""Set loop count when using duration.
148208
@@ -205,8 +265,7 @@ def __pytest_loop_step_number(self, request: pytest.FixtureRequest) -> int:
205265
warnings.warn("Repeating unittest class tests not supported")
206266
else:
207267
raise UnexpectedError(
208-
"This call couldn't work with pytest-loops. "
209-
"Please consider raising an issue with your usage."
268+
"This call couldn't work with pytest-loops. Please consider raising an issue with your usage."
210269
)
211270
return count
212271

@@ -226,7 +285,7 @@ def pytest_generate_tests(self, metafunc: Metafunc) -> None:
226285
metafunc.fixturenames.append("__pytest_loop_step_number")
227286

228287
def make_progress_id(i: int, n: int = count) -> str:
229-
return f"{n}/{i+1}"
288+
return f"{n}/{i + 1}"
230289

231290
scope = metafunc.config.option.codeflash_loops_scope
232291
metafunc.parametrize(

tests/test_lru_cache_clear.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import types
2+
from typing import NoReturn
3+
4+
import pytest
5+
from _pytest.config import Config
6+
7+
from codeflash.verification.pytest_plugin import PyTest_Loops
8+
9+
10+
@pytest.fixture
11+
def pytest_loops_instance(pytestconfig: Config) -> PyTest_Loops:
12+
return PyTest_Loops(pytestconfig)
13+
14+
15+
@pytest.fixture
16+
def mock_item() -> type:
17+
class MockItem:
18+
def __init__(self, function: types.FunctionType) -> None:
19+
self.function = function
20+
21+
return MockItem
22+
23+
24+
def create_mock_module(module_name: str, source_code: str) -> types.ModuleType:
25+
module = types.ModuleType(module_name)
26+
exec(source_code, module.__dict__) # noqa: S102
27+
return module
28+
29+
30+
def test_clear_lru_caches_function(pytest_loops_instance: PyTest_Loops, mock_item: type) -> None:
31+
source_code = """
32+
import functools
33+
34+
@functools.lru_cache(maxsize=None)
35+
def my_func(x):
36+
return x * 2
37+
38+
my_func(10) # miss the cache
39+
my_func(10) # hit the cache
40+
"""
41+
mock_module = create_mock_module("test_module_func", source_code)
42+
item = mock_item(mock_module.my_func)
43+
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
44+
assert mock_module.my_func.cache_info().hits == 0
45+
assert mock_module.my_func.cache_info().misses == 0
46+
assert mock_module.my_func.cache_info().currsize == 0
47+
48+
49+
def test_clear_lru_caches_class_method(pytest_loops_instance: PyTest_Loops, mock_item: type) -> None:
50+
source_code = """
51+
import functools
52+
53+
class MyClass:
54+
@functools.lru_cache(maxsize=None)
55+
def my_method(self, x):
56+
return x * 3
57+
58+
obj = MyClass()
59+
obj.my_method(5) # Pre-populate the cache
60+
obj.my_method(5) # Hit the cache
61+
# """
62+
mock_module = create_mock_module("test_module_class", source_code)
63+
item = mock_item(mock_module.MyClass.my_method)
64+
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
65+
assert mock_module.MyClass.my_method.cache_info().hits == 0
66+
assert mock_module.MyClass.my_method.cache_info().misses == 0
67+
assert mock_module.MyClass.my_method.cache_info().currsize == 0
68+
69+
70+
def test_clear_lru_caches_exception_handling(pytest_loops_instance: PyTest_Loops, mock_item: type) -> None:
71+
"""Test that exceptions during clearing are handled."""
72+
73+
class BrokenCache:
74+
def cache_clear(self) -> NoReturn:
75+
msg = "Cache clearing failed!"
76+
raise ValueError(msg)
77+
78+
item = mock_item(BrokenCache())
79+
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
80+
81+
82+
def test_clear_lru_caches_no_cache(pytest_loops_instance: PyTest_Loops, mock_item: type) -> None:
83+
def no_cache_func(x: int) -> int:
84+
return x
85+
86+
item = mock_item(no_cache_func)
87+
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001

0 commit comments

Comments
 (0)