Skip to content

Commit 652e98e

Browse files
committed
add tests
1 parent ba5d27a commit 652e98e

File tree

2 files changed

+141
-70
lines changed

2 files changed

+141
-70
lines changed

codeflash/verification/pytest_plugin.py

Lines changed: 54 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from __future__ import annotations
22

3-
import importlib
3+
import contextlib
44
import inspect
55

66
# System Imports
77
import logging
88
import os
99
import re
10+
import sys
1011
import time
1112
import warnings
1213
from typing import TYPE_CHECKING, Any, Callable
@@ -136,9 +137,6 @@ def pytest_runtestloop(self, session: Session) -> bool:
136137
item._nodeid = self._set_nodeid(item._nodeid, count)
137138

138139
next_item: pytest.Item = session.items[index + 1] if index + 1 < len(session.items) else None
139-
140-
self._clear_lru_caches(next_item)
141-
142140
item.config.hook.pytest_runtest_protocol(item=item, nextitem=next_item)
143141
if session.shouldfail:
144142
raise session.Failed(session.shouldfail)
@@ -149,6 +147,58 @@ def pytest_runtestloop(self, session: Session) -> bool:
149147
time.sleep(self._get_delay_time(session))
150148
return True
151149

150+
def _clear_lru_caches(self, item: pytest.Item) -> None:
151+
processed_functions: set[Callable] = set()
152+
protected_modules = {
153+
"gc",
154+
"inspect",
155+
"os",
156+
"sys",
157+
"time",
158+
"functools",
159+
"pathlib",
160+
"typing",
161+
"dill",
162+
"pytest",
163+
"importlib",
164+
}
165+
166+
def _clear_cache_for_object(obj: Any) -> None: # noqa: ANN401
167+
if obj in processed_functions:
168+
return
169+
processed_functions.add(obj)
170+
171+
if hasattr(obj, "__wrapped__"):
172+
module_name = obj.__wrapped__.__module__
173+
else:
174+
try:
175+
obj_module = inspect.getmodule(obj)
176+
module_name = obj_module.__name__.split(".")[0] if obj_module is not None else None
177+
except Exception: # noqa: BLE001
178+
module_name = None
179+
180+
if module_name in protected_modules:
181+
return
182+
183+
if hasattr(obj, "cache_clear") and callable(obj.cache_clear):
184+
with contextlib.suppress(Exception):
185+
obj.cache_clear()
186+
187+
_clear_cache_for_object(item.function) # type: ignore[attr-defined]
188+
189+
try:
190+
if hasattr(item.function, "__module__"): # type: ignore[attr-defined]
191+
module_name = item.function.__module__ # type: ignore[attr-defined]
192+
try:
193+
module = sys.modules.get(module_name)
194+
if module:
195+
for _, obj in inspect.getmembers(module):
196+
if callable(obj):
197+
_clear_cache_for_object(obj)
198+
except Exception: # noqa: BLE001, S110
199+
pass
200+
except Exception: # noqa: BLE001, S110
201+
pass
152202
def _set_nodeid(self, nodeid: str, count: int) -> str:
153203
"""Set loop count when using duration.
154204
@@ -194,72 +244,6 @@ def _timed_out(self, session: Session, start_time: float, count: int) -> bool:
194244
and time.time() - start_time > self._get_total_time(session)
195245
)
196246

197-
def _clear_lru_caches(self, item: pytest.Item) -> None:
198-
processed_functions: set[Callable] = set()
199-
cleared_caches = 0
200-
protected_modules = {
201-
"gc",
202-
"inspect",
203-
"os",
204-
"sys",
205-
"time",
206-
"functools",
207-
"pathlib",
208-
"typing",
209-
"dill",
210-
"pytest",
211-
"importlib",
212-
}
213-
214-
def _clear_cache_for_object(obj: Any) -> None:
215-
if obj in processed_functions:
216-
return
217-
processed_functions.add(obj)
218-
219-
try:
220-
obj_module = inspect.getmodule(obj)
221-
module_name = obj_module.__name__.split(".")[0] if obj_module is not None else None
222-
except AttributeError:
223-
module_name = None
224-
225-
if module_name in protected_modules:
226-
return
227-
228-
if hasattr(obj, "cache_clear") and callable(obj.cache_clear):
229-
try:
230-
obj.cache_clear()
231-
nonlocal cleared_caches
232-
cleared_caches += 1
233-
print(f"Cleared cache for: ^{obj.__name__}^")
234-
except Exception: # noqa: BLE001, S110
235-
pass
236-
237-
if callable(obj):
238-
try:
239-
for name in obj.__globals__:
240-
nested_obj = obj.__globals__[name]
241-
if callable(nested_obj):
242-
_clear_cache_for_object(nested_obj)
243-
except Exception: # noqa: BLE001, S110
244-
pass
245-
246-
if hasattr(item, "function") and callable(item.function):
247-
module_name = item.function.__module__
248-
full_file_path = inspect.getfile(item.function)
249-
try:
250-
module = importlib.import_module(module_name)
251-
252-
for _, obj in inspect.getmembers(module):
253-
if callable(obj):
254-
_clear_cache_for_object(obj)
255-
256-
if isinstance(obj, type):
257-
for _, class_method in inspect.getmembers(obj, predicate=inspect.isfunction):
258-
_clear_cache_for_object(class_method)
259-
260-
except Exception: # noqa: BLE001, S110
261-
pass
262-
263247
@pytest.fixture
264248
def __pytest_loop_step_number(self, request: pytest.FixtureRequest) -> int:
265249
"""Set step number for loop.
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import types
2+
from typing import NoReturn
3+
4+
import pytest
5+
from _pytest.config import Config
6+
7+
from codeflash.verification.pytest_plugin import PyTest_Loops
8+
9+
10+
@pytest.fixture
11+
def pytest_loops_instance(pytestconfig: Config) -> PyTest_Loops:
12+
return PyTest_Loops(pytestconfig)
13+
14+
15+
@pytest.fixture
16+
def mock_item() -> type:
17+
class MockItem:
18+
def __init__(self, function: types.FunctionType) -> None:
19+
self.function = function
20+
21+
return MockItem
22+
23+
24+
def create_mock_module(module_name: str, source_code: str) -> types.ModuleType:
25+
module = types.ModuleType(module_name)
26+
exec(source_code, module.__dict__) # noqa: S102
27+
return module
28+
29+
30+
def test_clear_lru_caches_function(pytest_loops_instance: PyTest_Loops, mock_item: type) -> None:
31+
source_code = """
32+
import functools
33+
34+
@functools.lru_cache(maxsize=None)
35+
def my_func(x):
36+
return x * 2
37+
38+
my_func(10) # miss the cache
39+
my_func(10) # hit the cache
40+
"""
41+
mock_module = create_mock_module("test_module_func", source_code)
42+
item = mock_item(mock_module.my_func)
43+
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
44+
assert mock_module.my_func.cache_info().hits == 0
45+
assert mock_module.my_func.cache_info().misses == 0
46+
assert mock_module.my_func.cache_info().currsize == 0
47+
48+
49+
def test_clear_lru_caches_class_method(pytest_loops_instance: PyTest_Loops, mock_item: type) -> None:
50+
source_code = """
51+
import functools
52+
53+
class MyClass:
54+
@functools.lru_cache(maxsize=None)
55+
def my_method(self, x):
56+
return x * 3
57+
58+
obj = MyClass()
59+
obj.my_method(5) # Pre-populate the cache
60+
obj.my_method(5) # Hit the cache
61+
# """
62+
mock_module = create_mock_module("test_module_class", source_code)
63+
item = mock_item(mock_module.MyClass.my_method)
64+
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
65+
assert mock_module.MyClass.my_method.cache_info().hits == 0
66+
assert mock_module.MyClass.my_method.cache_info().misses == 0
67+
assert mock_module.MyClass.my_method.cache_info().currsize == 0
68+
69+
70+
def test_clear_lru_caches_exception_handling(pytest_loops_instance: PyTest_Loops, mock_item: type) -> None:
71+
"""Test that exceptions during clearing are handled."""
72+
73+
class BrokenCache:
74+
def cache_clear(self) -> NoReturn:
75+
msg = "Cache clearing failed!"
76+
raise ValueError(msg)
77+
78+
item = mock_item(BrokenCache())
79+
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
80+
81+
82+
def test_clear_lru_caches_no_cache(pytest_loops_instance: PyTest_Loops, mock_item: type) -> None:
83+
def no_cache_func(x: int) -> int:
84+
return x
85+
86+
item = mock_item(no_cache_func)
87+
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001

0 commit comments

Comments
 (0)