44import traceback
55import collections
66import asyncio as aio
7- from .utils import _get_future_result
7+ from .utils import result_noraise
88
99
1010class BaseAioPool (object ):
11+ ''' BaseAioPool implements features, supposed to work in all supported
12+ python versions. Other features supposed to be implemented as mixins.'''
1113
1214 def __init__ (self , size = 1024 , * , loop = None ):
1315 self .loop = loop or aio .get_event_loop ()
1416
1517 self .size = size
1618 self ._executed = 0
17- self ._joined = collections .deque ()
18- self ._waiting = collections .deque ()
19+ self ._joined = set ()
20+ self ._waiting = {} # future -> task
21+ self ._spawned = {} # future -> task
1922 self .semaphore = aio .Semaphore (value = self .size , loop = self .loop )
2023
2124 async def __aenter__ (self ):
@@ -41,7 +44,7 @@ async def join(self):
4144 return True
4245
4346 fut = self .loop .create_future ()
44- self ._joined .append (fut )
47+ self ._joined .add (fut )
4548 try :
4649 return await fut
4750 finally :
@@ -72,33 +75,46 @@ async def _wrap(self, coro, future, cb=None, ctx=None):
7275 return
7376
7477 self .semaphore .release ()
75- if not exc :
76- future .set_result (res )
77- else :
78- future .set_exception (exc )
7978
79+ if not future .done ():
80+ if exc :
81+ future .set_exception (exc )
82+ else :
83+ future .set_result (res )
84+
85+ del self ._spawned [future ]
8086 if self .is_empty :
8187 self ._release_joined ()
8288
8389 async def _spawn (self , future , coro , cb = None , ctx = None ):
90+ acq_error = False
8491 try :
8592 await self .semaphore .acquire ()
8693 except Exception as e :
87- future .set_exception (e )
88- self ._waiting .remove (future )
89- wrapped = self ._wrap (coro , future , cb = cb , ctx = ctx )
90- self .loop .create_task (wrapped )
94+ acq_error = True
95+ if not future .done ():
96+ future .set_exception (e )
97+ finally :
98+ del self ._waiting [future ]
99+
100+ if future .done ():
101+ if not acq_error and future .cancelled (): # outside action
102+ self .semaphore .release ()
103+ else : # all good, can spawn now
104+ wrapped = self ._wrap (coro , future , cb = cb , ctx = ctx )
105+ task = self .loop .create_task (wrapped )
106+ self ._spawned [future ] = task
91107 return future
92108
93109 async def spawn_n (self , coro , cb = None , ctx = None ):
94110 future = self .loop .create_future ()
95- self ._waiting . append ( future )
96- self .loop . create_task ( self . _spawn ( future , coro , cb = cb , ctx = ctx ))
111+ task = self .loop . create_task ( self . _spawn ( future , coro , cb = cb , ctx = ctx ) )
112+ self ._waiting [ future ] = task
97113 return future
98114
99115 async def spawn (self , coro , cb = None , ctx = None ):
100116 future = self .loop .create_future ()
101- self ._waiting . append ( future )
117+ self ._waiting [ future ] = self . loop . create_future () # TODO omg ???
102118 return await self ._spawn (future , coro , cb = cb , ctx = ctx )
103119
104120 async def exec (self , coro , cb = None , ctx = None ):
@@ -113,16 +129,36 @@ async def map_n(self, fn, iterable):
113129
114130 async def map (self , fn , iterable , exc_as_result = True ):
115131 futures = await self .map_n (fn , iterable )
116- await self .join ()
117-
118- results = []
119- for fut in futures :
120- res = _get_future_result (fut , exc_as_result )
121- results .append (res )
122- return results
132+ await aio .wait (futures )
133+ return [result_noraise (fut , exc_as_result ) for fut in futures ]
123134
124135 async def iterwait (self , * arg , ** kw ): # TODO there's a way to support 3.5?
125136 raise NotImplementedError ('python3.6+ required' )
126137
127138 async def itermap (self , * arg , ** kw ): # TODO there's a way to support 3.5?
128139 raise NotImplementedError ('python3.6+ required' )
140+
141+ def _cancel (self , * futures ):
142+ tasks , _futures = [], []
143+
144+ if not len (futures ): # meaning cancel all
145+ tasks .extend (self ._waiting .values ())
146+ tasks .extend (self ._spawned .values ())
147+ _futures .extend (self ._waiting .keys ())
148+ _futures .extend (self ._spawned .keys ())
149+ else :
150+ for fut in futures :
151+ task = self ._spawned .get (fut , self ._waiting .get (fut ))
152+ if task :
153+ tasks .append (task )
154+ _futures .append (fut )
155+
156+ cancelled = sum ([1 for fut in tasks if fut .cancel ()])
157+ return cancelled , _futures
158+
159+ async def cancel (self , * futures , exc_as_result = True ):
160+ cancelled , _futures = self ._cancel (* futures )
161+ await aio .sleep (0 ) # let them actually cancel
162+ # need to collect them anyway, to supress warnings
163+ results = [result_noraise (fut , exc_as_result ) for fut in _futures ]
164+ return cancelled , results
0 commit comments