diff --git a/sdk/eventhub/azure-eventhub/CHANGELOG.md b/sdk/eventhub/azure-eventhub/CHANGELOG.md index f287cb0b8fde..f4f41bd9471a 100644 --- a/sdk/eventhub/azure-eventhub/CHANGELOG.md +++ b/sdk/eventhub/azure-eventhub/CHANGELOG.md @@ -1,5 +1,11 @@ # Release History +## 5.15.2 (Unreleased) + +### Bugs Fixed + +- Fixed a bug where `WebSocketTransportAsync.close()` could leak an `aiohttp.ClientSession` when the underlying websocket close raised. Also fixed session leaks on reconnect and on non-`ClientConnectorError` failures in `connect()`. + ## 5.15.1 (2025-11-11) ### Bugs Fixed diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py index 0478f41324bf..af196e326336 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py @@ -448,6 +448,10 @@ async def connect(self): http_proxy_auth = BasicAuth(login=username, password=password) + # Close any session left over from a previous connect attempt so that + # the aiohttp ClientSession is not leaked on reconnect. + await self._close_session_safely() + self.session = ClientSession() if self._custom_endpoint: url = f"wss://{self._custom_endpoint}" if self._use_tls else f"ws://{self._custom_endpoint}" @@ -477,6 +481,7 @@ async def connect(self): ) except ClientConnectorError as exc: _LOGGER.info("Websocket connect failed: %r", exc, extra=self.network_trace_params) + await self._close_session_safely() if self._custom_endpoint: raise AuthenticationException( ErrorCondition.ClientError, @@ -484,8 +489,26 @@ async def connect(self): error=exc, ) from exc raise ConnectionError("Failed to establish websocket connection: " + str(exc)) from exc + except Exception: # pylint: disable=broad-except + # Any other failure during ws_connect must also clean up the new + # session so the aiohttp ClientSession is not leaked. + await self._close_session_safely() + raise self.connected = True + async def _close_session_safely(self): + """Close ``self.session`` if set, suppressing and logging any errors.""" + if self.session is not None: + try: + await self.session.close() + except Exception as e: # pylint: disable=broad-except + _LOGGER.debug( + "Error closing aiohttp session: %r", + e, + extra=self.network_trace_params, + ) + self.session = None + async def _read(self, toread, buffer=None, **kwargs): # pylint: disable=unused-argument """Read exactly n bytes from the peer. @@ -524,9 +547,16 @@ async def _read(self, toread, buffer=None, **kwargs): # pylint: disable=unused- async def close(self): """Do any preliminary work in shutting down the connection.""" async with self.socket_lock: - await self.sock.close() - await self.session.close() - self.connected = False + try: + if self.sock is not None: + await self.sock.close() + except Exception as e: # pylint: disable=broad-except + _LOGGER.debug( + "Error closing websocket: %r", e, extra=self.network_trace_params + ) + self.sock = None + await self._close_session_safely() + self.connected = False async def _write(self, s): """Completely write a string (byte array) to the peer. diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py index 5bb5c30266da..5d660bbf8ea3 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py @@ -3,4 +3,4 @@ # Licensed under the MIT License. # ------------------------------------ -VERSION = "5.15.1" +VERSION = "5.15.2" diff --git a/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/asynctests/test_websocket_transport_async_unit.py b/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/asynctests/test_websocket_transport_async_unit.py new file mode 100644 index 000000000000..5da0cebe2bc9 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/asynctests/test_websocket_transport_async_unit.py @@ -0,0 +1,150 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +"""Unit tests for WebSocketTransportAsync covering close/connect cleanup paths. + +These tests do not require a live Event Hubs instance; aiohttp is mocked so that +the cleanup behavior of WebSocketTransportAsync can be exercised in isolation. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from azure.eventhub._pyamqp.aio._transport_async import WebSocketTransportAsync + + +def _make_transport(): + transport = WebSocketTransportAsync("example.servicebus.windows.net") + transport.network_trace_params = {} + return transport + + +@pytest.mark.asyncio +async def test_close_calls_session_close_even_if_sock_close_raises(): + """If sock.close() raises, the aiohttp ClientSession must still be closed. + + Regression test for the leak where WebSocketTransportAsync.close() called + self.sock.close() and self.session.close() sequentially without try/except, + leaving the aiohttp ClientSession unclosed when sock.close() raised. + """ + transport = _make_transport() + sock = MagicMock() + sock.close = AsyncMock(side_effect=RuntimeError("ws already closed")) + session = MagicMock() + session.close = AsyncMock() + transport.sock = sock + transport.session = session + transport.connected = True + + await transport.close() + + sock.close.assert_awaited_once() + session.close.assert_awaited_once() + assert transport.connected is False + + +@pytest.mark.asyncio +async def test_close_handles_none_sock_and_session(): + """close() must not raise if sock/session were never assigned.""" + transport = _make_transport() + transport.sock = None + transport.session = None + transport.connected = True + + await transport.close() + + assert transport.connected is False + + +@pytest.mark.asyncio +async def test_close_swallows_session_close_errors(): + """Errors from session.close() must not propagate, mirroring the sibling + AsyncTransport.close() pattern which logs and continues.""" + transport = _make_transport() + sock = MagicMock() + sock.close = AsyncMock() + session = MagicMock() + session.close = AsyncMock(side_effect=RuntimeError("aiohttp boom")) + transport.sock = sock + transport.session = session + transport.connected = True + + await transport.close() + + session.close.assert_awaited_once() + assert transport.connected is False + + +@pytest.mark.asyncio +async def test_connect_closes_previous_session_on_reconnect(): + """When connect() is called and a previous session already exists (reconnect + path), the previous session must be closed before a new one is created. + """ + from aiohttp import ClientConnectorError + + transport = _make_transport() + previous_session = MagicMock() + previous_session.close = AsyncMock() + transport.session = previous_session + + # Force connect() to fail fast after the previous-session cleanup so we can + # assert the cleanup happened. ClientConnectorError is one of the existing + # handled exception types in connect(). + fake_session = MagicMock() + fake_session.ws_connect = AsyncMock( + side_effect=ClientConnectorError(MagicMock(), OSError("nope")) + ) + fake_session.close = AsyncMock() + + with patch( + "aiohttp.ClientSession", return_value=fake_session + ), pytest.raises(ConnectionError): + await transport.connect() + + previous_session.close.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_connect_closes_new_session_on_client_connector_error(): + """When ws_connect raises ClientConnectorError, the newly created + ClientSession must be closed before the error is re-raised.""" + from aiohttp import ClientConnectorError + + transport = _make_transport() + transport.session = None + + fake_session = MagicMock() + fake_session.ws_connect = AsyncMock( + side_effect=ClientConnectorError(MagicMock(), OSError("nope")) + ) + fake_session.close = AsyncMock() + + with patch( + "aiohttp.ClientSession", return_value=fake_session + ), pytest.raises(ConnectionError): + await transport.connect() + + fake_session.close.assert_awaited_once() + assert transport.session is None + + +@pytest.mark.asyncio +async def test_connect_closes_new_session_on_unexpected_exception(): + """When ws_connect raises something other than ClientConnectorError, the + newly created session must still be closed before the exception + propagates.""" + transport = _make_transport() + transport.session = None + + fake_session = MagicMock() + fake_session.ws_connect = AsyncMock(side_effect=RuntimeError("unexpected")) + fake_session.close = AsyncMock() + + with patch( + "aiohttp.ClientSession", return_value=fake_session + ), pytest.raises(RuntimeError): + await transport.connect() + + fake_session.close.assert_awaited_once() + assert transport.session is None