Skip to content

Commit 24518ca

Browse files
committed
support of multiple joins
1 parent adbe311 commit 24518ca

File tree

7 files changed

+161
-25
lines changed

7 files changed

+161
-25
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
.env*
88
.pypyenv*
99
.pytest_*
10-
.git
10+
.mypy_cache
1111
__pycache__
1212
local_settings.py
1313

README.md

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,35 @@
11
# asyncio-pool
22

3-
TODO
3+
TODO: cancelled, timeouts, callbacks, features, tests, readme
4+
5+
Example (more in `tests/` and `examples/` dirs): # TODO
6+
7+
```python
8+
import asyncio as aio
9+
from asyncio_pool import AioPool
10+
11+
12+
async def worker(n):
13+
await aio.sleep(1 / n)
14+
15+
16+
async def run_in_pool():
17+
18+
async with AioPool(size=10) as pool: # no more than 10 concurrent coroutines
19+
results = await pool.map(worker, range(1, 100))
20+
21+
### OR
22+
23+
pool = AioPool(size=10)
24+
25+
# generator returning futures for each worker result
26+
futures = await pool.itermap(worker, range(1,100))
27+
# or spawning manually: list of futures for each worker result
28+
futures = [await pool.spawn(worker(i)) for i in range(1,100)]
29+
30+
await pool.join()
31+
print [fut.result() for fut in batch] # will re-raise exceptions
32+
33+
34+
### OR moar later
35+
```

asyncio_pool/pool.py

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,67 @@
11
# coding: utf8
2+
'''Pool of asyncio coroutines with familiar interface'''
23

34
import traceback
5+
import collections
46
import asyncio as aio
57

68

79
class AioPool(object):
810

9-
def __init__(self, size=1024, loop=None):
10-
self._size = size
11+
def __init__(self, size=1024, *, loop=None):
12+
self.loop = loop or aio.get_event_loop()
13+
14+
self.size = size
1115
self._waiting = 0
1216
self._executed = 0
13-
self.semaphore = aio.Semaphore(value=self._size)
14-
self.loop = loop or aio.get_event_loop()
15-
self._all_done = self.loop.create_future()
17+
self._joined = collections.deque()
18+
self.semaphore = aio.Semaphore(value=self.size, loop=self.loop)
19+
20+
async def __aenter__(self):
21+
return self
22+
23+
async def __aexit__(self, ext_type, exc, tb):
24+
await self.join()
25+
26+
@property
27+
def n_active(self):
28+
return self.size - self.semaphore._value
1629

1730
@property
1831
def is_empty(self):
19-
return 0 == self._waiting == (self._size - self.semaphore._value)
32+
return 0 == self._waiting == self.n_active
2033

2134
@property
2235
def is_full(self):
23-
return self._waiting + (self._size - self.semaphore._value) >= self._size
36+
return self.size <= self._waiting + self.n_active
2437

2538
async def join(self):
26-
await self._all_done
39+
if self.is_empty:
40+
return True
2741

28-
async def __aenter__(self):
29-
return self
42+
fut = self.loop.create_future()
43+
self._joined.append(fut)
44+
try:
45+
return await fut
46+
finally:
47+
self._joined.remove(fut)
3048

31-
async def __aexit__(self, ext_type, exc, tb):
32-
await self.join()
49+
def _release_joined(self):
50+
if not self.is_empty:
51+
raise RuntimeError() # TODO
52+
53+
for fut in self._joined:
54+
if not fut.done():
55+
fut.set_result(True)
3356

3457
async def _acquire(self):
35-
if self._all_done.done():
36-
self._all_done = self.loop.create_future()
3758
self._waiting += 1
3859
await self.semaphore.acquire()
3960
self._waiting -= 1
4061

4162
async def _wrap(self, coro, future, cb=None, ctx=None):
4263
res, exc, tb = None, None, None
64+
4365
try:
4466
res = await coro
4567
future.set_result(res)
@@ -50,10 +72,14 @@ async def _wrap(self, coro, future, cb=None, ctx=None):
5072
self.semaphore.release()
5173
self._executed += 1
5274

