Skip to content

Commit cbcf33b

Browse files
CarltonXiangharvey_xiangfridayL
authored
Fix/timer log (#677)
* feat: timer false * feat: timer false * feat: add model log * feat: add model_name * feat: add model_name * feat: add model_name --------- Co-authored-by: harvey_xiang <harvey_xiang22@163.com> Co-authored-by: chunyu li <78344051+fridayL@users.noreply.github.com>
1 parent cb64336 commit cbcf33b

File tree

2 files changed

+32
-17
lines changed

2 files changed

+32
-17
lines changed

src/memos/llms/openai.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@ def __init__(self, config: OpenAILLMConfig):
2828
)
2929
logger.info("OpenAI LLM instance initialized")
3030

31-
@timed_with_status(log_prefix="OpenAI LLM", log_args=["model_name_or_path"])
31+
@timed_with_status(
32+
log_prefix="OpenAI LLM",
33+
log_extra_args=lambda self, messages, **kwargs: {
34+
"model_name_or_path": kwargs.get("model_name_or_path", self.config.model_name_or_path)
35+
},
36+
)
3237
def generate(self, messages: MessageList, **kwargs) -> str:
3338
"""Generate a response from OpenAI LLM, optionally overriding generation params."""
3439
response = self.client.chat.completions.create(
@@ -55,7 +60,12 @@ def generate(self, messages: MessageList, **kwargs) -> str:
5560
return reasoning_content + response_content
5661
return response_content
5762

58-
@timed_with_status(log_prefix="OpenAI LLM", log_args=["model_name_or_path"])
63+
@timed_with_status(
64+
log_prefix="OpenAI LLM",
65+
log_extra_args=lambda self, messages, **kwargs: {
66+
"model_name_or_path": self.config.model_name_or_path
67+
},
68+
)
5969
def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]:
6070
"""Stream response from OpenAI LLM with optional reasoning support."""
6171
if kwargs.get("tools"):

src/memos/utils.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ def timed_with_status(
2222
Parameters:
2323
- log: enable timing logs (default True)
2424
- log_prefix: prefix; falls back to function name
25-
- log_args: names to include in logs (str or list/tuple of str).
26-
- log_extra_args: extra arguments to include in logs (dict). If it contains
27-
key "time_threshold", use its value (in seconds) as the logging threshold; otherwise
28-
fall back to DEFAULT_TIME_BAR.
25+
- log_args: names to include in logs (str or list/tuple of str), values are taken from kwargs by name.
26+
- log_extra_args:
27+
- can be a dict: fixed contextual fields that are always attached to logs;
28+
- or a callable: like `fn(*args, **kwargs) -> dict`, used to dynamically generate contextual fields at runtime.
2929
"""
3030

3131
if isinstance(log_args, str):
@@ -56,12 +56,24 @@ def wrapper(*args, **kwargs):
5656
elapsed_ms = (time.perf_counter() - start) * 1000.0
5757

5858
ctx_parts = []
59+
# 1) Collect parameters from kwargs by name
5960
for key in effective_log_args:
6061
val = kwargs.get(key)
6162
ctx_parts.append(f"{key}={val}")
6263

63-
if log_extra_args:
64-
ctx_parts.extend(f"{key}={val}" for key, val in log_extra_args.items())
64+
# 2) Support log_extra_args as dict or callable, so we can dynamically
65+
# extract values from self or other runtime context
66+
extra_items = {}
67+
try:
68+
if callable(log_extra_args):
69+
extra_items = log_extra_args(*args, **kwargs) or {}
70+
elif isinstance(log_extra_args, dict):
71+
extra_items = log_extra_args
72+
except Exception as e:
73+
logger.warning(f"[TIMER_WITH_STATUS] log_extra_args callback error: {e!r}")
74+
75+
if extra_items:
76+
ctx_parts.extend(f"{key}={val}" for key, val in extra_items.items())
6577

6678
ctx_str = f" [{', '.join(ctx_parts)}]" if ctx_parts else ""
6779

@@ -75,15 +87,8 @@ def wrapper(*args, **kwargs):
7587
f"[TIMER_WITH_STATUS] {log_prefix or fn.__name__} "
7688
f"took {elapsed_ms:.0f} ms{status_info}, args: {ctx_str}"
7789
)
78-
threshold_ms = DEFAULT_TIME_BAR * 1000.0
79-
if log_extra_args and "time_threshold" in log_extra_args:
80-
try:
81-
threshold_ms = float(log_extra_args["time_threshold"]) * 1000.0
82-
except Exception:
83-
threshold_ms = DEFAULT_TIME_BAR * 1000.0
8490

85-
if elapsed_ms >= threshold_ms:
86-
logger.info(msg)
91+
logger.info(msg)
8792

8893
return wrapper
8994

@@ -92,7 +97,7 @@ def wrapper(*args, **kwargs):
9297
return decorator(func)
9398

9499

95-
def timed(func=None, *, log=True, log_prefix=""):
100+
def timed(func=None, *, log=False, log_prefix=""):
96101
def decorator(fn):
97102
def wrapper(*args, **kwargs):
98103
start = time.perf_counter()

0 commit comments

Comments
 (0)