Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions sdk/eventhub/azure-eventhub/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Release History

## 5.15.2 (Unreleased)

### Bugs Fixed

- Fixed a bug where `WebSocketTransportAsync.close()` could leak an `aiohttp.ClientSession` when the underlying websocket close raised. Also fixed session leaks on reconnect and on non-`ClientConnectorError` failures in `connect()`.

## 5.15.1 (2025-11-11)

### Bugs Fixed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,10 @@ async def connect(self):

http_proxy_auth = BasicAuth(login=username, password=password)

# Close any session left over from a previous connect attempt so that
# the aiohttp ClientSession is not leaked on reconnect.
await self._close_session_safely()

self.session = ClientSession()
if self._custom_endpoint:
url = f"wss://{self._custom_endpoint}" if self._use_tls else f"ws://{self._custom_endpoint}"
Expand Down Expand Up @@ -477,15 +481,34 @@ async def connect(self):
)
except ClientConnectorError as exc:
_LOGGER.info("Websocket connect failed: %r", exc, extra=self.network_trace_params)
await self._close_session_safely()
if self._custom_endpoint:
raise AuthenticationException(
ErrorCondition.ClientError,
description="Failed to authenticate the connection due to exception: " + str(exc),
error=exc,
) from exc
raise ConnectionError("Failed to establish websocket connection: " + str(exc)) from exc
except Exception: # pylint: disable=broad-except
# Any other failure during ws_connect must also clean up the new
# session so the aiohttp ClientSession is not leaked.
await self._close_session_safely()
raise
self.connected = True

async def _close_session_safely(self):
"""Close ``self.session`` if set, suppressing and logging any errors."""
if self.session is not None:
try:
await self.session.close()
except Exception as e: # pylint: disable=broad-except
_LOGGER.debug(
"Error closing aiohttp session: %r",
e,
extra=self.network_trace_params,
)
self.session = None

async def _read(self, toread, buffer=None, **kwargs): # pylint: disable=unused-argument
"""Read exactly n bytes from the peer.

Expand Down Expand Up @@ -524,9 +547,16 @@ async def _read(self, toread, buffer=None, **kwargs): # pylint: disable=unused-
async def close(self):
"""Do any preliminary work in shutting down the connection."""
async with self.socket_lock:
await self.sock.close()
await self.session.close()
self.connected = False
try:
if self.sock is not None:
await self.sock.close()
except Exception as e: # pylint: disable=broad-except
_LOGGER.debug(
"Error closing websocket: %r", e, extra=self.network_trace_params
)
self.sock = None
await self._close_session_safely()
self.connected = False

async def _write(self, s):
"""Completely write a string (byte array) to the peer.
Expand Down
2 changes: 1 addition & 1 deletion sdk/eventhub/azure-eventhub/azure/eventhub/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
# Licensed under the MIT License.
# ------------------------------------

VERSION = "5.15.1"
VERSION = "5.15.2"
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
"""Unit tests for WebSocketTransportAsync covering close/connect cleanup paths.

These tests do not require a live Event Hubs instance; aiohttp is mocked so that
the cleanup behavior of WebSocketTransportAsync can be exercised in isolation.
"""

import pytest
from unittest.mock import AsyncMock, MagicMock, patch

from azure.eventhub._pyamqp.aio._transport_async import WebSocketTransportAsync


def _make_transport():
transport = WebSocketTransportAsync("example.servicebus.windows.net")
transport.network_trace_params = {}
return transport


@pytest.mark.asyncio
async def test_close_calls_session_close_even_if_sock_close_raises():
"""If sock.close() raises, the aiohttp ClientSession must still be closed.

Regression test for the leak where WebSocketTransportAsync.close() called
self.sock.close() and self.session.close() sequentially without try/except,
leaving the aiohttp ClientSession unclosed when sock.close() raised.
"""
transport = _make_transport()
sock = MagicMock()
sock.close = AsyncMock(side_effect=RuntimeError("ws already closed"))
session = MagicMock()
session.close = AsyncMock()
transport.sock = sock
transport.session = session
transport.connected = True

await transport.close()

sock.close.assert_awaited_once()
session.close.assert_awaited_once()
assert transport.connected is False


@pytest.mark.asyncio
async def test_close_handles_none_sock_and_session():
"""close() must not raise if sock/session were never assigned."""
transport = _make_transport()
transport.sock = None
transport.session = None
transport.connected = True

await transport.close()

assert transport.connected is False


@pytest.mark.asyncio
async def test_close_swallows_session_close_errors():
"""Errors from session.close() must not propagate, mirroring the sibling
AsyncTransport.close() pattern which logs and continues."""
transport = _make_transport()
sock = MagicMock()
sock.close = AsyncMock()
session = MagicMock()
session.close = AsyncMock(side_effect=RuntimeError("aiohttp boom"))
transport.sock = sock
transport.session = session
transport.connected = True

await transport.close()

session.close.assert_awaited_once()
assert transport.connected is False


@pytest.mark.asyncio
async def test_connect_closes_previous_session_on_reconnect():
"""When connect() is called and a previous session already exists (reconnect
path), the previous session must be closed before a new one is created.
"""
from aiohttp import ClientConnectorError

transport = _make_transport()
previous_session = MagicMock()
previous_session.close = AsyncMock()
transport.session = previous_session

# Force connect() to fail fast after the previous-session cleanup so we can
# assert the cleanup happened. ClientConnectorError is one of the existing
# handled exception types in connect().
fake_session = MagicMock()
fake_session.ws_connect = AsyncMock(
side_effect=ClientConnectorError(MagicMock(), OSError("nope"))
)
fake_session.close = AsyncMock()

with patch(
"aiohttp.ClientSession", return_value=fake_session
), pytest.raises(ConnectionError):
await transport.connect()

previous_session.close.assert_awaited_once()


@pytest.mark.asyncio
async def test_connect_closes_new_session_on_client_connector_error():
"""When ws_connect raises ClientConnectorError, the newly created
ClientSession must be closed before the error is re-raised."""
from aiohttp import ClientConnectorError

transport = _make_transport()
transport.session = None

fake_session = MagicMock()
fake_session.ws_connect = AsyncMock(
side_effect=ClientConnectorError(MagicMock(), OSError("nope"))
)
fake_session.close = AsyncMock()

with patch(
"aiohttp.ClientSession", return_value=fake_session
), pytest.raises(ConnectionError):
await transport.connect()

fake_session.close.assert_awaited_once()
assert transport.session is None


@pytest.mark.asyncio
async def test_connect_closes_new_session_on_unexpected_exception():
"""When ws_connect raises something other than ClientConnectorError, the
newly created session must still be closed before the exception
propagates."""
transport = _make_transport()
transport.session = None

fake_session = MagicMock()
fake_session.ws_connect = AsyncMock(side_effect=RuntimeError("unexpected"))
fake_session.close = AsyncMock()

with patch(
"aiohttp.ClientSession", return_value=fake_session
), pytest.raises(RuntimeError):
await transport.connect()

fake_session.close.assert_awaited_once()
assert transport.session is None