Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 211 additions & 1 deletion crates/pilotty-cli/src/daemon/pty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,83 @@ use portable_pty::{native_pty_system, Child, CommandBuilder, MasterPty, PtySize}
use tokio::sync::mpsc;
use tracing::{debug, error, warn};

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum DsrParseState {
Ground,
Esc,
Csi,
Csi6,
}

/// Tracks terminal cursor position and detects DSR cursor-position requests.
struct CursorTracker {
parser: vt100::Parser,
dsr_state: DsrParseState,
}

impl CursorTracker {
fn new(size: TermSize) -> Self {
Self {
parser: vt100::Parser::new(size.rows, size.cols, 0),
dsr_state: DsrParseState::Ground,
}
}

fn resize(&mut self, size: TermSize) {
self.parser.screen_mut().set_size(size.rows, size.cols);
}

fn process_output(&mut self, bytes: &[u8]) -> Vec<(u16, u16)> {
let mut positions = Vec::new();
let mut segment_start = 0;

for (idx, &b) in bytes.iter().enumerate() {
if self.advance_dsr_state(b) {
self.parser.process(&bytes[segment_start..=idx]);
positions.push(self.cursor_position_one_indexed());
segment_start = idx + 1;
}
}

if segment_start < bytes.len() {
self.parser.process(&bytes[segment_start..]);
}

positions
}

fn cursor_position_one_indexed(&self) -> (u16, u16) {
let (row, col) = self.parser.screen().cursor_position();
(row.saturating_add(1), col.saturating_add(1))
}

fn advance_dsr_state(&mut self, b: u8) -> bool {
let mut detected = false;

self.dsr_state = match (self.dsr_state, b) {
(DsrParseState::Ground, 0x1b) => DsrParseState::Esc,
(DsrParseState::Ground, _) => DsrParseState::Ground,

(DsrParseState::Esc, b'[') => DsrParseState::Csi,
(DsrParseState::Esc, 0x1b) => DsrParseState::Esc,
(DsrParseState::Esc, _) => DsrParseState::Ground,

(DsrParseState::Csi, b'6') => DsrParseState::Csi6,
(DsrParseState::Csi, 0x1b) => DsrParseState::Esc,
(DsrParseState::Csi, _) => DsrParseState::Ground,

(DsrParseState::Csi6, b'n') => {
detected = true;
DsrParseState::Ground
}
(DsrParseState::Csi6, 0x1b) => DsrParseState::Esc,
(DsrParseState::Csi6, _) => DsrParseState::Ground,
};

detected
}
}

/// Terminal size in columns and rows.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TermSize {
Expand Down Expand Up @@ -130,6 +207,8 @@ pub struct AsyncPtyHandle {
reader_thread: Option<std::thread::JoinHandle<()>>,
/// Handle to the writer thread for cleanup.
writer_thread: Option<std::thread::JoinHandle<()>>,
/// Tracks cursor position from PTY output for DSR responses.
cursor_tracker: Arc<std::sync::Mutex<CursorTracker>>,
}

impl AsyncPtyHandle {
Expand All @@ -148,10 +227,20 @@ impl AsyncPtyHandle {
let (write_tx, write_rx) = mpsc::channel::<Vec<u8>>(64);
let (read_tx, read_rx) = mpsc::channel::<Vec<u8>>(64);

let cursor_tracker = Arc::new(std::sync::Mutex::new(CursorTracker::new(initial_size)));

// Spawn reader thread
let reader_shutdown = shutdown.clone();
let reader_write_tx = write_tx.clone();
let reader_tracker = cursor_tracker.clone();
let reader_thread = std::thread::spawn(move || {
Self::reader_loop(reader, read_tx, reader_shutdown);
Self::reader_loop(
reader,
read_tx,
reader_write_tx,
reader_tracker,
reader_shutdown,
);
});

// Spawn writer thread
Expand All @@ -168,6 +257,7 @@ impl AsyncPtyHandle {
size: std::sync::Mutex::new(initial_size),
reader_thread: Some(reader_thread),
writer_thread: Some(writer_thread),
cursor_tracker,
})
}

