Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 56 additions & 29 deletions test/common/capture_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import dataclasses
import functools
from collections.abc import Mapping
from typing import Any, Dict, List

from common.db_utils import write_to_db
Expand Down Expand Up @@ -42,14 +45,65 @@ def post_process(table_name: str, **kwargs) -> List[Dict[str, Any]]:
return []


def _ensure_list(obj):
"""
Ensure the object is returned as a list.
"""
if isinstance(obj, list):
return obj
if isinstance(obj, (str, bytes, Mapping)):
return [obj]
if hasattr(obj, "__iter__") and not hasattr(obj, "__len__"): # 如 generator
return list(obj)
return [obj]


def _to_dict(obj: Any) -> Dict[str, Any]:
"""
Convert various object types to a dictionary for DB writing.
"""
if isinstance(obj, Mapping):
return dict(obj)
if dataclasses.is_dataclass(obj):
return dataclasses.asdict(obj)
if hasattr(obj, "_asdict"): # namedtuple
return obj._asdict()
if hasattr(obj, "__dict__"):
return vars(obj)
raise TypeError(f"Cannot convert {type(obj)} to dict for DB writing")


def proj_process(table_name: str, **kwargs) -> List[Dict[str, Any]]:
if "_proj" not in kwargs:
return []

name = kwargs.get("_name", table_name)
raw_input = kwargs["_proj"]
raw_results = _ensure_list(raw_input)

processed_results = []
for result in raw_results:
try:
dict_result = _to_dict(result)
write_to_db(name, dict_result)
processed_results.append(dict_result)
except Exception as e:
raise ValueError(f"Failed to process item in _proj: {e}") from e

return processed_results


