Skip to content
14 changes: 12 additions & 2 deletions src/memos/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ def __init__(self, config: OpenAILLMConfig):
)
logger.info("OpenAI LLM instance initialized")

@timed_with_status(log_prefix="OpenAI LLM", log_args=["model_name_or_path"])
@timed_with_status(
log_prefix="OpenAI LLM",
log_extra_args=lambda self, messages, **kwargs: {
"model_name_or_path": kwargs.get("model_name_or_path", self.config.model_name_or_path)
},
)
def generate(self, messages: MessageList, **kwargs) -> str:
"""Generate a response from OpenAI LLM, optionally overriding generation params."""
response = self.client.chat.completions.create(
Expand All @@ -55,7 +60,12 @@ def generate(self, messages: MessageList, **kwargs) -> str:
return reasoning_content + response_content
return response_content

@timed_with_status(log_prefix="OpenAI LLM", log_args=["model_name_or_path"])
@timed_with_status(
log_prefix="OpenAI LLM",
log_extra_args=lambda self, messages, **kwargs: {
"model_name_or_path": self.config.model_name_or_path
},
)
def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]:
"""Stream response from OpenAI LLM with optional reasoning support."""
if kwargs.get("tools"):
Expand Down
35 changes: 20 additions & 15 deletions src/memos/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ def timed_with_status(
Parameters:
- log: enable timing logs (default True)
- log_prefix: prefix; falls back to function name
- log_args: names to include in logs (str or list/tuple of str).
- log_extra_args: extra arguments to include in logs (dict). If it contains
key "time_threshold", use its value (in seconds) as the logging threshold; otherwise
fall back to DEFAULT_TIME_BAR.
- log_args: names to include in logs (str or list/tuple of str), values are taken from kwargs by name.
- log_extra_args:
- can be a dict: fixed contextual fields that are always attached to logs;
- or a callable: like `fn(*args, **kwargs) -> dict`, used to dynamically generate contextual fields at runtime.
"""

if isinstance(log_args, str):
Expand Down Expand Up @@ -56,12 +56,24 @@ def wrapper(*args, **kwargs):
elapsed_ms = (time.perf_counter() - start) * 1000.0

ctx_parts = []
# 1) Collect parameters from kwargs by name
for key in effective_log_args:
val = kwargs.get(key)
ctx_parts.append(f"{key}={val}")

if log_extra_args:
ctx_parts.extend(f"{key}={val}" for key, val in log_extra_args.items())
# 2) Support log_extra_args as dict or callable, so we can dynamically
# extract values from self or other runtime context
extra_items = {}
try:
if callable(log_extra_args):
extra_items = log_extra_args(*args, **kwargs) or {}
elif isinstance(log_extra_args, dict):
extra_items = log_extra_args
except Exception as e:
logger.warning(f"[TIMER_WITH_STATUS] log_extra_args callback error: {e!r}")

if extra_items:
ctx_parts.extend(f"{key}={val}" for key, val in extra_items.items())

ctx_str = f" [{', '.join(ctx_parts)}]" if ctx_parts else ""

Expand All @@ -75,15 +87,8 @@ def wrapper(*args, **kwargs):
f"[TIMER_WITH_STATUS] {log_prefix or fn.__name__} "
f"took {elapsed_ms:.0f} ms{status_info}, args: {ctx_str}"
)
threshold_ms = DEFAULT_TIME_BAR * 1000.0
if log_extra_args and "time_threshold" in log_extra_args:
try:
threshold_ms = float(log_extra_args["time_threshold"]) * 1000.0
except Exception:
threshold_ms = DEFAULT_TIME_BAR * 1000.0

if elapsed_ms >= threshold_ms:
logger.info(msg)
logger.info(msg)

return wrapper

Expand All @@ -92,7 +97,7 @@ def wrapper(*args, **kwargs):
return decorator(func)


def timed(func=None, *, log=True, log_prefix=""):
def timed(func=None, *, log=False, log_prefix=""):
def decorator(fn):
def wrapper(*args, **kwargs):
start = time.perf_counter()
Expand Down