Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 28 additions & 199 deletions python/packages/jumpstarter/jumpstarter/exporter/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import logging
import os
import select
import time
from collections.abc import Awaitable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Literal
Expand All @@ -21,13 +19,6 @@

logger = logging.getLogger(__name__)

MAX_DRAIN_BYTES = 256 * 1024
DRAIN_TIMEOUT_SECONDS = 2.0
DRAIN_MAX_EMPTY_POLLS = 10

# Module-level reference to time.monotonic so tests can patch it without
# affecting the asyncio event loop (which also uses time.monotonic).
_monotonic = time.monotonic


def _flush_lines(buffer: bytes, output_lines: list[str]) -> bytes:
Expand Down Expand Up @@ -73,19 +64,6 @@ def should_end_lease(self) -> bool:
return self.on_failure in ("endLease", "exit")


@dataclass
class PtyState:
"""Mutable state for PTY file descriptors and reader coordination.

Tracks which fds are still open (for cleanup) and provides a separate
stop flag to signal the reader task without affecting fd lifecycle.
"""

parent_fd_open: bool = True
child_fd_open: bool = True
reader_stop: bool = False


@dataclass(kw_only=True)
class HookExecutor:
"""Executes lifecycle hooks with access to the j CLI."""
Expand Down Expand Up @@ -232,52 +210,32 @@ async def _execute_hook_process( # noqa: C901
logging_session: Session,
hook_type: Literal["before_lease", "after_lease"],
) -> str | None:
"""Execute the hook process with the given environment and logging session.

Uses subprocess with a PTY to force line buffering in the subprocess,
ensuring logs stream in real-time rather than being block-buffered.
"""Execute the hook process and capture its output via pipes.