# ---------------- decorator ----------------
def export_vars(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
# If the function returns a dict containing '_data' or 'data', post-process it
# If the function returns a dict containing '_data' or '_proj', post-process it
if isinstance(result, dict):
if "_data" in result or "data" in result:
if "_data" in result:
return post_process(func.__name__, **result)
if "_proj" in result:
return proj_process(func.__name__, **result)
# Otherwise return unchanged
return result

Expand All @@ -63,33 +117,6 @@ def capture():
return {"name": "demo", "_data": {"accuracy": 0.1, "loss": 0.3}}


@export_vars
def capture_list():
"""All lists via '_name' + '_data'"""
return {
"_name": "demo",
"_data": {
"accuracy": [0.1, 0.2, 0.3],
"loss": [0.1, 0.2, 0.3],
},
}


@export_vars
def capture_mix():
"""Mixed single + lists via '_name' + '_data'"""
return {
"_name": "demo",
"_data": {
"length": 10086, # single value
"accuracy": [0.1, 0.2, 0.3], # list
"loss": [0.1, 0.2, 0.3], # list
},
}


# quick test
if __name__ == "__main__":
print("capture(): ", capture())
print("capture_list(): ", capture_list())
print("capture_mix(): ", capture_mix())
8 changes: 8 additions & 0 deletions test/common/db.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
CREATE TABLE test_results (
id INT AUTO_INCREMENT PRIMARY KEY,
test_case VARCHAR(255) NOT NULL,
status VARCHAR(50) NOT NULL,
error TEXT,
test_build_id VARCHAR(100) NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
6 changes: 3 additions & 3 deletions test/common/db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _get_db() -> Optional[MySQLDatabase]:
backup_str = db_config.get("backup", "results/")
_backup_path = Path(backup_str).resolve()
_backup_path.mkdir(parents=True, exist_ok=True)
logger.info(f"Backup directory set to: {_backup_path}")
# logger.info(f"Backup directory set to: {_backup_path}")

if not _db_enabled:
return None
Expand Down Expand Up @@ -94,7 +94,7 @@ def _backup_to_file(table_name: str, data: Dict[str, Any]) -> None:
with file_path.open("a", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False)
f.write("\n")
logger.info(f"Data backed up to {file_path}")
# logger.info(f"Data backed up to {file_path}")
except Exception as e:
logger.error(f"Failed to write backup file {file_path}: {e}")

Expand Down Expand Up @@ -140,7 +140,7 @@ def write_to_db(table_name: str, data: Dict[str, Any]) -> bool:

with db.atomic():
DynamicEntity.insert(filtered_data).execute()
logger.info(f"Successfully inserted data into table '{table_name}'.")
# logger.info(f"Successfully inserted data into table '{table_name}'.")
return True

except peewee.PeeweeException as e:
Expand Down
6 changes: 3 additions & 3 deletions test/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@ reports:
use_timestamp: true
directory_prefix: "pytest"
html: # pytest-html
enabled: true
enabled: false
filename: "report.html"
title: "UCM Pytest Test Report"

database:
backup: "results/"
enabled: true
enabled: false
host: "127.0.0.1"
port: 3306
name: "ucm_pytest"
name: "ucm_test"
user: "root"
password: "123456"
charset: "utf8mb4"
5 changes: 2 additions & 3 deletions test/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ python_functions = test_*

addopts =
-ra
--strict-markers
--capture=no
filterwarnings =
ignore::pytest.PytestReturnNotNoneWarning
Expand All @@ -18,8 +17,8 @@ norecursedirs = .git venv env __pycache__ *.egg

markers =
# -------- Levels (Required) --------
stage(n): Unit/Smoke/Regression/Release (0=Unit 1=Smoke 2=Regression 3=Release)
stage: Unit/Smoke/Regression/Release (0=Unit 1=Smoke 2=Regression 3=Release)
# -------- Features (Recommended) --------
feature: Feature tag
platform(name): Platform tag(gpu/npu)
platform: Platform tag(gpu/npu)
# end of markers
79 changes: 67 additions & 12 deletions test/suites/E2E/test_demo_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,73 @@ def test_divide_by_zero(self, calc):


@pytest.mark.feature("capture") # pytest must be the top
@export_vars
def test_capture_mix():
"""Mixed single + lists via '_name' + '_data'"""
assert 1 == 1
return {
"_name": "demo",
"_data": {
"length": 10086, # single value
"accuracy": [0.1, 0.2, 0.3], # list
"loss": [0.1, 0.2, 0.3], # list
},
}
class TestCapture:
@export_vars
def test_capture_mix(self):
"""Mixed single + lists via '_name' + '_data'"""
assert 1 == 1
return {
"_name": "capture_demo",
"_data": {
"length": 1, # single value
"accuracy": [0.1, 0.2, 0.3], # list
"loss": [0.1, 0.2, 0.3], # list
},
}

@export_vars
def test_capture_dict(self):
"""Mixed single + lists via '_name' + '_proj'"""
return {
"_name": "capture_demo",
"_proj": {"length": 2, "accuracy": 0.1, "loss": 0.1},
}

@export_vars
def test_capture_list_dict(self):
"""Mixed single + lists via '_name' + '_proj'"""
return {
"_name": "capture_demo",
"_proj": [
{"length": 3, "accuracy": 0.1, "loss": 0.1},
{"length": 3, "accuracy": 0.2, "loss": 0.2},
{"length": 3, "accuracy": 0.3, "loss": 0.3},
],
}

@export_vars
def test_capture_proj(self):
"""Mixed single + lists via '_name' + '_proj'"""

class Result:
def __init__(self, length, accuracy, loss):
self.length = length
self.accuracy = accuracy
self.loss = loss

return {
"_name": "capture_demo",
"_proj": Result(4, 0.1, 0.1),
}

@export_vars
def test_capture_list_proj(self):
"""Mixed single + lists via '_name' + '_proj'"""

class Result:
def __init__(self, length, accuracy, loss):
self.length = length
self.accuracy = accuracy
self.loss = loss

return {
"_name": "capture_demo",
"_proj": [
Result(5, 0.1, 0.1),
Result(5, 0.2, 0.2),
Result(5, 0.3, 0.3),
],
}


# ---------------- Read Config Example ----------------
Expand Down