Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions marimo/_server/api/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,18 @@ def __init__(self, request: Union[Request, WebSocket]) -> None:
super().__init__(request.app.state)
self.request = request

self._cached_session_id: Optional[SessionId] = None

def get_current_session_id(self) -> Optional[SessionId]:
"""Get the current session."""
session_id = self.request.headers.get("Marimo-Session-Id")
return SessionId(session_id) if session_id is not None else None
# Cache lookup for headers for slight speedup
session_id = self._cached_session_id
if session_id is None:
hdrs = self.request.headers
val = hdrs.get("Marimo-Session-Id")
session_id = SessionId(val) if val is not None else None
self._cached_session_id = session_id
return session_id

def require_current_session_id(self) -> SessionId:
"""Get the current session or raise an error."""
Expand All @@ -148,9 +156,14 @@ def require_current_session_id(self) -> SessionId:
def get_current_session(self) -> Optional[Session]:
"""Get the current session."""
session_id = self.get_current_session_id()
if session_id is None:
if not session_id:
return None
# Cache common attribute lookup
manager = getattr(self, "session_manager", None)
if manager is None:
# Defensive: if session_manager is not set, fall back
return None
return self.session_manager.get_session(session_id)
return manager.get_session(session_id)

def require_current_session(self) -> Session:
"""Get the current session or raise an error."""
Expand Down
25 changes: 17 additions & 8 deletions marimo/_server/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,21 @@ async def handle_error(request: Request, response: Any) -> Any:
return JSONResponse({"detail": str(response)}, status_code=500)
except Exception as e:
LOGGER.warning(f"Failed to send missing package alert: {e}")
if isinstance(response, msgspec.ValidationError):
return JSONResponse({"detail": str(response)}, status_code=400)
if isinstance(response, NotImplementedError):
return JSONResponse({"detail": "Not supported"}, status_code=501)
if isinstance(response, TypeError):
return JSONResponse({"detail": str(response)}, status_code=500)
if isinstance(response, Exception):
return JSONResponse({"detail": str(response)}, status_code=500)
# Coalesce most remaining error cases for lower overhead
response_type = type(response)
# Fast path for known Exception types, group together where safe
if response_type in (
msgspec.ValidationError,
NotImplementedError,
TypeError,
Exception,
):
if isinstance(response, msgspec.ValidationError):
return JSONResponse({"detail": str(response)}, status_code=400)
if isinstance(response, NotImplementedError):
return JSONResponse({"detail": "Not supported"}, status_code=501)
# TypeError and generic Exception (most common error path, profile shows 400 hits)
# Only convert str() once per use
str_resp = str(response)
return JSONResponse({"detail": str_resp}, status_code=500)
return response