Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
1852ad6
fix(broker): clean up connection error output (#539)
oboehmer Mar 30, 2026
21a239d
fix(broker): fix deadlock in BrokerClient.connect() on connection fai…
oboehmer Mar 30, 2026
6c539cf
test(broker): add failure-mode unit tests for ConnectionBroker (#486)
oboehmer Mar 30, 2026
459e42c
fix(ssh): harden broker socket validation and expand tests (#485)
oboehmer Mar 30, 2026
7d3cada
refactor(broker): extract _teardown_connection() to eliminate teardow…
oboehmer Mar 30, 2026
9db8035
fix(ssh): validate NAC_TEST_BROKER_SOCKET with is_socket()
oboehmer Mar 30, 2026
9406e3d
move broker failure-mode tests to integration and tighten assertions
oboehmer Mar 31, 2026
7ea56fd
test(broker): cover stale socket failure and organize failure-mode tests
oboehmer Mar 31, 2026
fb08d8b
test(broker): cover socket deletion mid-run
oboehmer Mar 31, 2026
f7b407b
test(broker): cover reconnect when cached connection is unhealthy
oboehmer Mar 31, 2026
b955af1
test(broker): add concurrent connect race coverage
oboehmer Mar 31, 2026
4adc76e
## Commit message
oboehmer Mar 31, 2026
025cd63
## Commit message
oboehmer Mar 31, 2026
f251e5b
test(broker): add missing-param + protocol/framing integration coverage
oboehmer Mar 31, 2026
6abe5d4
## Commit message
oboehmer Mar 31, 2026
2eda128
fix(test): add missing pytestmark to integration broker tests
oboehmer Apr 14, 2026
715d9bf
refactor(test): consolidate socket_dir fixture to root conftest
oboehmer Apr 14, 2026
693f2fc
refactor(test): move broker unit tests to mirror source path
oboehmer Apr 14, 2026
0f0ea46
refactor(test): remove tautological broker client unit tests
oboehmer Apr 14, 2026
4a6d1f1
refactor(test): move non-socket tests from integration to unit file
oboehmer Apr 14, 2026
f4a85f1
refactor(test): convert remaining inline teardown to _run_broker helper
oboehmer Apr 14, 2026
ff76be4
fix(broker): make _create_connection raise consistently instead of re…
oboehmer Apr 14, 2026
83c0869
fix(broker): use fixed-format error log in _create_connection
oboehmer Apr 14, 2026
b11053d
fix(broker): downgrade BrokerClient.connect() log to debug
oboehmer Apr 14, 2026
af88e7c
fix(broker): replace assert with RuntimeError guards in _send_request
oboehmer Apr 14, 2026
0e0cbae
fix(broker): add MAX_BROKER_MESSAGE_BYTES limit to _handle_client
oboehmer Apr 14, 2026
d2db443
fix(broker): suppress redundant hostname in _create_connection error log
oboehmer Apr 14, 2026
8a85306
Merge remote-tracking branch 'origin/main' into feat/539-486-585-brok…
oboehmer Apr 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 29 additions & 13 deletions nac_test/pyats_core/broker/broker_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,23 @@ def _get_socket_path(self) -> Path:
"or ensure connection broker is running."
)

def _teardown_connection(self) -> None:
"""Reset connection state without acquiring the lock.

Must be called while already holding _connection_lock. Extracted so
that both connect()'s error path and disconnect() share identical
teardown logic.

Note: does not call writer.wait_closed() because on the connect()
failure path there is nothing meaningful buffered to flush.
disconnect() calls wait_closed() explicitly after this method returns.
"""
if self.writer:
self.writer.close()
self.reader = None
self.writer = None
self._connected = False

async def connect(self) -> None:
"""Connect to the broker service."""
async with self._connection_lock:
Expand All @@ -72,20 +89,17 @@ async def connect(self) -> None:
await self._send_request({"command": "ping"})

except Exception as e:
logger.error(f"Failed to connect to broker: {e}")
await self.disconnect()
logger.debug(f"Failed to connect to broker: {e}")
self._teardown_connection()
raise ConnectionError(f"Cannot connect to broker: {e}") from e

async def disconnect(self) -> None:
"""Disconnect from the broker service."""
async with self._connection_lock:
if self.writer:
self.writer.close()
await self.writer.wait_closed()

self.reader = None
self.writer = None
self._connected = False
writer = self.writer
self._teardown_connection()
if writer:
await writer.wait_closed()

async def _send_request(self, request: dict[str, Any]) -> dict[str, Any]:
"""Send request to broker and return response."""
Expand All @@ -98,8 +112,10 @@ async def _send_request(self, request: dict[str, Any]) -> dict[str, Any]:
request_length = len(request_data).to_bytes(4, byteorder="big")

# Send request
assert self.writer is not None, "Writer must be connected"
assert self.reader is not None, "Reader must be connected"
if self.writer is None:
raise RuntimeError("Writer must be connected")
if self.reader is None:
raise RuntimeError("Reader must be connected")

self.writer.write(request_length + request_data)
await self.writer.drain()
Expand All @@ -120,7 +136,7 @@ async def _send_request(self, request: dict[str, Any]) -> dict[str, Any]:
return response # type: ignore[no-any-return]

except Exception as e:
logger.error(f"Error communicating with broker: {e}")
logger.debug(f"Error communicating with broker: {e}")
# Reset connection on error
await self.disconnect()
raise
Expand Down Expand Up @@ -165,7 +181,7 @@ async def ensure_connection(self, hostname: str) -> bool:
return response.get("result", False) # type: ignore[no-any-return]

except Exception as e:
logger.error(f"Failed to ensure connection to {hostname}: {e}")
logger.debug(f"Failed to ensure connection to {hostname}: {e}")
return False

async def disconnect_device(self, hostname: str) -> None:
Expand Down
56 changes: 38 additions & 18 deletions nac_test/pyats_core/broker/connection_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pathlib import Path
from typing import Any

from nac_test.pyats_core.constants import MAX_BROKER_MESSAGE_BYTES
from nac_test.pyats_core.ssh.command_cache import CommandCache
from nac_test.utils import get_or_create_event_loop

Expand Down Expand Up @@ -147,6 +148,13 @@ async def _handle_client(
if message_length == 0:
break

if message_length > MAX_BROKER_MESSAGE_BYTES:
logger.warning(
f"Client {client_addr} sent oversized frame "
f"({message_length} bytes, limit {MAX_BROKER_MESSAGE_BYTES})"
)
break

# Read message data
message_data = await reader.readexactly(message_length)
message = json.loads(message_data.decode("utf-8"))
Expand Down Expand Up @@ -198,8 +206,14 @@ async def _process_request(self, message: dict[str, Any]) -> dict[str, Any]:
if not hostname:
return {"status": "error", "error": "Missing hostname parameter"}

success = await self._ensure_connection(hostname)
return {"status": "success" if success else "error", "result": success}
success, error_msg = await self._ensure_connection(hostname)
if success:
return {"status": "success", "result": True}
else:
return {
"status": "error",
"error": error_msg or f"Failed to connect to {hostname}",
}

elif command == "disconnect":
hostname = message.get("hostname")
Expand All @@ -218,7 +232,7 @@ async def _process_request(self, message: dict[str, Any]) -> dict[str, Any]:
return {"status": "error", "error": f"Unknown command: {command}"}

except Exception as e:
logger.error(f"Error processing request: {e}", exc_info=True)
logger.error(f"Error processing request: {e}")
return {"status": "error", "error": str(e)}

async def _execute_command(self, hostname: str, cmd: str) -> str:
Expand Down Expand Up @@ -249,8 +263,6 @@ async def _execute_command(self, hostname: str, cmd: str) -> str:

# Ensure device is connected
connection = await self._get_connection(hostname)
if not connection:
raise ConnectionError(f"Failed to connect to device: {hostname}")

# Execute command in thread pool (since Unicon is synchronous)
loop = get_or_create_event_loop()
Expand All @@ -271,7 +283,7 @@ async def _execute_command(self, hostname: str, cmd: str) -> str:
await self._disconnect_device(hostname)
raise

async def _get_connection(self, hostname: str) -> Any | None:
async def _get_connection(self, hostname: str) -> Any:
"""Get or create connection to device."""
if hostname not in self.connection_locks:
self.connection_locks[hostname] = asyncio.Lock()
Expand Down Expand Up @@ -302,14 +314,12 @@ async def _get_connection(self, hostname: str) -> Any | None:
)
return await self._create_connection(hostname)

async def _create_connection(self, hostname: str) -> Any | None:
async def _create_connection(self, hostname: str) -> Any:
"""Create new connection to device using testbed."""
if not self.testbed:
logger.error("No testbed loaded")
return None
raise ConnectionError(f"No testbed loaded for {hostname}")
if hostname not in self.testbed.devices:
logger.error(f"Device {hostname} not found in testbed")
return None
raise ConnectionError(f"Device {hostname} not found in testbed")

async with self.connection_semaphore:
try:
Expand Down Expand Up @@ -340,13 +350,23 @@ async def _create_connection(self, hostname: str) -> Any | None:
return device

except Exception as e:
logger.error(f"Failed to connect to {hostname}: {e}", exc_info=True)
return None

async def _ensure_connection(self, hostname: str) -> bool:
"""Ensure device is connected, return success status."""
connection = await self._get_connection(hostname)
return connection is not None
# pyATS exceptions often embed the hostname already
# (e.g. "failed to connect to iosxe-r1"), so only prepend
# our "Failed to connect to <host>:" prefix when the
# hostname is absent — otherwise the log looks redundant.
msg = f"{type(e).__name__}: {e}"
if hostname not in str(e):
msg = f"Failed to connect to {hostname}: {msg}"
logger.error(msg)
raise

async def _ensure_connection(self, hostname: str) -> tuple[bool, str]:
"""Ensure device is connected, return (success, error_message)."""
try:
await self._get_connection(hostname)
return True, ""
except Exception as e:
return False, str(e)

async def _disconnect_device(self, hostname: str) -> None:
"""Disconnect from device and clean up."""
Expand Down
26 changes: 17 additions & 9 deletions nac_test/pyats_core/common/ssh_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import os
from collections.abc import Callable, Coroutine
from pathlib import Path
from typing import Any

from pyats import aetest
Expand Down Expand Up @@ -188,9 +189,17 @@ async def _async_setup(self, hostname: str) -> None:
"""Helper for async setup operations with connection error handling."""
try:
# Check if broker is active (priority over testbed to enable connection pooling)
broker_active = "NAC_TEST_BROKER_SOCKET" in os.environ
broker_socket_env = os.environ.get("NAC_TEST_BROKER_SOCKET")
broker_socket = Path(broker_socket_env) if broker_socket_env else None

if broker_active:
if broker_socket is not None and not broker_socket.is_socket():
self.logger.warning(
f"NAC_TEST_BROKER_SOCKET is set but {broker_socket} is not a valid "
f"Unix socket, falling back to direct connection"
)
broker_socket = None

if broker_socket is not None:
# Use broker client for connection management
# Testbed may still be available for Genie parsers
self.logger.info(
Expand All @@ -217,15 +226,14 @@ async def _async_setup(self, hostname: str) -> None:
"broker not active and testbed not available"
)

except ConnectionError:
# Already logged at source (broker or testbed layer) — just re-raise
raise
except Exception as e:
# Connection failed - raise exception to be caught in setup_ssh_context
error_msg = f"Failed to connect to device {hostname}: {str(e)}"
# Unexpected error — log here since no lower layer will have done so
error_msg = f"Failed to connect to {hostname}: {e}"
self.logger.error(error_msg)

# Raise with a clear message that will be caught by the calling method
raise ConnectionError(
f"Device connection failed: {hostname}\nError: {str(e)}"
) from e
raise ConnectionError(error_msg) from e

# 2. Create and attach the command cache
self.command_cache = CommandCache(hostname)
Expand Down
7 changes: 7 additions & 0 deletions nac_test/pyats_core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@
PYATS_POST_DISCONNECT_WAIT_SECONDS: int = 0
PYATS_GRACEFUL_DISCONNECT_WAIT_SECONDS: int = 0

# Connection broker protocol limits
# A 4-byte unsigned length prefix can represent ~4 GB; without this guard a buggy
# client could exhaust broker memory via a single oversized frame.
MAX_BROKER_MESSAGE_BYTES: int = 10 * 1024 * 1024 # 10 MB

# Multi-job execution configuration (to avoid reporter crashes)
TESTS_PER_JOB: int = 15 # Reduced from 20 for safety margin - each test ~1500 steps
MAX_PARALLEL_JOBS: int = 2 # Conservative parallelism to avoid resource exhaustion
Expand Down Expand Up @@ -136,6 +141,8 @@
"AUTH_CACHE_DIR",
"PYATS_POST_DISCONNECT_WAIT_SECONDS",
"PYATS_GRACEFUL_DISCONNECT_WAIT_SECONDS",
# Connection broker protocol limits
"MAX_BROKER_MESSAGE_BYTES",
# Multi-job execution
"TESTS_PER_JOB",
"MAX_PARALLEL_JOBS",
Expand Down
13 changes: 13 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import os
import re
import tempfile
from collections.abc import Generator
from pathlib import Path

Expand Down Expand Up @@ -141,3 +142,15 @@ def mock_api_server_preflight_401() -> Generator[MockAPIServer, None, None]:
mock_api_server so no mutation of the session-wide server is needed.
"""
yield from _start_mock_server(MOCK_API_CONFIG_PREFLIGHT_401_PATH)


# =============================================================================
# Shared test fixtures (used by unit, integration, and e2e tests)
# =============================================================================


@pytest.fixture()
def socket_dir() -> Generator[Path, None, None]:
"""Short-path temp dir suitable for Unix socket paths (macOS 104-char limit)."""
with tempfile.TemporaryDirectory() as d:
yield Path(d)
Loading
Loading