Expand All @@ -185,6 +275,11 @@ impl AsyncPtyHandle {
.size
.lock()
.map_err(|_| anyhow::anyhow!("Size mutex poisoned"))? = size;

self.cursor_tracker
.lock()
.map_err(|_| anyhow::anyhow!("Cursor tracker mutex poisoned"))?
.resize(size);
Ok(())
}
/// Send bytes to the PTY stdin.
Expand Down Expand Up @@ -245,6 +340,8 @@ impl AsyncPtyHandle {
fn reader_loop(
mut reader: Box<dyn Read + Send>,
read_tx: mpsc::Sender<Vec<u8>>,
write_tx: mpsc::Sender<Vec<u8>>,
cursor_tracker: Arc<std::sync::Mutex<CursorTracker>>,
shutdown: Arc<std::sync::atomic::AtomicBool>,
) {
let mut buf = vec![0u8; READ_BUFFER_SIZE];
Expand All @@ -261,6 +358,32 @@ impl AsyncPtyHandle {
break;
}
Ok(n) => {
let dsr_positions = match cursor_tracker.lock() {
Ok(mut tracker) => tracker.process_output(&buf[..n]),
Err(_) => {
warn!("Cursor tracker mutex poisoned; skipping DSR detection");
Vec::new()
}
};

if !dsr_positions.is_empty() {
for (row, col) in dsr_positions {
let response = format!("\x1b[{};{}R", row, col).into_bytes();
match write_tx.try_send(response) {
Ok(()) => {}
Err(mpsc::error::TrySendError::Full(_)) => {
warn!(
"PTY write channel full; dropping synthetic DSR response"
);
}
Err(mpsc::error::TrySendError::Closed(_)) => {
debug!("PTY write channel closed while sending DSR response");
break;
}
}
}
}

// Use blocking send since we're in a thread
if read_tx.blocking_send(buf[..n].to_vec()).is_err() {
debug!("PTY read channel closed");
Expand Down Expand Up @@ -347,9 +470,50 @@ impl Drop for AsyncPtyHandle {
#[cfg(test)]
mod tests {
use super::*;
use regex::Regex;
use std::io::Read;
use std::time::Duration;

#[test]
fn test_cursor_tracker_collects_cursor_positions_for_dsr_queries() {
let mut tracker = CursorTracker::new(TermSize { cols: 80, rows: 24 });

assert_eq!(tracker.process_output(b"hello"), Vec::<(u16, u16)>::new());
assert_eq!(tracker.process_output(b"\x1b[6n"), vec![(1, 6)]);
assert_eq!(
tracker.process_output(b"x\x1b[6n\x1b[6n"),
vec![(1, 7), (1, 7)]
);
}

#[test]
fn test_cursor_tracker_detects_dsr_across_chunk_boundary() {
let mut tracker = CursorTracker::new(TermSize { cols: 80, rows: 24 });

assert_eq!(tracker.process_output(b"\x1b["), Vec::<(u16, u16)>::new());
assert_eq!(tracker.process_output(b"6"), Vec::<(u16, u16)>::new());
assert_eq!(tracker.process_output(b"n"), vec![(1, 1)]);
}

#[test]
fn test_cursor_tracker_captures_position_at_dsr_boundary_with_trailing_output() {
let mut tracker = CursorTracker::new(TermSize { cols: 80, rows: 24 });

let positions = tracker.process_output(b"abc\x1b[6nxyz");

assert_eq!(positions, vec![(1, 4)]);
assert_eq!(tracker.cursor_position_one_indexed(), (1, 7));
}

#[test]
fn test_cursor_tracker_reports_one_indexed_position() {
let mut tracker = CursorTracker::new(TermSize { cols: 80, rows: 24 });
tracker.process_output(b"abc");

// Cursor is after 3 chars on first row => 1-indexed (1, 4)
assert_eq!(tracker.cursor_position_one_indexed(), (1, 4));
}

#[test]
fn test_spawn_echo_and_read_output() {
let session = PtySession::spawn(
Expand Down Expand Up @@ -509,6 +673,52 @@ mod tests {
.expect("resize to smaller should succeed");
}

#[tokio::test]
async fn test_async_pty_responds_to_dsr_cursor_query() {
let session = PtySession::spawn(
&[
"bash".to_string(),
"-lc".to_string(),
"printf 'abc\\033[6nxyz'; IFS=';' read -r -d R row col; printf 'ok:%s;%s\\n' \"$row\" \"$col\""
.to_string(),
],
TermSize::default(),
None,
)
.expect("Failed to spawn bash DSR test");

let handle = AsyncPtyHandle::new(session).expect("Failed to create async handle");

let result = tokio::time::timeout(Duration::from_secs(2), async {
let mut collected = Vec::new();
while let Some(chunk) = handle.read().await {
collected.extend_from_slice(&chunk);
if String::from_utf8_lossy(&collected).contains("ok:") {
break;
}
}
collected
})
.await;

let bytes = result.expect("Timed out waiting for DSR response");
let output = String::from_utf8_lossy(&bytes);
let captures = Regex::new(r"ok:\x1b\[(\d+);(\d+)")
.expect("valid regex")
.captures(&output)
.expect("Expected shell to receive DSR response with cursor coordinates");
let row: u16 = captures[1].parse().expect("row should parse");
let col: u16 = captures[2].parse().expect("col should parse");

assert_eq!(
(row, col),
(1, 4),
"Unexpected DSR response in output: {output:?}"
);

Comment on lines +705 to +718
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strengthened the async E2E test to validate the actual cursor report. The test now writes a known sequence (abc + ESC[6n + xyz), parses ok:\x1b[{row};{col}, and asserts (row, col) == (1, 4).

handle.shutdown().await;
}

#[test]
fn test_spawn_with_cwd() {
// Spawn pwd in /tmp and verify it outputs a path containing "tmp"
Expand Down