@@ -1619,6 +1619,225 @@ def mcp_call_tool(
16191619 should_retry = retryable if retryable is not None else bool (idempotency_key )
16201620 return self ._mcp_request (payload = payload , trace_id = trace_id , retryable = should_retry )
16211621
1622+ # ------------------------------------------------------------------
1623+ # Session API
1624+ # ------------------------------------------------------------------
1625+
1626+ def create_session (
1627+ self ,
1628+ * ,
1629+ type : str = "task" ,
1630+ project_id : str | None = None ,
1631+ parent_session_id : str | None = None ,
1632+ depends_on : list [str ] | None = None ,
1633+ metadata : dict [str , Any ] | None = None ,
1634+ trace_id : str | None = None ,
1635+ ) -> dict [str , Any ]:
1636+ body : dict [str , Any ] = {"type" : type }
1637+ if project_id is not None :
1638+ body ["project_id" ] = project_id
1639+ if parent_session_id is not None :
1640+ body ["parent_session_id" ] = parent_session_id
1641+ if depends_on is not None :
1642+ body ["depends_on" ] = depends_on
1643+ if metadata is not None :
1644+ body ["metadata" ] = metadata
1645+ return self ._request_json ("POST" , "/v1/sessions" , json_body = body , trace_id = trace_id , retryable = False )
1646+
1647+ def get_session (
1648+ self ,
1649+ session_id : str ,
1650+ * ,
1651+ trace_id : str | None = None ,
1652+ ) -> dict [str , Any ]:
1653+ return self ._request_json ("GET" , f"/v1/sessions/{ session_id } " , trace_id = trace_id , retryable = True )
1654+
1655+ def list_sessions (
1656+ self ,
1657+ * ,
1658+ status : str | None = None ,
1659+ parent_session_id : str | None = None ,
1660+ limit : int | None = None ,
1661+ trace_id : str | None = None ,
1662+ ) -> dict [str , Any ]:
1663+ params : dict [str , str ] = {}
1664+ if status is not None :
1665+ params ["status" ] = status
1666+ if parent_session_id is not None :
1667+ params ["parent_session_id" ] = parent_session_id
1668+ if limit is not None :
1669+ params ["limit" ] = str (limit )
1670+ return self ._request_json ("GET" , "/v1/sessions" , params = params or None , trace_id = trace_id , retryable = True )
1671+
1672+ def post_session_message (
1673+ self ,
1674+ session_id : str ,
1675+ * ,
1676+ role : str ,
1677+ content : Any ,
1678+ content_type : str = "text" ,
1679+ trace_id : str | None = None ,
1680+ ) -> dict [str , Any ]:
1681+ body : dict [str , Any ] = {"role" : role , "content" : content , "content_type" : content_type }
1682+ return self ._request_json ("POST" , f"/v1/sessions/{ session_id } /messages" , json_body = body , trace_id = trace_id , retryable = False )
1683+
1684+ def list_session_messages (
1685+ self ,
1686+ session_id : str ,
1687+ * ,
1688+ since : int = 0 ,
1689+ limit : int | None = None ,
1690+ trace_id : str | None = None ,
1691+ ) -> dict [str , Any ]:
1692+ params : dict [str , str ] = {}
1693+ if since > 0 :
1694+ params ["since" ] = str (since )
1695+ if limit is not None :
1696+ params ["limit" ] = str (limit )
1697+ return self ._request_json (
1698+ "GET" , f"/v1/sessions/{ session_id } /messages" , params = params or None , trace_id = trace_id , retryable = True ,
1699+ )
1700+
1701+ def get_session_feed (
1702+ self ,
1703+ session_id : str ,
1704+ * ,
1705+ limit : int | None = None ,
1706+ trace_id : str | None = None ,
1707+ ) -> dict [str , Any ]:
1708+ params : dict [str , str ] = {}
1709+ if limit is not None :
1710+ params ["limit" ] = str (limit )
1711+ return self ._request_json (
1712+ "GET" , f"/v1/sessions/{ session_id } /feed" , params = params or None , trace_id = trace_id , retryable = True ,
1713+ )
1714+
1715+ def listen_session (
1716+ self ,
1717+ session_id : str ,
1718+ * ,
1719+ since : int = 0 ,
1720+ wait_seconds : int = 30 ,
1721+ poll_interval_seconds : float = 1.0 ,
1722+ timeout_seconds : float | None = None ,
1723+ trace_id : str | None = None ,
1724+ ) -> Iterator [dict [str , Any ]]:
1725+ """Stream session feed events via SSE. Yields dicts for each message/intent event.
1726+
1727+ Reconnects automatically on stream timeout. Stops on session.completed event
1728+ or when timeout_seconds is exceeded.
1729+ """
1730+ if since < 0 :
1731+ raise ValueError ("since must be >= 0" )
1732+ if wait_seconds < 1 :
1733+ raise ValueError ("wait_seconds must be >= 1" )
1734+ if timeout_seconds is not None and timeout_seconds <= 0 :
1735+ raise ValueError ("timeout_seconds must be > 0 when provided" )
1736+
1737+ deadline = (time .monotonic () + timeout_seconds ) if timeout_seconds is not None else None
1738+ next_since = since
1739+
1740+ while True :
1741+ if deadline is not None and time .monotonic () >= deadline :
1742+ return
1743+
1744+ stream_wait = wait_seconds
1745+ if deadline is not None :
1746+ seconds_left = max (0.0 , deadline - time .monotonic ())
1747+ if seconds_left <= 0 :
1748+ return
1749+ stream_wait = max (1 , min (wait_seconds , int (seconds_left )))
1750+
1751+ try :
1752+ for event in self ._iter_session_feed_stream (
1753+ session_id = session_id ,
1754+ since = next_since ,
1755+ wait_seconds = stream_wait ,
1756+ trace_id = trace_id ,
1757+ ):
1758+ seq = event .get ("seq" )
1759+ if isinstance (seq , int ) and seq > next_since :
1760+ next_since = seq
1761+ yield event
1762+ if event .get ("type" ) == "session.completed" :
1763+ return
1764+ except AxmeHttpError as exc :
1765+ if exc .status_code not in {404 , 405 , 501 }:
1766+ raise
1767+ return
1768+
1769+ # Stream ended (timeout), reconnect
1770+ if deadline is not None and time .monotonic () >= deadline :
1771+ return
1772+ time .sleep (poll_interval_seconds )
1773+
1774+ def complete_session (
1775+ self ,
1776+ session_id : str ,
1777+ * ,
1778+ result : dict [str , Any ] | None = None ,
1779+ trace_id : str | None = None ,
1780+ ) -> dict [str , Any ]:
1781+ body : dict [str , Any ] = {}
1782+ if result is not None :
1783+ body ["result" ] = result
1784+ return self ._request_json ("POST" , f"/v1/sessions/{ session_id } /complete" , json_body = body , trace_id = trace_id , retryable = False )
1785+
1786+ def _iter_session_feed_stream (
1787+ self ,
1788+ * ,
1789+ session_id : str ,
1790+ since : int ,
1791+ wait_seconds : int ,
1792+ trace_id : str | None ,
1793+ ) -> Iterator [dict [str , Any ]]:
1794+ headers : dict [str , str ] | None = None
1795+ normalized_trace_id = self ._normalize_trace_id (trace_id )
1796+ if normalized_trace_id is not None :
1797+ headers = {"X-Trace-Id" : normalized_trace_id }
1798+
1799+ stream_timeout = httpx .Timeout (
1800+ connect = 10.0 ,
1801+ read = float (wait_seconds ) + 15.0 ,
1802+ write = 10.0 ,
1803+ pool = 10.0 ,
1804+ )
1805+ with self ._http .stream (
1806+ "GET" ,
1807+ f"/v1/sessions/{ session_id } /feed/stream" ,
1808+ params = {"since" : str (since ), "wait_seconds" : str (wait_seconds )},
1809+ headers = headers ,
1810+ timeout = stream_timeout ,
1811+ ) as response :
1812+ if response .status_code >= 400 :
1813+ self ._raise_http_error (response )
1814+
1815+ current_event : str | None = None
1816+ data_lines : list [str ] = []
1817+ for line in response .iter_lines ():
1818+ if line == "" :
1819+ if current_event == "stream.timeout" :
1820+ return
1821+ if current_event and data_lines :
1822+ try :
1823+ payload = json .loads ("\n " .join (data_lines ))
1824+ except ValueError :
1825+ payload = None
1826+ if isinstance (payload , dict ):
1827+ payload ["type" ] = current_event
1828+ yield payload
1829+ current_event = None
1830+ data_lines = []
1831+ continue
1832+ if line .startswith (":" ):
1833+ continue
1834+ if line .startswith ("event:" ):
1835+ current_event = line .partition (":" )[2 ].strip ()
1836+ continue
1837+ if line .startswith ("data:" ):
1838+ data_lines .append (line .partition (":" )[2 ].lstrip ())
1839+ continue
1840+
16221841 def _request_json (
16231842 self ,
16241843 method : str ,
0 commit comments