|
6 | 6 | import queue as stdlib_queue |
7 | 7 | import threading |
8 | 8 | from itertools import count |
9 | | -from typing import TYPE_CHECKING, Generic, TypeVar |
| 9 | +from typing import TYPE_CHECKING, Generic, TypeVar, Protocol, Final, NoReturn |
10 | 10 |
|
11 | 11 | import attrs |
12 | 12 | import outcome |
|
36 | 36 | Ts = TypeVarTuple("Ts") |
37 | 37 |
|
38 | 38 | RetT = TypeVar("RetT") |
| 39 | +T_co = TypeVar("T_co", covariant=True) |
39 | 40 |
|
40 | 41 |
|
41 | 42 | class _ParentTaskData(threading.local): |
@@ -253,6 +254,32 @@ def run_in_system_nursery(self, token: TrioToken) -> None: |
253 | 254 | token.run_sync_soon(self.run_sync) |
254 | 255 |
|
255 | 256 |
|
| 257 | +class _SupportsUnwrap(Protocol, Generic[T_co]): |
| 258 | + def unwrap(self) -> T_co: ... |
| 259 | + |
| 260 | + |
| 261 | +class _Value(_SupportsUnwrap[T_co]): |
| 262 | + def __init__(self, v: T_co) -> None: |
| 263 | + self._v: Final = v |
| 264 | + |
| 265 | + def unwrap(self) -> T_co: |
| 266 | + try: |
| 267 | + return self._v |
| 268 | + finally: |
| 269 | + del self._v |
| 270 | + |
| 271 | + |
| 272 | +class _Error(_SupportsUnwrap[NoReturn]): |
| 273 | + def __init__(self, e: BaseException) -> None: |
| 274 | + self._e: Final = e |
| 275 | + |
| 276 | + def unwrap(self) -> NoReturn: |
| 277 | + try: |
| 278 | + raise self._e |
| 279 | + finally: |
| 280 | + del self._e |
| 281 | + |
| 282 | + |
256 | 283 | @enable_ki_protection |
257 | 284 | async def to_thread_run_sync( |
258 | 285 | sync_fn: Callable[[Unpack[Ts]], RetT], |
@@ -372,11 +399,15 @@ def do_release_then_return_result() -> RetT: |
372 | 399 | try: |
373 | 400 | return result.unwrap() |
374 | 401 | finally: |
| 402 | + del result |
375 | 403 | limiter.release_on_behalf_of(placeholder) |
376 | 404 |
|
377 | 405 | result = outcome.capture(do_release_then_return_result) |
| 406 | + if isinstance(result, outcome.Error): |
| 407 | + result2: _SupportsUnwrap[RetT] = _Error(result.error) |
| 408 | + result2 = _Value(result.value) |
378 | 409 | if task_register[0] is not None: |
379 | | - trio.lowlevel.reschedule(task_register[0], outcome.Value(result)) |
| 410 | + trio.lowlevel.reschedule(task_register[0], outcome.Value(result2)) |
380 | 411 |
|
381 | 412 | current_trio_token = trio.lowlevel.current_trio_token() |
382 | 413 |
|
|
0 commit comments