Returns:
Warning message string if hook failed with on_failure='warn', None otherwise
"""
import pty
import subprocess

command = hook_config.script
timeout = hook_config.timeout
on_failure = hook_config.on_failure

# Exception handling
error_msg: str | None = None
cause: Exception | None = None
timed_out = False

# Route hook output logs to the client via the session's log stream
logger.debug("Entering log source context for %s", log_source)
with logging_session.context_log_source(__name__, log_source):
# Create a PTY pair - this forces line buffering in the subprocess
logger.debug("Starting hook subprocess...")
logger.debug("Creating PTY pair...")
try:
parent_fd, child_fd = pty.openpty()
except Exception as e:
logger.error("Failed to create PTY: %s", e, exc_info=True)
raise
logger.debug("PTY created: parent_fd=%d, child_fd=%d", parent_fd, child_fd)

pty_state = PtyState()

process: subprocess.Popen | None = None
try:
# Use subprocess.Popen with the PTY child as stdin/stdout/stderr
# This avoids the issues with os.fork() in async contexts
# Determine interpreter and invocation mode
script_stripped = command.strip()
is_file = "\n" not in script_stripped and os.path.isfile(script_stripped)

interpreter = hook_config.exec_
if is_file and interpreter is None:
# Auto-detect interpreter from file extension
import sys

ext = os.path.splitext(script_stripped)[1].lower()
Expand All @@ -301,216 +259,97 @@ async def _execute_hook_process( # noqa: C901
try:
process = subprocess.Popen(
cmd,
stdin=child_fd,
stdout=child_fd,
stderr=child_fd,
stdin=subprocess.DEVNULL,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
env=hook_env,
start_new_session=True, # Equivalent to os.setsid()
close_fds=True, # Close inherited fds to prevent interference with gRPC connections
process_group=0,
Comment thread
mangelajo marked this conversation as resolved.
close_fds=True,
)
except Exception as e:
logger.error("Failed to spawn subprocess: %s", e, exc_info=True)
raise
logger.debug("Subprocess spawned with PID %d", process.pid)
# Close child fd in parent process - subprocess has it now
os.close(child_fd)
pty_state.child_fd_open = False
logger.debug("Closed child_fd in parent process")

output_lines: list[str] = []

# Set parent fd to non-blocking mode
import fcntl

flags = fcntl.fcntl(parent_fd, fcntl.F_GETFL)
fcntl.fcntl(parent_fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
logger.debug("Parent fd set to non-blocking")
pipe_fd = process.stdout.fileno()
flags = fcntl.fcntl(pipe_fd, fcntl.F_GETFL)
fcntl.fcntl(pipe_fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)

async def read_pty_output() -> None: # noqa: C901
"""Read from PTY parent fd line by line using non-blocking I/O."""
logger.debug("read_pty_output task started")
async def read_output() -> None:
"""Read subprocess output via pipe using async non-blocking I/O."""
buffer = b""
read_count = 0
last_heartbeat = 0

start_time = _monotonic()
try:
while not pty_state.reader_stop:
while True:
try:
# Wait for fd to be readable with timeout
with anyio.move_on_after(0.1):
await anyio.wait_readable(parent_fd)

# Check stop flag immediately after timeout
# (main task may have signaled us to stop)
if pty_state.reader_stop:
logger.debug("read_pty_output: stop flag set, exiting")
await anyio.wait_readable(pipe_fd)
chunk = os.read(pipe_fd, 4096)
if not chunk:
break

read_count += 1
# Log heartbeat every 2 seconds
elapsed = _monotonic() - start_time
if elapsed - last_heartbeat >= 2.0:
logger.debug(
"read_pty_output: heartbeat at %.1fs, iterations=%d", elapsed, read_count
)
last_heartbeat = elapsed

# Read available data (non-blocking)
try:
chunk = os.read(parent_fd, 4096)
if not chunk:
# EOF
logger.debug("read_pty_output: EOF received")
break
buffer += chunk
except BlockingIOError:
# No data available right now, continue loop
continue
except OSError as e:
# PTY closed or error
logger.debug("read_pty_output: OSError on read: %s", e)
break

# Process complete lines
buffer = _flush_lines(buffer, output_lines)

buffer += chunk
except BlockingIOError:
continue
except OSError as e:
# PTY closed or read error
logger.debug("read_pty_output: OSError in loop: %s", e)
logger.debug("read_output: OSError: %s", e)
break
finally:
# Drain any remaining data from the PTY buffer.
# On macOS, PTY output may still be in the kernel buffer
# after the subprocess exits and the stop flag is set.
# Use select() with a timeout to poll for readability
# instead of immediately breaking on BlockingIOError,
# giving the macOS PTY kernel buffer time to deliver
# remaining data.
# Bound the drain to prevent spinning indefinitely if a
# grandchild process holds the PTY slave fd open.
try:
drain_deadline = _monotonic() + DRAIN_TIMEOUT_SECONDS
drained = 0
consecutive_empty = 0
while drained < MAX_DRAIN_BYTES and _monotonic() < drain_deadline:
# Poll for readability with a short timeout.
# This avoids the race where a non-blocking read
# raises BlockingIOError because the macOS PTY
# kernel buffer hasn't delivered the data yet.
remaining = drain_deadline - _monotonic()
if remaining <= 0:
break
timeout_s = min(remaining, 0.1)
try:
readable, _, _ = select.select([parent_fd], [], [], timeout_s)
except (ValueError, OSError):
# fd closed or invalid
break
if not readable:
# On macOS, data may not be available on the
# first select() call even though the subprocess
# has already written and exited. Keep retrying
# until we see several consecutive empty polls,
# which indicates the buffer is truly drained.
consecutive_empty += 1
if consecutive_empty >= DRAIN_MAX_EMPTY_POLLS:
break
continue
consecutive_empty = 0
try:
chunk = os.read(parent_fd, 4096)
if not chunk:
break
buffer += chunk
drained += len(chunk)
except (BlockingIOError, OSError):
break

buffer = _flush_lines(buffer, output_lines)
except Exception:
logger.debug("read_pty_output: error during drain", exc_info=True)

logger.debug("read_pty_output: exiting, processed %d iterations", read_count)
finally:
if buffer:
line_decoded = buffer.decode(errors="replace").rstrip()
if line_decoded:
output_lines.append(line_decoded)
logger.info("%s", line_decoded)

async def wait_for_process() -> int:
"""Wait for the subprocess to complete.

