diff --git a/sdk/eventhub/azure-eventhub/CHANGELOG.md b/sdk/eventhub/azure-eventhub/CHANGELOG.md index f287cb0b8fde..1253ff9b264a 100644 --- a/sdk/eventhub/azure-eventhub/CHANGELOG.md +++ b/sdk/eventhub/azure-eventhub/CHANGELOG.md @@ -1,5 +1,11 @@ # Release History +## 5.15.2 (Unreleased) + +### Bugs Fixed + +- Fixed a bug where `Connection._disconnect()` early-returned when state was already `END`, so the underlying transport was never closed if `Connection.close()` entered its exception handler (e.g. network error, timeout, or already-closed peer during the AMQP close handshake). With `TransportType.AmqpOverWebsocket`, the leaked transport's `aiohttp.ClientSession` produced an `Unclosed client session` warning per affected partition. + ## 5.15.1 (2025-11-11) ### Bugs Fixed diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py index 1dfd1d852f6b..5b56605fc225 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py @@ -202,6 +202,7 @@ def __init__( # pylint:disable=too-many-locals self._error: Optional[AMQPConnectionError] = None self._outgoing_endpoints: Dict[int, Session] = {} self._incoming_endpoints: Dict[int, Session] = {} + self._transport_closed: bool = False def __enter__(self) -> "Connection": self.open() @@ -258,11 +259,26 @@ def _connect(self) -> None: ) from exc def _disconnect(self) -> None: - """Disconnect the transport and set state to END.""" - if self.state == ConnectionState.END: + """Disconnect the transport and set state to END. + + ``transport.close()`` is gated on ``self._transport_closed`` so that it + runs exactly once regardless of which code path drives the shutdown. + Without this, ``Connection.close()`` could enter its exception handler, + set the state to ``END`` without closing the transport, and the + subsequent ``_disconnect()`` call from the ``finally`` block would + early-return on the state check — leaking the underlying transport. + """ + if self._transport_closed: return - self._set_state(ConnectionState.END) - self._transport.close() + self._transport_closed = True + if self.state != ConnectionState.END: + self._set_state(ConnectionState.END) + try: + self._transport.close() + except Exception as e: # pylint: disable=broad-except + _LOGGER.debug( + "Error closing transport: %r", e, extra=self._network_trace_params + ) def _can_read(self) -> bool: """Whether the connection is in a state where it is legal to read for incoming frames. diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py index 681cf72c8082..6d1f8bc4da3c 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py @@ -183,6 +183,7 @@ def __init__( # pylint:disable=too-many-locals self._error: Optional[AMQPConnectionError] = None self._outgoing_endpoints: Dict[int, Session] = {} self._incoming_endpoints: Dict[int, Session] = {} + self._transport_closed: bool = False async def __aenter__(self) -> "Connection": await self.open() @@ -240,11 +241,28 @@ async def _connect(self) -> None: ) from exc async def _disconnect(self) -> None: - """Disconnect the transport and set state to END.""" - if self.state == ConnectionState.END: + """Disconnect the transport and set state to END. + + ``transport.close()`` is gated on ``self._transport_closed`` so that it + runs exactly once regardless of which code path drives the shutdown. + Without this, ``Connection.close()`` could enter its exception handler, + set the state to ``END`` without closing the transport, and the + subsequent ``_disconnect()`` call from the ``finally`` block would + early-return on the state check — leaking the underlying transport + (most notably the aiohttp ``ClientSession`` used by the websocket + transport). + """ + if self._transport_closed: return - await self._set_state(ConnectionState.END) - await self._transport.close() + self._transport_closed = True + if self.state != ConnectionState.END: + await self._set_state(ConnectionState.END) + try: + await self._transport.close() + except Exception as e: # pylint: disable=broad-except + _LOGGER.debug( + "Error closing transport: %r", e, extra=self._network_trace_params + ) def _can_read(self) -> bool: """Whether the connection is in a state where it is legal to read for incoming frames. diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py index 5bb5c30266da..5d660bbf8ea3 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py @@ -3,4 +3,4 @@ # Licensed under the MIT License. # ------------------------------------ -VERSION = "5.15.1" +VERSION = "5.15.2" diff --git a/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/asynctests/test_connection_disconnect_async_unit.py b/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/asynctests/test_connection_disconnect_async_unit.py new file mode 100644 index 000000000000..e974dcef9e67 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/asynctests/test_connection_disconnect_async_unit.py @@ -0,0 +1,110 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +"""Unit tests for Connection._disconnect() cleanup paths. + +Covers the regression where Connection.close() entering its exception handler +set state to END without closing the transport, and the subsequent +_disconnect() call (from the finally block) early-returned and never closed +the transport, leaking the aiohttp ClientSession. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from azure.eventhub._pyamqp.aio._connection_async import Connection +from azure.eventhub._pyamqp.constants import ConnectionState + + +def _make_connection(): + """Build a Connection without going through __init__ (which opens a real + transport). Only the attributes touched by _disconnect/close are set.""" + connection = Connection.__new__(Connection) + connection.state = ConnectionState.START + connection._transport = MagicMock() + connection._transport.close = AsyncMock() + connection._network_trace_params = { + "amqpConnection": "test", + "amqpSession": "", + "amqpLink": "", + } + connection._outgoing_endpoints = {} + connection._transport_closed = False + return connection + + +@pytest.mark.asyncio +async def test_disconnect_closes_transport_when_state_already_end(): + """When Connection.close() enters its exception handler it sets state to + END before calling _disconnect() in the finally block. The previous + implementation early-returned in that case and never closed the transport. + """ + connection = _make_connection() + connection.state = ConnectionState.END + + await connection._disconnect() + + connection._transport.close.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_disconnect_is_idempotent(): + """_disconnect() may be called more than once (e.g. once from + _incoming_close and again from Connection.close()'s finally). The transport + must only be closed once.""" + connection = _make_connection() + + await connection._disconnect() + await connection._disconnect() + + connection._transport.close.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_disconnect_sets_state_to_end_when_not_already(): + """The normal _disconnect() path still transitions state to END.""" + connection = _make_connection() + assert connection.state != ConnectionState.END + + await connection._disconnect() + + assert connection.state == ConnectionState.END + connection._transport.close.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_disconnect_swallows_transport_close_errors(): + """Errors from transport.close() must not propagate out of _disconnect() — + the connection is shutting down and any leaked resource will be GC'd.""" + connection = _make_connection() + connection._transport.close = AsyncMock(side_effect=RuntimeError("boom")) + + await connection._disconnect() + + connection._transport.close.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_close_closes_transport_when_outgoing_close_raises(): + """Regression test at the public API surface. + + When Connection.close() hits its exception path (here, _outgoing_close + raising), it sets state to END and falls through to _disconnect() in the + finally block. The transport (the aiohttp ClientSession for the websocket + transport) must still be closed exactly once rather than leaked.""" + connection = _make_connection() + connection.state = ConnectionState.OPENED + connection._error = None + connection._outgoing_close = AsyncMock(side_effect=RuntimeError("boom")) + + async def _set_state(new_state): + connection.state = new_state + + connection._set_state = AsyncMock(side_effect=_set_state) + + await connection.close() + + connection._outgoing_close.assert_awaited_once() + assert connection.state == ConnectionState.END + connection._transport.close.assert_awaited_once() diff --git a/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/synctests/test_connection_disconnect_unit.py b/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/synctests/test_connection_disconnect_unit.py new file mode 100644 index 000000000000..640f2c7130e6 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/synctests/test_connection_disconnect_unit.py @@ -0,0 +1,89 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +"""Unit tests for Connection._disconnect() cleanup paths (sync). + +Mirrors the async tests in tests/pyamqp_tests/asynctests/. See that file for +the full bug description. +""" + +from unittest.mock import MagicMock + +from azure.eventhub._pyamqp._connection import Connection +from azure.eventhub._pyamqp.constants import ConnectionState + + +def _make_connection(): + """Build a Connection without going through __init__ (which opens a real + transport). Only the attributes touched by _disconnect/close are set.""" + connection = Connection.__new__(Connection) + connection.state = ConnectionState.START + connection._transport = MagicMock() + connection._network_trace_params = { + "amqpConnection": "test", + "amqpSession": "", + "amqpLink": "", + } + connection._outgoing_endpoints = {} + connection._transport_closed = False + return connection + + +def test_disconnect_closes_transport_when_state_already_end(): + connection = _make_connection() + connection.state = ConnectionState.END + + connection._disconnect() + + connection._transport.close.assert_called_once() + + +def test_disconnect_is_idempotent(): + connection = _make_connection() + + connection._disconnect() + connection._disconnect() + + connection._transport.close.assert_called_once() + + +def test_disconnect_sets_state_to_end_when_not_already(): + connection = _make_connection() + assert connection.state != ConnectionState.END + + connection._disconnect() + + assert connection.state == ConnectionState.END + connection._transport.close.assert_called_once() + + +def test_disconnect_swallows_transport_close_errors(): + connection = _make_connection() + connection._transport.close.side_effect = RuntimeError("boom") + + connection._disconnect() + + connection._transport.close.assert_called_once() + + +def test_close_closes_transport_when_outgoing_close_raises(): + """Regression test at the public API surface. + + When Connection.close() hits its exception path (here, _outgoing_close + raising), it sets state to END and falls through to _disconnect() in the + finally block. The transport must still be closed exactly once rather than + leaked.""" + connection = _make_connection() + connection.state = ConnectionState.OPENED + connection._error = None + connection._outgoing_close = MagicMock(side_effect=RuntimeError("boom")) + connection._set_state = MagicMock( + side_effect=lambda new_state: setattr(connection, "state", new_state) + ) + + connection.close() + + connection._outgoing_close.assert_called_once() + assert connection.state == ConnectionState.END + connection._transport.close.assert_called_once()