|
6 | 6 | import asyncio |
7 | 7 | import json |
8 | 8 | import logging |
9 | | -from typing import Dict |
| 9 | +from typing import Dict, Optional |
10 | 10 |
|
11 | 11 | from common.config.app_config import config |
12 | 12 | from common.models.messages_kernel import TeamConfiguration |
@@ -86,10 +86,159 @@ def __init__(self): |
86 | 86 | 20 # Maximum number of replanning rounds 20 needed to accommodate complex tasks |
87 | 87 | ) |
88 | 88 |
|
| 89 | + # Event-driven notification system for approvals and clarifications |
| 90 | + self._approval_events: Dict[str, asyncio.Event] = {} |
| 91 | + self._clarification_events: Dict[str, asyncio.Event] = {} |
| 92 | + |
| 93 | + # Default timeout for waiting operations (5 minutes) |
| 94 | + self.default_timeout: float = 300.0 |
| 95 | + |
89 | 96 | def get_current_orchestration(self, user_id: str) -> MagenticOrchestration: |
90 | 97 | """get existing orchestration instance.""" |
91 | 98 | return self.orchestrations.get(user_id, None) |
92 | 99 |
|
| 100 | + def set_approval_pending(self, plan_id: str) -> None: |
| 101 | + """Set an approval as pending and create an event for it.""" |
| 102 | + self.approvals[plan_id] = None |
| 103 | + if plan_id not in self._approval_events: |
| 104 | + self._approval_events[plan_id] = asyncio.Event() |
| 105 | + else: |
| 106 | + # Clear existing event to reset state |
| 107 | + self._approval_events[plan_id].clear() |
| 108 | + |
| 109 | + def set_approval_result(self, plan_id: str, approved: bool) -> None: |
| 110 | + """Set the approval result and trigger the event.""" |
| 111 | + self.approvals[plan_id] = approved |
| 112 | + if plan_id in self._approval_events: |
| 113 | + self._approval_events[plan_id].set() |
| 114 | + |
| 115 | + async def wait_for_approval(self, plan_id: str, timeout: Optional[float] = None) -> bool: |
| 116 | + """ |
| 117 | + Wait for an approval decision with timeout. |
| 118 | +
|
| 119 | + Args: |
| 120 | + plan_id: The plan ID to wait for |
| 121 | + timeout: Timeout in seconds (defaults to default_timeout) |
| 122 | +
|
| 123 | + Returns: |
| 124 | + The approval decision (True/False) |
| 125 | +
|
| 126 | + Raises: |
| 127 | + asyncio.TimeoutError: If timeout is exceeded |
| 128 | + KeyError: If plan_id is not found in approvals |
| 129 | + """ |
| 130 | + if timeout is None: |
| 131 | + timeout = self.default_timeout |
| 132 | + |
| 133 | + if plan_id not in self.approvals: |
| 134 | + raise KeyError(f"Plan ID {plan_id} not found in approvals") |
| 135 | + |
| 136 | + if self.approvals[plan_id] is not None: |
| 137 | + # Already has a result |
| 138 | + return self.approvals[plan_id] |
| 139 | + |
| 140 | + if plan_id not in self._approval_events: |
| 141 | + self._approval_events[plan_id] = asyncio.Event() |
| 142 | + |
| 143 | + try: |
| 144 | + await asyncio.wait_for(self._approval_events[plan_id].wait(), timeout=timeout) |
| 145 | + return self.approvals[plan_id] |
| 146 | + except asyncio.TimeoutError: |
| 147 | + # Clean up on timeout |
| 148 | + self.cleanup_approval(plan_id) |
| 149 | + raise |
| 150 | + except asyncio.CancelledError: |
| 151 | + # Handle task cancellation gracefully |
| 152 | + logger.debug(f"Approval request {plan_id} was cancelled") |
| 153 | + raise |
| 154 | + except Exception as e: |
| 155 | + # Handle any other unexpected errors |
| 156 | + logger.error(f"Unexpected error waiting for approval {plan_id}: {e}") |
| 157 | + raise |
| 158 | + finally: |
| 159 | + # Ensure cleanup happens regardless of how the try block exits |
| 160 | + # Only cleanup if the approval is still pending (None) to avoid |
| 161 | + # cleaning up successful approvals |
| 162 | + if plan_id in self.approvals and self.approvals[plan_id] is None: |
| 163 | + self.cleanup_approval(plan_id) |
| 164 | + |
| 165 | + def set_clarification_pending(self, request_id: str) -> None: |
| 166 | + """Set a clarification as pending and create an event for it.""" |
| 167 | + self.clarifications[request_id] = None |
| 168 | + if request_id not in self._clarification_events: |
| 169 | + self._clarification_events[request_id] = asyncio.Event() |
| 170 | + else: |
| 171 | + # Clear existing event to reset state |
| 172 | + self._clarification_events[request_id].clear() |
| 173 | + |
| 174 | + def set_clarification_result(self, request_id: str, answer: str) -> None: |
| 175 | + """Set the clarification response and trigger the event.""" |
| 176 | + self.clarifications[request_id] = answer |
| 177 | + if request_id in self._clarification_events: |
| 178 | + self._clarification_events[request_id].set() |
| 179 | + |
| 180 | + async def wait_for_clarification(self, request_id: str, timeout: Optional[float] = None) -> str: |
| 181 | + """ |
| 182 | + Wait for a clarification response with timeout. |
| 183 | +
|
| 184 | + Args: |
| 185 | + request_id: The request ID to wait for |
| 186 | + timeout: Timeout in seconds (defaults to default_timeout) |
| 187 | +
|
| 188 | + Returns: |
| 189 | + The clarification response |
| 190 | +
|
| 191 | + Raises: |
| 192 | + asyncio.TimeoutError: If timeout is exceeded |
| 193 | + KeyError: If request_id is not found in clarifications |
| 194 | + """ |
| 195 | + if timeout is None: |
| 196 | + timeout = self.default_timeout |
| 197 | + |
| 198 | + if request_id not in self.clarifications: |
| 199 | + raise KeyError(f"Request ID {request_id} not found in clarifications") |
| 200 | + |
| 201 | + if self.clarifications[request_id] is not None: |
| 202 | + # Already has a result |
| 203 | + return self.clarifications[request_id] |
| 204 | + |
| 205 | + if request_id not in self._clarification_events: |
| 206 | + self._clarification_events[request_id] = asyncio.Event() |
| 207 | + |
| 208 | + try: |
| 209 | + await asyncio.wait_for(self._clarification_events[request_id].wait(), timeout=timeout) |
| 210 | + return self.clarifications[request_id] |
| 211 | + except asyncio.TimeoutError: |
| 212 | + # Clean up on timeout |
| 213 | + self.cleanup_clarification(request_id) |
| 214 | + raise |
| 215 | + except asyncio.CancelledError: |
| 216 | + # Handle task cancellation gracefully |
| 217 | + logger.debug(f"Clarification request {request_id} was cancelled") |
| 218 | + raise |
| 219 | + except Exception as e: |
| 220 | + # Handle any other unexpected errors |
| 221 | + logger.error(f"Unexpected error waiting for clarification {request_id}: {e}") |
| 222 | + raise |
| 223 | + finally: |
| 224 | + # Ensure cleanup happens regardless of how the try block exits |
| 225 | + # Only cleanup if the clarification is still pending (None) to avoid |
| 226 | + # cleaning up successful clarifications |
| 227 | + if request_id in self.clarifications and self.clarifications[request_id] is None: |
| 228 | + self.cleanup_clarification(request_id) |
| 229 | + |
| 230 | + def cleanup_approval(self, plan_id: str) -> None: |
| 231 | + """Clean up approval resources.""" |
| 232 | + self.approvals.pop(plan_id, None) |
| 233 | + if plan_id in self._approval_events: |
| 234 | + del self._approval_events[plan_id] |
| 235 | + |
| 236 | + def cleanup_clarification(self, request_id: str) -> None: |
| 237 | + """Clean up clarification resources.""" |
| 238 | + self.clarifications.pop(request_id, None) |
| 239 | + if request_id in self._clarification_events: |
| 240 | + del self._clarification_events[request_id] |
| 241 | + |
93 | 242 |
|
94 | 243 | class ConnectionConfig: |
95 | 244 | """Connection manager for WebSocket connections.""" |
|
0 commit comments