diff --git a/src/mcpcat/modules/constants.py b/src/mcpcat/modules/constants.py index 928660b..dd42f4d 100644 --- a/src/mcpcat/modules/constants.py +++ b/src/mcpcat/modules/constants.py @@ -4,3 +4,9 @@ EVENT_ID_PREFIX = "evt" MCPCAT_API_URL = "https://api.mcpcat.io" # Default API URL for MCPCat events DEFAULT_CONTEXT_DESCRIPTION = "Explain why you are calling this tool and how it fits into the user's overall goal. This parameter is used for analytics and user intent tracking. YOU MUST provide 15-25 words (count carefully). NEVER use first person ('I', 'we', 'you') - maintain third-person perspective. NEVER include sensitive information such as credentials, passwords, or personal data. Example (20 words): \"Searching across the organization's repositories to find all open issues related to performance complaints and latency issues for team prioritization.\"" + +# Maximum number of exceptions to capture in a cause chain +MAX_EXCEPTION_CHAIN_DEPTH = 10 + +# Maximum number of stack frames to capture per exception +MAX_STACK_FRAMES = 50 diff --git a/src/mcpcat/modules/exceptions.py b/src/mcpcat/modules/exceptions.py new file mode 100644 index 0000000..aec91b2 --- /dev/null +++ b/src/mcpcat/modules/exceptions.py @@ -0,0 +1,424 @@ +"""Exception tracking module for MCPCat.""" + +import contextvars +import linecache +import os +import re +import sys +import traceback +import types +from typing import Any + +from mcpcat.types import ChainedErrorData, ErrorData, StackFrame +from mcpcat.modules.constants import MAX_EXCEPTION_CHAIN_DEPTH, MAX_STACK_FRAMES + +_captured_error: contextvars.ContextVar[BaseException | None] = contextvars.ContextVar( + "_captured_error", default=None +) + + +def capture_exception(exc: BaseException | Any) -> ErrorData: + """ + Captures detailed exception information including stack traces and cause chains. + + This function extracts error metadata (type, message, stack trace) and + recursively unwraps __cause__ and __context__ chains. It parses Python + tracebacks into structured frames and detects whether each frame is user + code (in_app: True) or library code (in_app: False). + + Args: + exc: The error to capture (can be BaseException, string, object, or any value) + + Returns: + ErrorData dict with structured error information including platform="python" + """ + if is_call_tool_result(exc): + return capture_call_tool_result_error(exc) + + if not isinstance(exc, BaseException): + return { + "message": stringify_non_exception(exc), + "type": None, + "platform": "python", + } + + error_data: ErrorData = { + "message": str(exc), + "type": type(exc).__name__, + "platform": "python", + } + + if exc.__traceback__: + error_data["frames"] = parse_python_traceback(exc.__traceback__) + error_data["stack"] = format_exception_string(exc) + + chained_errors = unwrap_exception_chain(exc) + if chained_errors: + error_data["chained_errors"] = chained_errors + + return error_data + + +def parse_python_traceback(tb: types.TracebackType | None) -> list[StackFrame]: + """ + Parses Python traceback into structured StackFrame list. + + Iterates through the traceback chain, extracting module name, function name, + file paths, line numbers, and source context for each frame. + + Args: + tb: Traceback object from exception.__traceback__ + + Returns: + List of StackFrame dicts (limited to MAX_STACK_FRAMES) + """ + if tb is None: + return [] + + frames: list[StackFrame] = [] + current_tb = tb + count = 0 + + while current_tb is not None and count < MAX_STACK_FRAMES: + frame = current_tb.tb_frame + abs_path = os.path.abspath(frame.f_code.co_filename) + + try: + module = frame.f_globals.get("__name__") + except (AttributeError, KeyError): + module = None + + in_app = is_in_app(abs_path) + + frame_dict: StackFrame = { + "filename": filename_for_module(module, abs_path), + "abs_path": abs_path, + "function": frame.f_code.co_name or "", + "module": module or "", + "lineno": current_tb.tb_lineno, + "in_app": in_app, + } + + if in_app: + context = extract_context_line(abs_path, current_tb.tb_lineno) + if context: + frame_dict["context_line"] = context + + frames.append(frame_dict) + + current_tb = current_tb.tb_next + count += 1 + + return frames + + +def filename_for_module(module: str | None, abs_path: str) -> str: + """ + Creates module-relative filename from absolute path. + + Tries to extract path relative to the base module's location in sys.modules. + Falls back to absolute path if extraction fails. + + Examples: + module="myapp.views.admin", abs_path="/home/user/project/myapp/views/admin.py" + → Returns "myapp/views/admin.py" + + Args: + module: Python module name (e.g., "myapp.views.admin") + abs_path: Absolute file path + + Returns: + Module-relative filename or absolute path as fallback + """ + if not abs_path or not module: + return abs_path + + try: + # Convert compiled .pyc files to source .py paths + if abs_path.endswith(".pyc"): + abs_path = abs_path[:-1] + + # Extract root package name (e.g., "myapp" from "myapp.views.admin") + base_module = module.split(".", 1)[0] + + # Single-module case (no dots): just return the filename + if base_module == module: + return os.path.basename(abs_path) + + if base_module not in sys.modules: + return abs_path + + base_module_file = getattr(sys.modules[base_module], "__file__", None) + if not base_module_file: + return abs_path + + # Navigate up 2 levels from package's __init__.py to find project root + # e.g., /project/myapp/__init__.py → rsplit by separator twice → /project + base_module_dir = base_module_file.rsplit(os.sep, 2)[0] + + # Extract the path relative to the project root + if abs_path.startswith(base_module_dir): + return abs_path.split(base_module_dir, 1)[-1].lstrip(os.sep) + + return abs_path + except Exception: + return abs_path + + +def is_in_app(abs_path: str) -> bool: + """ + Determines if a file path represents user code (True) or library code (False). + + Library code is identified by: + - Paths containing site-packages or dist-packages + - Python stdlib paths + - Paths containing /lib/pythonX.Y/ + + Args: + abs_path: Absolute file path to check + + Returns: + True if user code, False if library code + """ + if not abs_path: + return False + + if re.search(r"[\\/](?:dist|site)-packages[\\/]", abs_path): + return False + + stdlib_paths = [sys.prefix, sys.base_prefix] + if hasattr(sys, "real_prefix"): # virtualenv + stdlib_paths.append(sys.real_prefix) + + for stdlib in stdlib_paths: + if not stdlib: + continue + stdlib_lib = os.path.join(stdlib, "lib") + normalized_stdlib = stdlib_lib.replace("\\", "/") + normalized_path = abs_path.replace("\\", "/") + if normalized_path.startswith(normalized_stdlib): + return False + + # Catches cases like Homebrew Python on macOS + python_version = f"python{sys.version_info.major}.{sys.version_info.minor}" + stdlib_pattern = f"/lib/{python_version}/" + if stdlib_pattern in abs_path.replace("\\", "/"): + return False + + return True + + +def extract_context_line(abs_path: str, lineno: int) -> str | None: + """ + Extracts the line of code at the specified line number. + + Uses linecache for efficient file reading and caching. + + Args: + abs_path: Absolute path to source file + lineno: Line number (1-indexed) + + Returns: + Source line as string, or None if unavailable + """ + if not abs_path or not lineno: + return None + + try: + line = linecache.getline(abs_path, lineno) + if line: + return line.rstrip("\n") + except Exception: + pass + + return None + + +def unwrap_exception_chain(exc: BaseException) -> list[ChainedErrorData]: + """ + Recursively unwraps __cause__ and __context__ chains. + + Checks __suppress_context__ to determine which chain to follow: + - If True, follows __cause__ (explicit: raise ... from ...) + - If False, follows __context__ (implicit context) + + Uses id() tracking to prevent circular references. + + Args: + exc: Base exception to unwrap + + Returns: + List of ChainedErrorData dicts representing the error chain + """ + chain: list[ChainedErrorData] = [] + seen_ids: set[int] = set() + current: BaseException | None = exc + depth = 0 + + seen_ids.add(id(exc)) + + while current is not None and depth < MAX_EXCEPTION_CHAIN_DEPTH: + if getattr(current, "__suppress_context__", False): + next_exc = getattr(current, "__cause__", None) + else: + next_exc = getattr(current, "__context__", None) + + if next_exc is None: + break + + exc_id = id(next_exc) + if exc_id in seen_ids: + break + seen_ids.add(exc_id) + + if not isinstance(next_exc, BaseException): + chain.append( + { + "message": stringify_non_exception(next_exc), + "type": None, + } + ) + break + + chained_data: ChainedErrorData = { + "message": str(next_exc), + "type": type(next_exc).__name__, + } + + if next_exc.__traceback__: + chained_data["frames"] = parse_python_traceback(next_exc.__traceback__) + chained_data["stack"] = format_exception_string(next_exc) + + chain.append(chained_data) + current = next_exc + depth += 1 + + # TODO: Add ExceptionGroup support for Python 3.11+ + # ExceptionGroups have .exceptions attribute with multiple exceptions + + return chain + + +def format_exception_string(exc: BaseException) -> str: + """ + Formats exception into full traceback string. + + Similar to error.stack in JavaScript - captures the complete formatted traceback + including exception type, message, and stack frames. + + Args: + exc: Exception to format + + Returns: + Formatted traceback string + """ + try: + return "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)) + except Exception: + return f"{type(exc).__name__}: {exc}" + + +def is_call_tool_result(value: Any) -> bool: + """ + Detects if a value is a CallToolResult object. + + MCP SDK converts errors to CallToolResult format: + { content: [{ type: "text", text: "error message" }], isError: true } + + Args: + value: Value to check + + Returns: + True if value is a CallToolResult object + """ + return ( + value is not None + and hasattr(value, "isError") + and hasattr(value, "content") + and isinstance(getattr(value, "content", None), list) + ) + + +def capture_call_tool_result_error(result: Any) -> ErrorData: + """ + Extracts error information from CallToolResult objects. + + MCP SDK converts exceptions to CallToolResult, losing original stack traces. + This extracts the error message from the content array. + + Args: + result: CallToolResult object with error + + Returns: + ErrorData with extracted message (no stack trace available) + """ + message = "Unknown error" + + try: + if hasattr(result, "content"): + text_parts = [] + for item in result.content: + if ( + hasattr(item, "type") + and item.type == "text" + and hasattr(item, "text") + ): + text_parts.append(item.text) + if text_parts: + message = " ".join(text_parts).strip() + except Exception: + pass + + return { + "message": message, + "type": None, + "platform": "python", + } + + +def stringify_non_exception(value: Any) -> str: + """ + Converts non-exception objects to string representation for error messages. + + In Python, anything can be raised (though it should be BaseException subclass). + This handles edge cases by converting them to meaningful strings. + + Args: + value: Non-exception value that was raised + + Returns: + String representation of the value + """ + if value is None: + return "None" + + if isinstance(value, str): + return value + + if isinstance(value, (int, float, bool)): + return str(value) + + try: + import json + + return json.dumps(value) + except Exception: + return str(value) + + +def store_captured_error(exc: BaseException) -> None: + """Stores exception in context variable before MCP SDK processing.""" + _captured_error.set(exc) + + +def get_captured_error() -> BaseException | None: + """Retrieves and clears stored exception from context variable.""" + exc = _captured_error.get() + if exc is not None: + _captured_error.set(None) + return exc + + +def clear_captured_error() -> None: + """Clears any stored exception from context variable.""" + _captured_error.set(None) diff --git a/src/mcpcat/modules/overrides/community/monkey_patch.py b/src/mcpcat/modules/overrides/community/monkey_patch.py index edc7263..2a37075 100644 --- a/src/mcpcat/modules/overrides/community/monkey_patch.py +++ b/src/mcpcat/modules/overrides/community/monkey_patch.py @@ -12,6 +12,7 @@ from mcpcat.modules import event_queue from mcpcat.modules.compatibility import is_mcp_error_response +from mcpcat.modules.exceptions import capture_exception from mcpcat.modules.identify import identify_session from mcpcat.modules.internal import get_server_tracking_data from mcpcat.modules.logging import write_to_log @@ -19,7 +20,6 @@ get_client_info_from_request_context, get_server_session_id, ) -from mcpcat.modules.tools import handle_report_missing from mcpcat.types import EventType, UnredactedEvent from ..mcp_server import override_lowlevel_mcp_server_minimal, safe_request_context @@ -144,7 +144,11 @@ async def wrapped_call_tool_handler(request: CallToolRequest) -> ServerResult: # Check for errors is_error, error_message = is_mcp_error_response(result) event.is_error = is_error - event.error = {"message": error_message} if is_error else None + # Use full exception capture if there's an error + if is_error: + event.error = capture_exception(result) + else: + event.error = None event.response = result.model_dump() if result else None return result @@ -152,7 +156,17 @@ async def wrapped_call_tool_handler(request: CallToolRequest) -> ServerResult: except Exception as e: write_to_log(f"Error in wrapped_call_tool_handler: {e}") event.is_error = True - event.error = {"message": str(e)} + # Use full exception capture with stack trace + try: + event.error = capture_exception(e) + except Exception as capture_err: + # Fallback to simple error if capture fails + write_to_log(f"Error capturing exception: {capture_err}") + event.error = { + "message": str(e), + "type": type(e).__name__, + "platform": "python", + } raise finally: # Always publish event if tracing is enabled diff --git a/src/mcpcat/modules/overrides/official/monkey_patch.py b/src/mcpcat/modules/overrides/official/monkey_patch.py index 20a0b46..eabe4a6 100644 --- a/src/mcpcat/modules/overrides/official/monkey_patch.py +++ b/src/mcpcat/modules/overrides/official/monkey_patch.py @@ -7,10 +7,16 @@ import inspect from collections.abc import Callable from datetime import datetime, timezone -from typing import Any, List, Optional +from typing import Any, List from mcpcat.modules import event_queue from mcpcat.modules.compatibility import is_official_fastmcp_server, is_mcp_error_response +from mcpcat.modules.exceptions import ( + capture_exception, + clear_captured_error, + get_captured_error, + store_captured_error, +) from mcpcat.modules.internal import ( get_original_method, get_server_tracking_data, @@ -300,26 +306,48 @@ async def patched_call_tool( write_to_log(f"Error preparing arguments: {e}") args_for_tool = arguments # Use original if modification fails + # Clear any previous captured error before execution + clear_captured_error() + # Call original method - THIS IS CRITICAL, must not fail if not callable(original_call_tool): write_to_log("Critical: original_call_tool is not callable") raise ValueError("Original call_tool method is not callable") - result = await original_call_tool( - name, args_for_tool, context=context, **kwargs - ) + # Wrap execution to preserve exceptions before MCP SDK processes them + try: + result = await original_call_tool( + name, args_for_tool, context=context, **kwargs + ) + except Exception as tool_exc: + # Preserve original exception before MCP SDK converts it + store_captured_error(tool_exc) + raise # Re-raise so MCP SDK can handle it normally # Try to capture response in event (non-critical) if event: try: + # Check if result indicates an error (CallToolResult with isError=True) + is_error_result = False + if hasattr(result, "model_dump"): + is_error_result, error_message = is_mcp_error_response(result) + + if is_error_result: + # MCP SDK converted an exception to CallToolResult + event.is_error = True + + # Try to use preserved exception first (has full traceback) + captured = get_captured_error() + if captured: + event.error = capture_exception(captured) + else: + # Fallback: extract from CallToolResult + event.error = capture_exception(result) + + # Capture response data if isinstance(result, tuple): event.response = result[1] if len(result) > 1 else None elif hasattr(result, "model_dump"): - is_error, error_message = is_mcp_error_response(result) - event.is_error = is_error - event.error = ( - {"message": error_message} if is_error else None - ) event.response = result.model_dump() elif isinstance(result, dict): event.response = result @@ -347,9 +375,19 @@ async def patched_call_tool( if event: try: event.is_error = True - event.error = {"message": str(e)} - except: - pass + # Use full exception capture with stack trace + event.error = capture_exception(e) + except Exception as capture_err: + # Fallback to simple error if capture fails + write_to_log(f"Error capturing exception: {capture_err}") + try: + event.error = { + "message": str(e), + "type": type(e).__name__, + "platform": "python", + } + except: + pass # Re-raise to preserve original error behavior raise diff --git a/src/mcpcat/types.py b/src/mcpcat/types.py index 413e58a..dc36743 100644 --- a/src/mcpcat/types.py +++ b/src/mcpcat/types.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, field from datetime import datetime from enum import Enum -from typing import Any, Dict, Optional, Set, TypedDict, Literal, Union +from typing import Any, Dict, Optional, Set, TypedDict, Literal, Union, NotRequired from mcpcat_api import PublishEventRequest from pydantic import BaseModel @@ -44,6 +44,37 @@ class Event(PublishEventRequest): pass +# Error tracking types + +class StackFrame(TypedDict, total=False): + """Stack frame information for error tracking.""" + filename: str + abs_path: str + function: str # Function name or "" + module: str + lineno: int + in_app: bool + context_line: NotRequired[str] + + +class ChainedErrorData(TypedDict, total=False): + """Chained exception data (from __cause__ or __context__).""" + message: str + type: NotRequired[str | None] + stack: NotRequired[str] + frames: NotRequired[list[StackFrame]] + + +class ErrorData(TypedDict, total=False): + """Complete error information for an exception.""" + message: str + type: NotRequired[str | None] # Exception class name (e.g., "ValueError", "TypeError") + stack: NotRequired[str] + frames: NotRequired[list[StackFrame]] + chained_errors: NotRequired[list[ChainedErrorData]] + platform: str # Platform identifier (always "python") + + class EventType(str, Enum): """MCP event types.""" diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py new file mode 100644 index 0000000..eb5c5df --- /dev/null +++ b/tests/test_exceptions.py @@ -0,0 +1,672 @@ +"""Tests for exception tracking functionality.""" + +import os +import tempfile +import time +from unittest.mock import MagicMock + +import pytest + +from mcpcat import MCPCatOptions, track +from mcpcat.modules.event_queue import EventQueue, set_event_queue +from mcpcat.modules.exceptions import ( + capture_exception, + extract_context_line, + filename_for_module, + format_exception_string, + is_in_app, + parse_python_traceback, + stringify_non_exception, +) + +from .test_utils.client import create_test_client +from .test_utils.todo_server import create_todo_server + + +class TestBasicExceptionCapture: + """Tests for basic exception capture functionality.""" + + def test_capture_simple_exception(self): + """Test capturing a simple ValueError.""" + try: + raise ValueError("test error") + except ValueError as e: + error_data = capture_exception(e) + + assert error_data["message"] == "test error" + assert error_data["type"] == "ValueError" + assert error_data["platform"] == "python" + assert "frames" in error_data + assert len(error_data["frames"]) > 0 + assert "stack" in error_data + + def test_capture_exception_with_stack_trace(self): + """Test that stack trace is properly captured.""" + try: + raise RuntimeError("runtime error") + except RuntimeError as e: + error_data = capture_exception(e) + + assert "frames" in error_data + frames = error_data["frames"] + assert len(frames) > 0 + + # Check frame structure + first_frame = frames[0] + assert "filename" in first_frame + assert "abs_path" in first_frame + assert "function" in first_frame + assert "module" in first_frame + assert "lineno" in first_frame + assert "in_app" in first_frame + + def test_capture_exception_without_traceback(self): + """Test capturing exception without traceback.""" + exc = ValueError("no traceback") + # Create exception without raising it (no __traceback__) + error_data = capture_exception(exc) + + assert error_data["message"] == "no traceback" + assert error_data["type"] == "ValueError" + assert error_data["platform"] == "python" + # No traceback means no frames or stack + assert "frames" not in error_data or len(error_data.get("frames", [])) == 0 + + def test_module_extraction(self): + """Test that module names are properly extracted.""" + try: + raise TypeError("type error") + except TypeError as e: + error_data = capture_exception(e) + + frames = error_data["frames"] + assert len(frames) > 0 + + # At least one frame should have __name__ from current module + has_module = any(frame.get("module") for frame in frames) + assert has_module + + +class TestErrorChainUnwrapping: + """Tests for exception chain unwrapping.""" + + def test_explicit_chaining_with_from(self): + """Test explicit exception chaining (raise ... from ...).""" + try: + try: + raise ValueError("root cause") + except ValueError as e: + raise RuntimeError("wrapper error") from e + except RuntimeError as e: + error_data = capture_exception(e) + + assert error_data["message"] == "wrapper error" + assert error_data["type"] == "RuntimeError" + assert "chained_errors" in error_data + + chained = error_data["chained_errors"] + assert len(chained) == 1 + assert chained[0]["message"] == "root cause" + assert chained[0]["type"] == "ValueError" + + def test_implicit_chaining_context(self): + """Test implicit exception chaining (__context__).""" + try: + try: + raise ValueError("first error") + except ValueError: + # Implicit chaining - new exception during except block + raise TypeError("second error") + except TypeError as e: + error_data = capture_exception(e) + + assert error_data["type"] == "TypeError" + assert "chained_errors" in error_data + + chained = error_data["chained_errors"] + assert len(chained) == 1 + assert chained[0]["type"] == "ValueError" + + def test_circular_reference_prevention(self): + """Test that circular exception chains are handled.""" + # Create circular reference manually + exc1 = ValueError("error 1") + exc2 = RuntimeError("error 2") + + # Create circular chain (this shouldn't happen normally but we handle it) + exc1.__cause__ = exc2 + exc2.__cause__ = exc1 + exc1.__suppress_context__ = True + exc2.__suppress_context__ = True + + # Should not infinite loop + error_data = capture_exception(exc1) + + assert error_data["type"] == "ValueError" + # Should have stopped due to circular detection + assert "chained_errors" in error_data + chained = error_data["chained_errors"] + # Should have one (exc2) before detecting the circle back to exc1 + assert len(chained) == 1 + + def test_max_depth_limiting(self): + """Test that deep exception chains are limited.""" + # Create a chain deeper than MAX_EXCEPTION_CHAIN_DEPTH (10) + exc = ValueError("root") + current = exc + + for i in range(15): + new_exc = RuntimeError(f"error {i}") + new_exc.__cause__ = current + new_exc.__suppress_context__ = True + current = new_exc + + error_data = capture_exception(current) + + # Should have limited the chain to 10 + assert "chained_errors" in error_data + assert len(error_data["chained_errors"]) <= 10 + + +class TestInAppDetection: + """Tests for in_app detection.""" + + def test_user_code_is_in_app(self): + """Test that user code is marked as in_app=True.""" + # Current test file should be user code + current_file = os.path.abspath(__file__) + assert is_in_app(current_file) is True + + def test_site_packages_not_in_app(self): + """Test that site-packages code is marked as in_app=False.""" + # Create a fake site-packages path + fake_path = "/usr/local/lib/python3.10/site-packages/requests/api.py" + assert is_in_app(fake_path) is False + + fake_path2 = "/home/user/.local/lib/python3.9/dist-packages/numpy/core.py" + assert is_in_app(fake_path2) is False + + def test_stdlib_not_in_app(self): + """Test that Python stdlib is marked as in_app=False.""" + # Check actual stdlib module + import json + + json_file = json.__file__ + if json_file: + assert is_in_app(json_file) is False + + def test_empty_path_not_in_app(self): + """Test that empty path returns False.""" + assert is_in_app("") is False + assert is_in_app(None) is False # type: ignore + + +class TestPathNormalization: + """Tests for path normalization.""" + + def test_filename_for_module_with_package(self): + """Test filename_for_module with package module.""" + # Test with this test module + test_module = __name__ + test_file = __file__ + + result = filename_for_module(test_module, test_file) + + # Should be more relative than absolute + assert not result.startswith("/home/") and not result.startswith("/Users/") + # Should contain the filename + assert "test_exceptions.py" in result + + def test_filename_for_module_strips_pyc(self): + """Test that .pyc extension is stripped.""" + module = "mymodule" + path = "/path/to/mymodule.pyc" + + result = filename_for_module(module, path) + + # Should have stripped .pyc + assert ".pyc" not in result + + def test_filename_for_module_simple_module(self): + """Test simple module returns basename.""" + module = "simple" + path = "/path/to/simple.py" + + result = filename_for_module(module, path) + + # Simple module should return basename + assert result == "simple.py" + + def test_filename_for_module_fallback(self): + """Test that it falls back to abs_path on error.""" + module = "nonexistent.module.name" + path = "/some/path/file.py" + + result = filename_for_module(module, path) + + # Should fall back to original path + assert result == path + + +class TestContextExtraction: + """Tests for source context extraction.""" + + def test_extract_context_line(self): + """Test extracting context line from source file.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write("line 1\n") + f.write("line 2\n") + f.write("line 3 with error\n") + f.write("line 4\n") + temp_path = f.name + + try: + assert extract_context_line(temp_path, 3) == "line 3 with error" + assert extract_context_line(temp_path, 1) == "line 1" + finally: + os.unlink(temp_path) + + def test_extract_context_line_handles_missing_file(self): + """Test that missing files are handled gracefully.""" + assert extract_context_line("/nonexistent/file.py", 1) is None + + +class TestNonExceptionHandling: + """Tests for non-exception object handling.""" + + def test_capture_string_error(self): + """Test capturing a string that was raised.""" + error_data = capture_exception("string error") + + assert error_data["message"] == "string error" + assert error_data["type"] is None # Unknown type for non-exceptions + assert error_data["platform"] == "python" + assert "frames" not in error_data + + def test_capture_none(self): + """Test capturing None.""" + error_data = capture_exception(None) + + assert error_data["message"] == "None" + assert error_data["type"] is None # Unknown type for non-exceptions + + def test_capture_dict(self): + """Test capturing a dict.""" + error_data = capture_exception({"code": 404, "message": "not found"}) + + assert "404" in error_data["message"] + assert "not found" in error_data["message"] + assert error_data["type"] is None # Unknown type for non-exceptions + + def test_stringify_non_exception(self): + """Test stringify_non_exception helper.""" + assert stringify_non_exception(None) == "None" + assert stringify_non_exception("test") == "test" + assert stringify_non_exception(42) == "42" + assert stringify_non_exception(True) == "True" + + +class TestStackFrameParsing: + """Tests for stack frame parsing.""" + + def test_parse_traceback_with_frames(self): + """Test parsing traceback with multiple frames.""" + + def inner_function(): + raise ValueError("inner error") + + def outer_function(): + inner_function() + + try: + outer_function() + except ValueError as e: + frames = parse_python_traceback(e.__traceback__) + + assert len(frames) > 0 + + # Check that we have frames from both functions + function_names = [f["function"] for f in frames] + assert "inner_function" in function_names + assert "outer_function" in function_names + + def test_parse_none_traceback(self): + """Test parsing None traceback.""" + frames = parse_python_traceback(None) + assert frames == [] + + def test_frame_limit(self): + """Test that frames are limited to MAX_STACK_FRAMES.""" + + def recursive_function(depth): + if depth <= 0: + raise ValueError("deep error") + recursive_function(depth - 1) + + try: + # Create a very deep stack (more than 50 frames) + recursive_function(60) + except ValueError as e: + frames = parse_python_traceback(e.__traceback__) + + # Should be limited to 50 frames + assert len(frames) <= 50 + + +class TestFormatExceptionString: + """Tests for exception string formatting.""" + + def test_format_exception_with_traceback(self): + """Test formatting exception with traceback.""" + try: + raise ValueError("format test") + except ValueError as e: + formatted = format_exception_string(e) + + assert "ValueError" in formatted + assert "format test" in formatted + assert "Traceback" in formatted + + def test_format_exception_without_traceback(self): + """Test formatting exception without traceback.""" + exc = RuntimeError("no tb") + formatted = format_exception_string(exc) + + # Should still format even without traceback + assert "RuntimeError" in formatted + assert "no tb" in formatted + + +class TestEdgeCases: + """Tests for edge cases and error handling.""" + + def test_very_deep_stack(self): + """Test handling very deep stacks.""" + + def deep_recursion(n): + if n <= 0: + raise ValueError("deep") + return deep_recursion(n - 1) + + try: + deep_recursion(100) + except ValueError as e: + error_data = capture_exception(e) + + # Should handle deep stack without crashing + assert error_data["type"] == "ValueError" + assert "frames" in error_data + assert len(error_data["frames"]) <= 50 + + def test_exception_with_special_characters(self): + """Test exception with special characters in message.""" + try: + raise ValueError("Error with émojis 🔥 and\nnewhlines\ttabs") + except ValueError as e: + error_data = capture_exception(e) + + assert "émojis" in error_data["message"] + assert "🔥" in error_data["message"] + + def test_capture_preserves_all_fields(self): + """Test that all important fields are captured.""" + try: + raise KeyError("missing key") + except KeyError as e: + error_data = capture_exception(e) + + # Check all expected fields + assert "message" in error_data + assert "type" in error_data + assert "platform" in error_data + assert error_data["platform"] == "python" + + if e.__traceback__: + assert "frames" in error_data + assert "stack" in error_data + + +class TestExceptionIntegration: + """Integration tests for exception capture with real MCP server calls.""" + + @pytest.fixture(autouse=True) + def setup_and_teardown(self): + """Set up and tear down for each test.""" + from mcpcat.modules.event_queue import event_queue as original_queue + + yield + set_event_queue(original_queue) + + def _create_mock_event_capture(self): + """Helper to create mock API client and event capture list.""" + mock_api_client = MagicMock() + captured_events = [] + + def capture_event(publish_event_request): + captured_events.append(publish_event_request) + + mock_api_client.publish_event = MagicMock(side_effect=capture_event) + + test_queue = EventQueue(api_client=mock_api_client) + set_event_queue(test_queue) + + return captured_events + + @pytest.mark.asyncio + async def test_tool_raises_value_error(self): + """Test that ValueError from tools is properly captured.""" + captured_events = self._create_mock_event_capture() + + server = create_todo_server() + options = MCPCatOptions(enable_tracing=True) + track(server, "test_project", options) + + async with create_test_client(server) as client: + await client.call_tool("tool_that_raises", {"error_type": "value"}) + time.sleep(1.0) + + tool_events = [ + e + for e in captured_events + if e.event_type == "mcp:tools/call" + and e.resource_name == "tool_that_raises" + ] + assert len(tool_events) > 0, "No tool_that_raises event captured" + + event = tool_events[0] + assert event.is_error is True + assert event.error is not None + # MCP SDK wraps tool exceptions - check the message contains the original error + assert "Test value error from tool" in event.error["message"] + + @pytest.mark.asyncio + async def test_tool_raises_runtime_error(self): + """Test that RuntimeError from tools is properly captured.""" + captured_events = self._create_mock_event_capture() + + server = create_todo_server() + options = MCPCatOptions(enable_tracing=True) + track(server, "test_project", options) + + async with create_test_client(server) as client: + await client.call_tool("tool_that_raises", {"error_type": "runtime"}) + time.sleep(1.0) + + tool_events = [ + e + for e in captured_events + if e.event_type == "mcp:tools/call" + and e.resource_name == "tool_that_raises" + ] + assert len(tool_events) > 0 + + event = tool_events[0] + assert event.is_error is True + assert event.error is not None + # MCP SDK wraps tool exceptions - check the message contains the original error + assert "Test runtime error from tool" in event.error["message"] + + @pytest.mark.asyncio + async def test_tool_raises_custom_error(self): + """Test that custom exception types are properly captured.""" + captured_events = self._create_mock_event_capture() + + server = create_todo_server() + options = MCPCatOptions(enable_tracing=True) + track(server, "test_project", options) + + async with create_test_client(server) as client: + await client.call_tool("tool_that_raises", {"error_type": "custom"}) + time.sleep(1.0) + + tool_events = [ + e + for e in captured_events + if e.event_type == "mcp:tools/call" + and e.resource_name == "tool_that_raises" + ] + assert len(tool_events) > 0 + + event = tool_events[0] + assert event.is_error is True + assert event.error is not None + # MCP SDK wraps tool exceptions - check the message contains the original error + assert "Test custom error from tool" in event.error["message"] + + @pytest.mark.asyncio + async def test_tool_raises_captures_stack_frames(self): + """Test that stack frames are properly captured with correct structure.""" + captured_events = self._create_mock_event_capture() + + server = create_todo_server() + options = MCPCatOptions(enable_tracing=True) + track(server, "test_project", options) + + async with create_test_client(server) as client: + await client.call_tool("tool_that_raises", {"error_type": "value"}) + time.sleep(1.0) + + tool_events = [ + e + for e in captured_events + if e.event_type == "mcp:tools/call" + and e.resource_name == "tool_that_raises" + ] + assert len(tool_events) > 0 + + event = tool_events[0] + assert event.error is not None + + # Verify frames are captured + frames = event.error.get("frames", []) + assert len(frames) > 0, "No stack frames captured" + + # Verify frame structure + for frame in frames: + assert "filename" in frame + assert "abs_path" in frame + assert "function" in frame + assert "module" in frame + assert "lineno" in frame + assert "in_app" in frame + assert isinstance(frame["lineno"], int) + assert isinstance(frame["in_app"], bool) + + @pytest.mark.asyncio + async def test_tool_raises_has_in_app_frames(self): + """Test that stack frames include in_app detection.""" + captured_events = self._create_mock_event_capture() + + server = create_todo_server() + options = MCPCatOptions(enable_tracing=True) + track(server, "test_project", options) + + async with create_test_client(server) as client: + await client.call_tool("tool_that_raises", {"error_type": "value"}) + time.sleep(1.0) + + tool_events = [ + e + for e in captured_events + if e.event_type == "mcp:tools/call" + and e.resource_name == "tool_that_raises" + ] + assert len(tool_events) > 0 + + event = tool_events[0] + frames = event.error.get("frames", []) + assert len(frames) > 0, "Should have stack frames" + + # All frames should have in_app field + for frame in frames: + assert "in_app" in frame, "Frame should have in_app field" + assert isinstance(frame["in_app"], bool) + + # Verify we have a mix of in_app and not in_app (sdk code is not in_app) + # Note: MCP SDK wraps the error, so the original tool function may not appear + # but we still verify the in_app detection logic works + + @pytest.mark.asyncio + async def test_tool_raises_captures_context_lines(self): + """Test that context lines are captured for in_app frames.""" + captured_events = self._create_mock_event_capture() + + server = create_todo_server() + options = MCPCatOptions(enable_tracing=True) + track(server, "test_project", options) + + async with create_test_client(server) as client: + await client.call_tool("tool_that_raises", {"error_type": "value"}) + time.sleep(1.0) + + tool_events = [ + e + for e in captured_events + if e.event_type == "mcp:tools/call" + and e.resource_name == "tool_that_raises" + ] + assert len(tool_events) > 0 + + event = tool_events[0] + frames = event.error.get("frames", []) + in_app_frames = [f for f in frames if f.get("in_app") is True] + + # In-app frames should have context_line + frames_with_context = [f for f in in_app_frames if f.get("context_line")] + assert len(frames_with_context) > 0, "No context lines for in_app frames" + + # Context line should contain actual code + for frame in frames_with_context: + context = frame["context_line"] + assert len(context) > 0 + assert context.strip() != "" + + @pytest.mark.asyncio + async def test_mcp_protocol_error(self): + """Test that MCP protocol errors (McpError) are properly handled.""" + captured_events = self._create_mock_event_capture() + + server = create_todo_server() + options = MCPCatOptions(enable_tracing=True) + track(server, "test_project", options) + + async with create_test_client(server) as client: + await client.call_tool("tool_with_mcp_error", {}) + time.sleep(1.0) + + tool_events = [ + e + for e in captured_events + if e.event_type == "mcp:tools/call" + and e.resource_name == "tool_with_mcp_error" + ] + assert len(tool_events) > 0, "No tool_with_mcp_error event captured" + + event = tool_events[0] + assert event.is_error is True + assert event.error is not None + assert "Invalid parameters" in event.error["message"] + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_utils/todo_server.py b/tests/test_utils/todo_server.py index 239567c..7383b59 100644 --- a/tests/test_utils/todo_server.py +++ b/tests/test_utils/todo_server.py @@ -1,6 +1,7 @@ """Todo server implementation for testing.""" -from mcp.server import Server +from mcp.shared.exceptions import McpError +from mcp.types import ErrorData try: from mcp.server import FastMCP @@ -11,6 +12,16 @@ HAS_FASTMCP = False +# Standard JSON-RPC error codes +INVALID_PARAMS = -32602 + + +class CustomTestError(Exception): + """Custom exception type for testing exception capture.""" + + pass + + class Todo: """Todo item.""" @@ -64,11 +75,30 @@ def complete_todo(id: int) -> str: raise ValueError(f"Todo with ID {id} not found") + @server.tool() + def tool_that_raises(error_type: str = "value") -> str: + """A tool that raises Python exceptions for testing.""" + if error_type == "value": + raise ValueError("Test value error from tool") + elif error_type == "runtime": + raise RuntimeError("Test runtime error from tool") + elif error_type == "custom": + raise CustomTestError("Test custom error from tool") + return "Should not reach here" + + @server.tool() + def tool_with_mcp_error() -> str: + """A tool that returns an MCP protocol error.""" + error = ErrorData(code=INVALID_PARAMS, message="Invalid parameters") + raise McpError(error) + # Store original handlers for testing server._original_handlers = { "add_todo": add_todo, "list_todos": list_todos, "complete_todo": complete_todo, + "tool_that_raises": tool_that_raises, + "tool_with_mcp_error": tool_with_mcp_error, } return server