Skip to content

Commit 7150694

Browse files
authored
Add Session API methods - create, get, list, messages, feed, SSE listen, complete (#44)
- 8 new public methods: create_session, get_session, list_sessions, post_session_message, list_session_messages, get_session_feed, listen_session, complete_session - SSE streaming with auto-reconnect on timeout - 13 new tests, 87/87 total green
1 parent 659f006 commit 7150694

2 files changed

Lines changed: 434 additions & 0 deletions

File tree

axme_sdk/client.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1619,6 +1619,225 @@ def mcp_call_tool(
16191619
should_retry = retryable if retryable is not None else bool(idempotency_key)
16201620
return self._mcp_request(payload=payload, trace_id=trace_id, retryable=should_retry)
16211621

1622+
# ------------------------------------------------------------------
1623+
# Session API
1624+
# ------------------------------------------------------------------
1625+
1626+
def create_session(
1627+
self,
1628+
*,
1629+
type: str = "task",
1630+
project_id: str | None = None,
1631+
parent_session_id: str | None = None,
1632+
depends_on: list[str] | None = None,
1633+
metadata: dict[str, Any] | None = None,
1634+
trace_id: str | None = None,
1635+
) -> dict[str, Any]:
1636+
body: dict[str, Any] = {"type": type}
1637+
if project_id is not None:
1638+
body["project_id"] = project_id
1639+
if parent_session_id is not None:
1640+
body["parent_session_id"] = parent_session_id
1641+
if depends_on is not None:
1642+
body["depends_on"] = depends_on
1643+
if metadata is not None:
1644+
body["metadata"] = metadata
1645+
return self._request_json("POST", "/v1/sessions", json_body=body, trace_id=trace_id, retryable=False)
1646+
1647+
def get_session(
1648+
self,
1649+
session_id: str,
1650+
*,
1651+
trace_id: str | None = None,
1652+
) -> dict[str, Any]:
1653+
return self._request_json("GET", f"/v1/sessions/{session_id}", trace_id=trace_id, retryable=True)
1654+
1655+
def list_sessions(
1656+
self,
1657+
*,
1658+
status: str | None = None,
1659+
parent_session_id: str | None = None,
1660+
limit: int | None = None,
1661+
trace_id: str | None = None,
1662+
) -> dict[str, Any]:
1663+
params: dict[str, str] = {}
1664+
if status is not None:
1665+
params["status"] = status
1666+
if parent_session_id is not None:
1667+
params["parent_session_id"] = parent_session_id
1668+
if limit is not None:
1669+
params["limit"] = str(limit)
1670+
return self._request_json("GET", "/v1/sessions", params=params or None, trace_id=trace_id, retryable=True)
1671+
1672+
def post_session_message(
1673+
self,
1674+
session_id: str,
1675+
*,
1676+
role: str,
1677+
content: Any,
1678+
content_type: str = "text",
1679+
trace_id: str | None = None,
1680+
) -> dict[str, Any]:
1681+
body: dict[str, Any] = {"role": role, "content": content, "content_type": content_type}
1682+
return self._request_json("POST", f"/v1/sessions/{session_id}/messages", json_body=body, trace_id=trace_id, retryable=False)
1683+
1684+
def list_session_messages(
1685+
self,
1686+
session_id: str,
1687+
*,
1688+
since: int = 0,
1689+
limit: int | None = None,
1690+
trace_id: str | None = None,
1691+
) -> dict[str, Any]:
1692+
params: dict[str, str] = {}
1693+
if since > 0:
1694+
params["since"] = str(since)
1695+
if limit is not None:
1696+
params["limit"] = str(limit)
1697+
return self._request_json(
1698+
"GET", f"/v1/sessions/{session_id}/messages", params=params or None, trace_id=trace_id, retryable=True,
1699+
)
1700+
1701+
def get_session_feed(
1702+
self,
1703+
session_id: str,
1704+
*,
1705+
limit: int | None = None,
1706+
trace_id: str | None = None,
1707+
) -> dict[str, Any]:
1708+
params: dict[str, str] = {}
1709+
if limit is not None:
1710+
params["limit"] = str(limit)
1711+
return self._request_json(
1712+
"GET", f"/v1/sessions/{session_id}/feed", params=params or None, trace_id=trace_id, retryable=True,
1713+
)
1714+
1715+
def listen_session(
1716+
self,
1717+
session_id: str,
1718+
*,
1719+
since: int = 0,
1720+
wait_seconds: int = 30,
1721+
poll_interval_seconds: float = 1.0,
1722+
timeout_seconds: float | None = None,
1723+
trace_id: str | None = None,
1724+
) -> Iterator[dict[str, Any]]:
1725+
"""Stream session feed events via SSE. Yields dicts for each message/intent event.
1726+
1727+
Reconnects automatically on stream timeout. Stops on session.completed event
1728+
or when timeout_seconds is exceeded.
1729+
"""
1730+
if since < 0:
1731+
raise ValueError("since must be >= 0")
1732+
if wait_seconds < 1:
1733+
raise ValueError("wait_seconds must be >= 1")
1734+
if timeout_seconds is not None and timeout_seconds <= 0:
1735+
raise ValueError("timeout_seconds must be > 0 when provided")
1736+
1737+
deadline = (time.monotonic() + timeout_seconds) if timeout_seconds is not None else None
1738+
next_since = since
1739+
1740+
while True:
1741+
if deadline is not None and time.monotonic() >= deadline:
1742+
return
1743+
1744+
stream_wait = wait_seconds
1745+
if deadline is not None:
1746+
seconds_left = max(0.0, deadline - time.monotonic())
1747+
if seconds_left <= 0:
1748+
return
1749+
stream_wait = max(1, min(wait_seconds, int(seconds_left)))
1750+
1751+
try:
1752+
for event in self._iter_session_feed_stream(
1753+
session_id=session_id,
1754+
since=next_since,
1755+
wait_seconds=stream_wait,
1756+
trace_id=trace_id,
1757+
):
1758+
seq = event.get("seq")
1759+
if isinstance(seq, int) and seq > next_since:
1760+
next_since = seq
1761+
yield event
1762+
if event.get("type") == "session.completed":
1763+
return
1764+
except AxmeHttpError as exc:
1765+
if exc.status_code not in {404, 405, 501}:
1766+
raise
1767+
return
1768+
1769+
# Stream ended (timeout), reconnect
1770+
if deadline is not None and time.monotonic() >= deadline:
1771+
return
1772+
time.sleep(poll_interval_seconds)
1773+
1774+
def complete_session(
1775+
self,
1776+
session_id: str,
1777+
*,
1778+
result: dict[str, Any] | None = None,
1779+
trace_id: str | None = None,
1780+
) -> dict[str, Any]:
1781+
body: dict[str, Any] = {}
1782+
if result is not None:
1783+
body["result"] = result
1784+
return self._request_json("POST", f"/v1/sessions/{session_id}/complete", json_body=body, trace_id=trace_id, retryable=False)
1785+
1786+
def _iter_session_feed_stream(
1787+
self,
1788+
*,
1789+
session_id: str,
1790+
since: int,
1791+
wait_seconds: int,
1792+
trace_id: str | None,
1793+
) -> Iterator[dict[str, Any]]:
1794+
headers: dict[str, str] | None = None
1795+
normalized_trace_id = self._normalize_trace_id(trace_id)
1796+
if normalized_trace_id is not None:
1797+
headers = {"X-Trace-Id": normalized_trace_id}
1798+
1799+
stream_timeout = httpx.Timeout(
1800+
connect=10.0,
1801+
read=float(wait_seconds) + 15.0,
1802+
write=10.0,
1803+
pool=10.0,
1804+
)
1805+
with self._http.stream(
1806+
"GET",
1807+
f"/v1/sessions/{session_id}/feed/stream",
1808+
params={"since": str(since), "wait_seconds": str(wait_seconds)},
1809+
headers=headers,
1810+
timeout=stream_timeout,
1811+
) as response:
1812+
if response.status_code >= 400:
1813+
self._raise_http_error(response)
1814+
1815+
current_event: str | None = None
1816+
data_lines: list[str] = []
1817+
for line in response.iter_lines():
1818+
if line == "":
1819+
if current_event == "stream.timeout":
1820+
return
1821+
if current_event and data_lines:
1822+
try:
1823+
payload = json.loads("\n".join(data_lines))
1824+
except ValueError:
1825+
payload = None
1826+
if isinstance(payload, dict):
1827+
payload["type"] = current_event
1828+
yield payload
1829+
current_event = None
1830+
data_lines = []
1831+
continue
1832+
if line.startswith(":"):
1833+
continue
1834+
if line.startswith("event:"):
1835+
current_event = line.partition(":")[2].strip()
1836+
continue
1837+
if line.startswith("data:"):
1838+
data_lines.append(line.partition(":")[2].lstrip())
1839+
continue
1840+
16221841
def _request_json(
16231842
self,
16241843
method: str,

0 commit comments

Comments
 (0)