53-
if cb:
54-
await self.spawn(cb(res, (exc, tb), ctx))
55-
elif self.is_empty:
56-
self._all_done.set_result('done')
75+
try:
76+
if cb:
77+
await self.spawn(cb(res, (exc, tb), ctx))
78+
except Exception as exc_cb:
79+
pass # TODO
80+
finally:
81+
if self.is_empty:
82+
self._release_joined()
5783

5884
async def spawn(self, coro, cb=None, ctx=None):
5985
await self._acquire()

reqs-dev.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
asyncio
12
ipython
23
pytest
34
pytest-asyncio
5+
async-timeout

run_tests.sh

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
11
#! /bin/sh
22

3-
py35=python3.5
43
py36=python3.6
54
py37=python3.7
6-
pypy3=/opt/pypy3/pypy3/bin/pypy3
7-
5+
## 3.5 and pypy (which is also 3.5) are disalbed because of async generators
6+
# py35=python3.5
7+
# pypy3=/opt/pypy3/pypy3/bin/pypy3
88
default_env=$py36
99

10+
todo=${@:-"./tests"}
1011

1112
for py in $py35 $py36 $py37 $pypy3
1213
do
14+
echo ""
1315
if [ -x "$(command -v $py)" ]; then
1416
pyname="$(basename $py)"
1517
envname=".env_$pyname"
1618

1719
if ! [ -d $envname ]; then
1820
echo "$pyname: virtual env does not exist"
1921
else
20-
$envname/bin/pytest tests
22+
echo "$pyname: running for $todo"
23+
$envname/bin/pytest $todo
2124
fi
2225
else
2326
echo "$py: not found"

tests/old.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
# coding: utf8
22

3-
from .asyncio_pool import AioPool
3+
import os
4+
import sys
5+
curr_dir = os.path.dirname(os.path.abspath(__file__))
6+
sys.path.insert(0, os.path.split(curr_dir)[0])
7+
8+
import asyncio as aio
9+
from asyncio_pool import AioPool
410

511

612
async def test(n):

tests/test_base.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# coding: utf8
2+
3+
import pytest
4+
import asyncio as aio
5+
from asyncio_pool import AioPool
6+
from async_timeout import timeout
7+
8+
9+
@pytest.mark.asyncio
10+
async def test_timeout_cancel():
11+
async def wrk(sem):
12+
async with sem:
13+
await aio.sleep(1)
14+
15+
sem = aio.Semaphore(value=2)
16+
17+
async with timeout(0.2):
18+
with pytest.raises(aio.CancelledError):
19+
await aio.gather(*[wrk(sem) for _ in range(3)])
20+
21+
22+
@pytest.mark.asyncio
23+
async def test_outer_join():
24+
25+
todo, to_release = range(1,15), range(10)
26+
done, released = [], []
27+
28+
async def inner(n):
29+
nonlocal done
30+
await aio.sleep(1 / n)
31+
done.append(n)
32+
33+
async def outer(n, pool):
34+
nonlocal released
35+
await pool.join()
36+
released.append(n)
37+
38+
loop = aio.get_event_loop()
39+
pool = AioPool(size=100)
40+
41+
tasks = [await pool.spawn(inner(i)) for i in todo]
42+
joined = [loop.create_task(outer(j, pool)) for j in to_release]
43+
await pool.join()
44+
45+
assert len(released) < len(to_release)
46+
await aio.wait(joined)
47+
assert len(todo) == len(done) and len(released) == len(to_release)
48+
49+
50+
@pytest.mark.asyncio
51+
async def test_internal_join():
52+
async def wrk(n, pool):
53+
aio.sleep(1 / n)
54+
if n == 3:
55+
await pool.join()
56+
else:
57+
await pool.spawn(wrk(n + 1, pool))
58+
return n
59+
60+
return True
61+
pool = AioPool(size=3)
62+
await pool.spawn(wrk(1, pool))
63+
64+
async with timeout(1.5) as tm:
65+
await pool.join()
66+
67+
assert tm.expired

0 commit comments

Comments
 (0)