Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ Homepage = "http://github.com/closeio/socketshark"
[dependency-groups]
lint = ["mypy", "ruff>=0.15"]
test = [
"aioresponses",
"pytest",
"pytest-asyncio",
"prometheus-async",
Expand All @@ -46,6 +45,7 @@ test = [
# requires `setuptools` and its subdependency of `distutils`, which is no
# longer present in Python 3.12+. We must thus bring it back explicitly.
"setuptools",
"aiointercept>=0.1.7",
]
dev = [{include-group = "lint"}, {include-group = "test"}]

Expand Down
89 changes: 41 additions & 48 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import aiohttp
import aioredis
import pytest
from aioresponses import aioresponses
from aiointercept import CallbackResult, aiointercept
from structlog.testing import capture_logs
from yarl import URL

Expand Down Expand Up @@ -117,20 +117,6 @@ def current_event_loop():
asyncio.set_event_loop(None)


class aioresponses_delayed(aioresponses): # noqa
"""
Just like aioresponses, but slightly delays POST requests.
"""

async def _request_mock(self, orig_self, method, url, *args, **kwargs):
result = await super()._request_mock(
orig_self, method, url, *args, **kwargs
)
if method == "POST":
await asyncio.sleep(0.2)
return result


class MockClient: # noqa SIM119
def __init__(self, shark):
self.log = []
Expand Down Expand Up @@ -158,7 +144,7 @@ class TestSession:
"""

async def _auth_session(self, session):
with aioresponses() as mock_responses:
async with aiointercept(mock_external_urls=True) as mock_responses:
# Mock auth endpoint
auth_url = "http://auth-service/auth/ticket/"

Expand Down Expand Up @@ -385,7 +371,7 @@ async def test_auth_ticket(self):
"error": "Must specify ticket.",
}

with aioresponses() as mock_responses:
async with aiointercept(mock_external_urls=True) as mock_responses:
# Auth endpoint unreachable
await session.on_client_event(
{
Expand All @@ -400,7 +386,7 @@ async def test_auth_ticket(self):
"error": "Service unavailable.",
}

with aioresponses() as mock_responses:
async with aiointercept(mock_external_urls=True) as mock_responses:
# Mock auth endpoint
auth_url = "http://auth-service/auth/ticket/"

Expand Down Expand Up @@ -856,7 +842,7 @@ async def test_subscription_authorizer(self):

await self._auth_session(session)

with aioresponses() as mock_responses:
async with aiointercept(mock_external_urls=True) as mock_responses:
# Authorizer is unavailable.
await session.on_client_event(
{
Expand All @@ -871,7 +857,7 @@ async def test_subscription_authorizer(self):
"error": "Service unavailable.",
}

with aioresponses() as mock_responses:
async with aiointercept(mock_external_urls=True) as mock_responses:
# Mock authorizer
authorizer_url = "http://auth-service/auth/authorizer/"

Expand Down Expand Up @@ -973,7 +959,7 @@ async def test_subscription_authorizer_data(self):

await self._auth_session(session)

with aioresponses() as mock_responses:
async with aiointercept(mock_external_urls=True) as mock_responses:
# Mock authorizer
mock_responses.post(
conf["authorizer"],
Expand Down Expand Up @@ -1031,7 +1017,7 @@ async def test_subscription_periodic_authorizer(self):

await self._auth_session(session)

with aioresponses() as mock_responses:
async with aiointercept(mock_external_urls=True) as mock_responses:
# Mock authorizer
authorizer_url = "http://auth-service/auth/authorizer/"

Expand Down Expand Up @@ -1129,7 +1115,7 @@ async def test_subscription_authorizer_data_periodic(self):

conf = TEST_CONFIG["SERVICES"]["periodic_authorizer_with_fields"]

with aioresponses() as mock_responses:
async with aiointercept(mock_external_urls=True) as mock_responses:
# Mock authorizer
mock_responses.post(
conf["authorizer"],
Expand Down Expand Up @@ -1204,7 +1190,7 @@ async def test_subscription_periodic_heartbeat(self):
client = MockClient(shark)
session = client.session

with aioresponses() as mock_responses:
async with aiointercept(mock_external_urls=True) as mock_responses:
# Mock responses
heartbeat_url = "http://my-service/heartbeat/"

Expand Down Expand Up @@ -1242,7 +1228,7 @@ async def test_subscription_complex(self):
conf = TEST_CONFIG["SERVICES"]["complex"]

# Test unsuccessful subscriptions
with aioresponses() as mock_responses:
async with aiointercept(mock_external_urls=True) as mock_responses:
mock_responses.post(conf["authorizer"], payload={"status": "ok"})
mock_responses.post(
conf["before_subscribe"], payload={"status": "error"}
Expand Down Expand Up @@ -1296,7 +1282,7 @@ async def test_subscription_complex(self):
}

# Test successful subscription with extra field and messages
with aioresponses() as mock_responses:
async with aiointercept(mock_external_urls=True) as mock_responses:
mock_responses.post(conf["authorizer"], payload={"status": "ok"})
mock_responses.post(
conf["before_subscribe"], payload={"status": "ok"}
Expand Down Expand Up @@ -1480,7 +1466,7 @@ async def test_subscription_complex(self):
}

# Test unsubscribe callbacks
with aioresponses() as mock_responses:
async with aiointercept(mock_external_urls=True) as mock_responses:
mock_responses.post(
conf["before_unsubscribe"], payload={"status": "error"}
)
Expand Down Expand Up @@ -1553,7 +1539,7 @@ async def test_subscription_complex(self):
}

# Test extra data in subscribe/unsubscribe callbacks
with aioresponses() as mock_responses:
async with aiointercept(mock_external_urls=True) as mock_responses:
mock_responses.post(conf["authorizer"], payload={"status": "ok"})
mock_responses.post(
conf["before_subscribe"],
Expand Down Expand Up @@ -1630,7 +1616,7 @@ async def test_unsubscribe_on_close(self):

conf = TEST_CONFIG["SERVICES"]["complex"]

with aioresponses() as mock_responses:
async with aiointercept(mock_external_urls=True) as mock_responses:
mock_responses.post(conf["authorizer"], payload={"status": "ok"})
mock_responses.post(
conf["before_subscribe"], payload={"status": "ok"}
Expand Down Expand Up @@ -1672,7 +1658,7 @@ async def test_order_filter(self):

conf = TEST_CONFIG["SERVICES"]["simple_before_subscribe"]

with aioresponses() as mock_responses:
async with aiointercept(mock_external_urls=True) as mock_responses:
mock_responses.post(
conf["before_subscribe"],
payload={
Expand Down Expand Up @@ -2063,7 +2049,7 @@ async def test_rate_limit(

conf = TEST_CONFIG["SERVICES"]["simple_before_subscribe"]

with aioresponses() as mock_responses:
async with aiointercept(mock_external_urls=True) as mock_responses:
mock_responses.post(
conf["before_subscribe"],
status=429,
Expand Down Expand Up @@ -2556,7 +2542,7 @@ async def task():
await asyncio.sleep(0.1)

aiosession = aiohttp.ClientSession()
mock_responses = aioresponses()
mock_responses = aiointercept(mock_external_urls=True)
conf = TEST_CONFIG["SERVICES"]["ws_test"]

async with aiosession.ws_connect(self.ws_url) as ws:
Expand All @@ -2577,16 +2563,14 @@ async def task():
"status": "ok",
}

# Start mocking here (if we started mocking earlier we wouldn't
# be able to use aiohttp to connect to the WebSocket).
mock_responses.start()
await mock_responses.start()
mock_responses.post(conf["on_unsubscribe"], payload={})

await aiosession.close()

# Wait until backend learns about the disconnected WebSocket.
await asyncio.sleep(0.1)
mock_responses.stop()
await mock_responses.stop()
requests = mock_responses.requests[
("POST", URL(conf["on_unsubscribe"]))
]
Expand All @@ -2609,7 +2593,7 @@ def test_shutdown(self):
"""
Make sure we call unsubscribe callbacks when shutting down.
"""
mock_responses = aioresponses()
mock_responses = aiointercept(mock_external_urls=True)
conf = TEST_CONFIG["SERVICES"]["ws_test"]

async def task():
Expand All @@ -2636,7 +2620,7 @@ async def task():
"status": "ok",
}

mock_responses.start()
await mock_responses.start()
mock_responses.post(conf["on_unsubscribe"], payload={})

asyncio.ensure_future(shark.shutdown())
Expand All @@ -2650,7 +2634,6 @@ async def task():
shark = SocketShark(TEST_CONFIG)
asyncio.ensure_future(task())
shark.start()
mock_responses.stop()

requests = mock_responses.requests[
("POST", URL(conf["on_unsubscribe"]))
Expand All @@ -2662,9 +2645,12 @@ def test_shutdown_connections(self):
"""
Make sure we don't allow new WebSocket connections when shutting down.
"""
# Pretend we have a subscription that takes a long time to close (so
# we can sneak in a connection attempt).
mock_responses = aioresponses_delayed()
mock_responses = aiointercept(mock_external_urls=True)

async def _delayed_ok(url, **kwargs):
await asyncio.sleep(0.2)
return CallbackResult(payload={})

conf = TEST_CONFIG["SERVICES"]["ws_test"]

async def task():
Expand All @@ -2691,25 +2677,32 @@ async def task():
"status": "ok",
}

mock_responses.start()
mock_responses.post(conf["on_unsubscribe"], payload={})
await mock_responses.start()
mock_responses.post(
conf["on_unsubscribe"],
# Pretend we have a subscription that takes a long time to
# close (so we can sneak in a connection attempt).
callback=_delayed_ok,
)

asyncio.ensure_future(shark.shutdown())

msg = await ws.receive()
assert msg.type == aiohttp.WSMsgType.CLOSE
await ws.close()

# Ensure we call the on_unsubscribe callback before the
# stopping the patcher.
# Ensure we call the `on_unsubscribe` callback before stopping
# the patcher.
await asyncio.sleep(0.1)

mock_responses.stop()
await mock_responses.stop()

# Attempt a new connection.
with pytest.raises(aiohttp.ClientConnectionError):
async with aiosession.ws_connect(self.ws_url) as ws:
raise AssertionError # Whoops!
# Whoops! This should never happen. The connection should
# fail before we get here.
raise AssertionError

await aiosession.close()

Expand Down
Loading