Ensures the subprocess is properly reaped even if cancelled,
preventing zombie processes.
"""
"""Wait for the subprocess to complete."""
logger.debug("wait_for_process: waiting for PID %d", process.pid)
try:
result = await anyio.to_thread.run_sync(process.wait, abandon_on_cancel=True)
logger.debug("wait_for_process: PID %d exited with code %d", process.pid, result)
return result
finally:
# Ensure subprocess is reaped on cancellation to prevent zombies
if process.poll() is None:
logger.debug("wait_for_process: cleaning up still-running PID %d", process.pid)
try:
process.terminate()
# Give it a moment to terminate gracefully
for _ in range(10):
if process.poll() is not None:
break
await anyio.sleep(0.1)
# Force kill if still running
if process.poll() is None:
logger.debug("wait_for_process: force killing PID %d", process.pid)
process.kill()
# Final reap with non-abandoning wait
await anyio.to_thread.run_sync(process.wait, abandon_on_cancel=False)
except Exception as e:
logger.debug("wait_for_process: error during cleanup: %s", e)

# Use move_on_after for timeout
returncode: int | None = None
logger.debug("Starting PTY output reader and process waiter (timeout=%d)", timeout)

# Yield to event loop to ensure other tasks can progress
# This helps prevent race conditions in task scheduling
await anyio.sleep(0)
logger.debug("Starting output reader and process waiter (timeout=%d)", timeout)

with anyio.move_on_after(timeout) as cancel_scope:
# Run output reading and process waiting concurrently
async with anyio.create_task_group() as tg:
logger.debug("Task group created, starting tasks...")
tg.start_soon(read_pty_output)
logger.debug("Waiting for subprocess to complete...")
tg.start_soon(read_output)
returncode = await wait_for_process()
logger.debug("Subprocess completed with code: %s", returncode)
# Give a brief moment for any final output to be read
await anyio.sleep(0.2)
# Signal the read task to stop via the dedicated stop flag.
# The read task checks this flag after each 0.1s timeout
# and also receives EOF when the subprocess exits.
# Note: pty_state.parent_fd_open stays True so the finally block
# properly closes parent_fd.
pty_state.reader_stop = True
logger.debug("Stop flag set, waiting for read task to exit")
# Don't cancel - let the task exit naturally via EOF or flag check
# Cancellation can cause unexpected side effects on gRPC connections
# Yield to let the LogStream deliver any pending
# messages before reporting the hook result.
await anyio.sleep(0)

if cancel_scope.cancelled_caught:
timed_out = True
error_msg = f"Hook timed out after {timeout} seconds"
logger.error(error_msg)
# Terminate the process
if process and process.poll() is None:
process.terminate()
# Give it a moment to terminate gracefully
try:
with anyio.move_on_after(5):
await anyio.to_thread.run_sync(process.wait, abandon_on_cancel=True)
except Exception:
pass
# Force kill if still running
if process.poll() is None:
process.kill()
try:
Expand All @@ -529,23 +368,13 @@ async def wait_for_process() -> int:
cause = e
logger.error(error_msg, exc_info=True)
finally:
# Clean up file descriptors - only close those still open to avoid
# closing an unrelated fd that reused the same number.
if pty_state.parent_fd_open:
try:
os.close(parent_fd)
except OSError:
pass
if pty_state.child_fd_open:
if process and process.stdout:
try:
os.close(child_fd)
process.stdout.close()
except OSError:
pass

# Handle failure inside context_log_source so the WARNING log is
# routed to the client as a hook log (visible without --exporter-logs).
if error_msg is not None:
# For timeout, create a TimeoutError as the cause
if timed_out and cause is None:
cause = TimeoutError(error_msg)
return self._handle_hook_failure(error_msg, on_failure, hook_type, cause)
Expand Down
Loading
Loading