diff --git a/README.md b/README.md index e52b319..c7b9185 100644 --- a/README.md +++ b/README.md @@ -219,7 +219,7 @@ All public APIs are covered by runnable examples under `./examples`, and the tes | **Process** | [GracefulShutdown](#gracefulshutdown) [Interrupt](#interrupt) [KillAfter](#killafter) [Send](#send) [Terminate](#terminate) [Wait](#wait) | | **Results** | [IsExitCode](#isexitcode) [IsSignal](#issignal) [OK](#ok) | | **Shadow Print** | [ShadowOff](#shadowoff) [ShadowOn](#shadowon) [ShadowPrint](#shadowprint) [WithFormatter](#withformatter) [WithMask](#withmask) [WithPrefix](#withprefix) | -| **Streaming** | [OnStderr](#onstderr) [OnStdout](#onstdout) [StderrWriter](#stderrwriter) [StdoutWriter](#stdoutwriter) | +| **Streaming** | [OnStderr](#onstderr) [OnStdout](#onstdout) [StderrWriter](#stderrwriter) [StdoutWriter](#stdoutwriter) [WithPTY](#withpty) | | **WorkingDir** | [Dir](#dir) | @@ -1061,6 +1061,21 @@ fmt.Print(out.String()) // hello ``` +### WithPTY + +WithPTY attaches stdout/stderr to a pseudo-terminal. + +Output is combined; OnStdout and OnStderr receive the same lines, and Result.Stderr remains empty. +Platforms without PTY support return an error when the command runs. + +```go +_, _ = execx.Command("printf", "hi"). + WithPTY(). + OnStdout(func(line string) { fmt.Println(line) }). + Run() +// hi +``` + ## WorkingDir ### Dir diff --git a/examples/withpty/main.go b/examples/withpty/main.go new file mode 100644 index 0000000..690ec15 --- /dev/null +++ b/examples/withpty/main.go @@ -0,0 +1,22 @@ +//go:build ignore +// +build ignore + +package main + +import ( + "fmt" + + "github.com/goforj/execx" +) + +func main() { + // WithPTY attaches stdout/stderr to a pseudo-terminal. + // Output is combined; OnStdout and OnStderr receive the same lines. + + // Example: with pty + _, _ = execx.Command("printf", "hi\n"). + WithPTY(). + OnStdout(func(line string) { fmt.Println(line) }). + Run() + // hi +} diff --git a/execx.go b/execx.go index 07695b5..e2a3642 100644 --- a/execx.go +++ b/execx.go @@ -74,6 +74,7 @@ type Cmd struct { sysProcAttr *syscall.SysProcAttr onExecCmd func(*exec.Cmd) + usePTY bool next *Cmd root *Cmd @@ -416,6 +417,25 @@ func (c *Cmd) StderrWriter(w io.Writer) *Cmd { return c } +// WithPTY attaches stdout/stderr to a pseudo-terminal. +// @group Streaming +// +// When enabled, stdout and stderr are merged into a single stream. OnStdout and +// OnStderr both receive the same lines, and Result.Stderr remains empty. +// Platforms without PTY support return an error when the command runs. +// +// Example: with pty +// +// _, _ = execx.Command("printf", "hi"). +// WithPTY(). +// OnStdout(func(line string) { fmt.Println(line) }). +// Run() +// // hi +func (c *Cmd) WithPTY() *Cmd { + c.rootCmd().usePTY = true + return c +} + // OnExecCmd registers a callback to mutate the underlying exec.Cmd before start. // @group Execution // @@ -688,6 +708,9 @@ func WithFormatter(fn func(ShadowEvent) string) ShadowOption { // fmt.Println(res.ExitCode == 0) // // #bool true func (c *Cmd) Run() (Result, error) { + if err := c.validatePTY(); err != nil { + return Result{Err: err, ExitCode: -1}, err + } shadow := c.shadowPrintStart(false) pipe := c.newPipeline(false, shadow) pipe.start() @@ -749,6 +772,9 @@ func (c *Cmd) OutputTrimmed() (string, error) { // // Run 'go help env' for details. // // false func (c *Cmd) CombinedOutput() (string, error) { + if err := c.validatePTY(); err != nil { + return "", err + } shadow := c.shadowPrintStart(false) pipe := c.newPipeline(true, shadow) pipe.start() @@ -772,6 +798,9 @@ func (c *Cmd) CombinedOutput() (string, error) { // // {Stdout:GO Stderr: ExitCode:0 Err: Duration:4.976291ms signal:} // // ] func (c *Cmd) PipelineResults() ([]Result, error) { + if err := c.validatePTY(); err != nil { + return nil, err + } shadow := c.shadowPrintStart(false) pipe := c.newPipeline(false, shadow) pipe.start() @@ -791,6 +820,11 @@ func (c *Cmd) PipelineResults() ([]Result, error) { // fmt.Println(res.ExitCode == 0) // // #bool true func (c *Cmd) Start() *Process { + if err := c.validatePTY(); err != nil { + proc := &Process{done: make(chan struct{})} + proc.finish(Result{Err: err, ExitCode: -1}) + return proc + } shadow := c.shadowPrintStart(true) pipe := c.newPipeline(false, shadow) pipe.start() @@ -839,6 +873,8 @@ func (c *Cmd) execCmd() *exec.Cmd { } var isTerminalFunc = term.IsTerminal +var openPTYFunc = openPTY +var ptyCheckFunc = ptyCheck func isTerminalWriter(w io.Writer) bool { f, ok := w.(*os.File) @@ -848,6 +884,20 @@ func isTerminalWriter(w io.Writer) bool { return isTerminalFunc(int(f.Fd())) } +func (c *Cmd) validatePTY() error { + root := c.rootCmd() + if !root.usePTY { + return nil + } + if root.next != nil { + return errors.New("execx: WithPTY is not supported with pipelines") + } + if err := ptyCheckFunc(); err != nil { + return err + } + return nil +} + func (c *Cmd) stdoutWriter(buf *bytes.Buffer, withCombined bool, combined *bytes.Buffer, shadow *shadowContext) io.Writer { if c.stdoutW != nil && c.onStdout == nil && !withCombined { if isTerminalWriter(c.stdoutW) { @@ -896,6 +946,28 @@ func (c *Cmd) stderrWriter(buf *bytes.Buffer, withCombined bool, combined *bytes return wrapShadowWriter(out, shadow) } +func (c *Cmd) ptyWriter(buf *bytes.Buffer, withCombined bool, combined *bytes.Buffer, shadow *shadowContext) io.Writer { + writers := []io.Writer{} + if c.stdoutW != nil { + writers = append(writers, c.stdoutW) + } + if c.stderrW != nil && c.stderrW != c.stdoutW { + writers = append(writers, c.stderrW) + } + writers = append(writers, buf) + if withCombined { + writers = append(writers, combined) + } + if c.onStdout != nil || c.onStderr != nil { + writers = append(writers, &ptyLineWriter{onStdout: c.onStdout, onStderr: c.onStderr}) + } + var out io.Writer = buf + if len(writers) > 1 { + out = io.MultiWriter(writers...) + } + return wrapShadowWriter(out, shadow) +} + type lineWriter struct { onLine func(string) buf bytes.Buffer @@ -919,6 +991,35 @@ func (l *lineWriter) Write(p []byte) (int, error) { return len(p), nil } +type ptyLineWriter struct { + onStdout func(string) + onStderr func(string) + buf bytes.Buffer +} + +// Write buffers output and emits completed lines to stdout/stderr callbacks. +func (l *ptyLineWriter) Write(p []byte) (int, error) { + if l.onStdout == nil && l.onStderr == nil { + return len(p), nil + } + for _, b := range p { + if b == '\n' { + line := l.buf.String() + l.buf.Reset() + line = strings.TrimSuffix(line, "\r") + if l.onStdout != nil { + l.onStdout(line) + } + if l.onStderr != nil { + l.onStderr(line) + } + continue + } + _ = l.buf.WriteByte(b) + } + return len(p), nil +} + func buildEnv(mode envMode, env map[string]string) []string { merged := map[string]string{} if mode != envOnly { diff --git a/execx_test.go b/execx_test.go index 2747631..45a5b73 100644 --- a/execx_test.go +++ b/execx_test.go @@ -515,6 +515,286 @@ func TestLineWriterNil(t *testing.T) { } } +func TestWithPTYPipelineUnsupported(t *testing.T) { + prevCheck := ptyCheckFunc + ptyCheckFunc = func() error { return nil } + t.Cleanup(func() { + ptyCheckFunc = prevCheck + }) + _, err := Command("printf", "hi"). + WithPTY(). + Pipe("tr", "a-z", "A-Z"). + Run() + if err == nil || !strings.Contains(err.Error(), "WithPTY is not supported") { + t.Fatalf("expected WithPTY pipeline error, got %v", err) + } +} + +func TestWithPTYOpenError(t *testing.T) { + prevOpen := openPTYFunc + prevCheck := ptyCheckFunc + openPTYFunc = func() (*os.File, *os.File, error) { + return nil, nil, errors.New("pty open failed") + } + ptyCheckFunc = func() error { return nil } + t.Cleanup(func() { + openPTYFunc = prevOpen + ptyCheckFunc = prevCheck + }) + _, err := Command("printf", "hi").WithPTY().Run() + if err == nil || !strings.Contains(err.Error(), "pty open failed") { + t.Fatalf("expected openpty error, got %v", err) + } +} + +func TestWithPTYCombinedStream(t *testing.T) { + prevOpen := openPTYFunc + prevCheck := ptyCheckFunc + openPTYFunc = func() (*os.File, *os.File, error) { + r, w, err := os.Pipe() + if err != nil { + return nil, nil, err + } + return r, w, nil + } + ptyCheckFunc = func() error { return nil } + t.Cleanup(func() { + openPTYFunc = prevOpen + ptyCheckFunc = prevCheck + }) + stdoutLines := []string{} + stderrLines := []string{} + res, err := Command("printf", "a\nb\n"). + WithPTY(). + OnStdout(func(line string) { stdoutLines = append(stdoutLines, line) }). + OnStderr(func(line string) { stderrLines = append(stderrLines, line) }). + Run() + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if res.Stdout != "a\nb\n" { + t.Fatalf("expected stdout to contain output, got %q", res.Stdout) + } + if res.Stderr != "" { + t.Fatalf("expected stderr to be empty, got %q", res.Stderr) + } + if strings.Join(stdoutLines, ",") != "a,b" { + t.Fatalf("unexpected stdout lines: %v", stdoutLines) + } + if strings.Join(stderrLines, ",") != "a,b" { + t.Fatalf("unexpected stderr lines: %v", stderrLines) + } +} + +func TestWithPTYCombinedOutput(t *testing.T) { + prevOpen := openPTYFunc + prevCheck := ptyCheckFunc + openPTYFunc = func() (*os.File, *os.File, error) { + r, w, err := os.Pipe() + if err != nil { + return nil, nil, err + } + return r, w, nil + } + ptyCheckFunc = func() error { return nil } + t.Cleanup(func() { + openPTYFunc = prevOpen + ptyCheckFunc = prevCheck + }) + out, err := Command("printf", "hi").WithPTY().CombinedOutput() + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if out != "hi" { + t.Fatalf("expected combined output, got %q", out) + } +} + +func TestWithPTYPipelineResults(t *testing.T) { + prevOpen := openPTYFunc + prevCheck := ptyCheckFunc + openPTYFunc = func() (*os.File, *os.File, error) { + r, w, err := os.Pipe() + if err != nil { + return nil, nil, err + } + return r, w, nil + } + ptyCheckFunc = func() error { return nil } + t.Cleanup(func() { + openPTYFunc = prevOpen + ptyCheckFunc = prevCheck + }) + results, err := Command("printf", "ok").WithPTY().PipelineResults() + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if len(results) != 1 || results[0].Stdout != "ok" { + t.Fatalf("unexpected pipeline results: %+v", results) + } +} + +func TestWithPTYStart(t *testing.T) { + prevOpen := openPTYFunc + prevCheck := ptyCheckFunc + openPTYFunc = func() (*os.File, *os.File, error) { + r, w, err := os.Pipe() + if err != nil { + return nil, nil, err + } + return r, w, nil + } + ptyCheckFunc = func() error { return nil } + t.Cleanup(func() { + openPTYFunc = prevOpen + ptyCheckFunc = prevCheck + }) + proc := Command("printf", "hi").WithPTY().Start() + res, err := proc.Wait() + if err != nil || res.Stdout != "hi" { + t.Fatalf("expected stdout from Start, got %q err=%v", res.Stdout, err) + } +} + +func TestWithPTYCheckError(t *testing.T) { + prevCheck := ptyCheckFunc + ptyCheckFunc = func() error { return errors.New("pty unsupported") } + t.Cleanup(func() { + ptyCheckFunc = prevCheck + }) + _, err := Command("printf", "hi").WithPTY().CombinedOutput() + if err == nil || !strings.Contains(err.Error(), "pty unsupported") { + t.Fatalf("expected pty check error, got %v", err) + } +} + +func TestWithPTYStartCheckError(t *testing.T) { + prevCheck := ptyCheckFunc + ptyCheckFunc = func() error { return errors.New("pty unsupported") } + t.Cleanup(func() { + ptyCheckFunc = prevCheck + }) + proc := Command("printf", "hi").WithPTY().Start() + res, err := proc.Wait() + if err == nil || !strings.Contains(err.Error(), "pty unsupported") { + t.Fatalf("expected pty check error, got %v", err) + } + if res.ExitCode != -1 { + t.Fatalf("expected exit code -1, got %d", res.ExitCode) + } +} + +func TestWithPTYPipelineResultsCheckError(t *testing.T) { + prevCheck := ptyCheckFunc + ptyCheckFunc = func() error { return errors.New("pty unsupported") } + t.Cleanup(func() { + ptyCheckFunc = prevCheck + }) + _, err := Command("printf", "hi").WithPTY().PipelineResults() + if err == nil || !strings.Contains(err.Error(), "pty unsupported") { + t.Fatalf("expected pty check error, got %v", err) + } +} + +type errWriter struct { + called bool +} + +func (w *errWriter) Write(p []byte) (int, error) { + w.called = true + return 0, errors.New("write failed") +} + +func TestWithPTYWriterError(t *testing.T) { + prevOpen := openPTYFunc + prevCheck := ptyCheckFunc + openPTYFunc = func() (*os.File, *os.File, error) { + r, w, err := os.Pipe() + if err != nil { + return nil, nil, err + } + return r, w, nil + } + ptyCheckFunc = func() error { return nil } + t.Cleanup(func() { + openPTYFunc = prevOpen + ptyCheckFunc = prevCheck + }) + writer := &errWriter{} + res, err := Command("printf", "hi").WithPTY().StdoutWriter(writer).Run() + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !writer.called { + t.Fatalf("expected writer to be called") + } + if res.Stdout != "" { + t.Fatalf("expected empty stdout, got %q", res.Stdout) + } +} + +func TestWithPTYStartError(t *testing.T) { + prevOpen := openPTYFunc + prevCheck := ptyCheckFunc + openPTYFunc = func() (*os.File, *os.File, error) { + r, w, err := os.Pipe() + if err != nil { + return nil, nil, err + } + return r, w, nil + } + ptyCheckFunc = func() error { return nil } + t.Cleanup(func() { + openPTYFunc = prevOpen + ptyCheckFunc = prevCheck + }) + _, err := Command("execx-does-not-exist").WithPTY().Run() + if err == nil { + t.Fatalf("expected start error") + } +} + +func TestWithPTYWritersNoCallbacks(t *testing.T) { + prevOpen := openPTYFunc + prevCheck := ptyCheckFunc + openPTYFunc = func() (*os.File, *os.File, error) { + r, w, err := os.Pipe() + if err != nil { + return nil, nil, err + } + return r, w, nil + } + ptyCheckFunc = func() error { return nil } + t.Cleanup(func() { + openPTYFunc = prevOpen + ptyCheckFunc = prevCheck + }) + var stdoutBuf bytes.Buffer + var stderrBuf bytes.Buffer + res, err := Command("printf", "hi"). + WithPTY(). + StdoutWriter(&stdoutBuf). + StderrWriter(&stderrBuf). + Run() + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if res.Stdout != "hi" { + t.Fatalf("expected stdout buffer to capture output, got %q", res.Stdout) + } + if stdoutBuf.String() != "hi" || stderrBuf.String() != "hi" { + t.Fatalf("unexpected writers: stdout=%q stderr=%q", stdoutBuf.String(), stderrBuf.String()) + } +} + +func TestPTYLineWriterNil(t *testing.T) { + writer := &ptyLineWriter{} + n, err := writer.Write([]byte("data")) + if err != nil || n != 4 { + t.Fatalf("unexpected write result n=%d err=%v", n, err) + } +} + func TestOnExecCmdApplied(t *testing.T) { called := false cmd := Command("printf", "hi").OnExecCmd(func(ec *exec.Cmd) { @@ -631,6 +911,21 @@ func TestPipelineResultsError(t *testing.T) { } } +func TestPipelineStartErrorPropagation(t *testing.T) { + results, err := Command("execx-does-not-exist"). + Pipe("printf", "ok"). + PipelineResults() + if err == nil { + t.Fatalf("expected error") + } + if len(results) != 2 { + t.Fatalf("expected 2 results, got %d", len(results)) + } + if results[1].Err == nil { + t.Fatalf("expected downstream start error") + } +} + func TestProcessSignals(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("signals not supported on windows") diff --git a/pipeline.go b/pipeline.go index 8f101de..0986023 100644 --- a/pipeline.go +++ b/pipeline.go @@ -5,6 +5,7 @@ import ( "context" "errors" "io" + "os" "os/exec" "time" ) @@ -16,9 +17,14 @@ type stage struct { stderrBuf bytes.Buffer combinedBuf bytes.Buffer startErr error + setupErr error waitErr error startTime time.Time pipeWriter *io.PipeWriter + ptyMaster *os.File + ptySlave *os.File + ptyWriter io.Writer + ptyDone chan error } type pipeline struct { @@ -31,10 +37,23 @@ func (c *Cmd) newPipeline(withCombined bool, shadow *shadowContext) *pipeline { for _, stage := range stages { stage.startTime = time.Now() stage.cmd = stage.def.execCmd() - stdoutWriter := stage.def.stdoutWriter(&stage.stdoutBuf, withCombined, &stage.combinedBuf, shadow) - stderrWriter := stage.def.stderrWriter(&stage.stderrBuf, withCombined, &stage.combinedBuf, shadow) - stage.cmd.Stdout = stdoutWriter - stage.cmd.Stderr = stderrWriter + if stage.def.rootCmd().usePTY { + master, slave, err := openPTYFunc() + if err != nil { + stage.setupErr = err + continue + } + stage.ptyMaster = master + stage.ptySlave = slave + stage.ptyWriter = stage.def.ptyWriter(&stage.stdoutBuf, withCombined, &stage.combinedBuf, shadow) + stage.cmd.Stdout = slave + stage.cmd.Stderr = slave + } else { + stdoutWriter := stage.def.stdoutWriter(&stage.stdoutBuf, withCombined, &stage.combinedBuf, shadow) + stderrWriter := stage.def.stderrWriter(&stage.stderrBuf, withCombined, &stage.combinedBuf, shadow) + stage.cmd.Stdout = stdoutWriter + stage.cmd.Stderr = stderrWriter + } } for i := range stages { @@ -52,14 +71,37 @@ func (c *Cmd) newPipeline(withCombined bool, shadow *shadowContext) *pipeline { } func (p *pipeline) start() { - for i, stage := range p.stages { - stage.startErr = stage.cmd.Start() - if stage.startErr != nil { + for i, stg := range p.stages { + if stg.setupErr != nil { + stg.startErr = stg.setupErr + break + } + stg.startErr = stg.cmd.Start() + if stg.startErr != nil { + if stg.ptyMaster != nil { + _ = stg.ptyMaster.Close() + } + if stg.ptySlave != nil { + _ = stg.ptySlave.Close() + } for j := i + 1; j < len(p.stages); j++ { - p.stages[j].startErr = stage.startErr + p.stages[j].startErr = stg.startErr } break } + if stg.ptyMaster != nil { + stg.ptyDone = make(chan error, 1) + go func(st *stage) { + _, err := io.Copy(st.ptyWriter, st.ptyMaster) + if err != nil { + st.ptyDone <- err + } else { + st.ptyDone <- nil + } + _ = st.ptyMaster.Close() + }(stg) + _ = stg.ptySlave.Close() + } } } @@ -75,6 +117,11 @@ func (p *pipeline) wait() { if p.stages[i].pipeWriter != nil { _ = p.stages[i].pipeWriter.Close() } + if p.stages[i].ptyDone != nil { + if err := <-p.stages[i].ptyDone; err != nil && p.stages[i].waitErr == nil { + p.stages[i].waitErr = err + } + } } } diff --git a/pty_darwin.go b/pty_darwin.go new file mode 100644 index 0000000..2376ad8 --- /dev/null +++ b/pty_darwin.go @@ -0,0 +1,53 @@ +//go:build darwin + +package execx + +import ( + "bytes" + "os" + "syscall" + "unsafe" +) + +func ptyCheck() error { + return nil +} + +func openPTY() (*os.File, *os.File, error) { + return openPTYWith(os.OpenFile, ptyIoctl) +} + +func openPTYWith(openFile func(string, int, os.FileMode) (*os.File, error), ioctl func(uintptr, uintptr, uintptr) error) (*os.File, *os.File, error) { + master, err := openFile("/dev/ptmx", os.O_RDWR|syscall.O_NOCTTY, 0) + if err != nil { + return nil, nil, err + } + if err := ioctl(master.Fd(), syscall.TIOCPTYGRANT, 0); err != nil { + _ = master.Close() + return nil, nil, err + } + if err := ioctl(master.Fd(), syscall.TIOCPTYUNLK, 0); err != nil { + _ = master.Close() + return nil, nil, err + } + var nameBuf [128]byte + if err := ioctl(master.Fd(), syscall.TIOCPTYGNAME, uintptr(unsafe.Pointer(&nameBuf[0]))); err != nil { + _ = master.Close() + return nil, nil, err + } + name := string(bytes.TrimRight(nameBuf[:], "\x00")) + slave, err := openFile(name, os.O_RDWR|syscall.O_NOCTTY, 0) + if err != nil { + _ = master.Close() + return nil, nil, err + } + return master, slave, nil +} + +func ptyIoctl(fd uintptr, req uintptr, arg uintptr) error { + _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, fd, req, arg) + if errno != 0 { + return errno + } + return nil +} diff --git a/pty_darwin_test.go b/pty_darwin_test.go new file mode 100644 index 0000000..f3aadb6 --- /dev/null +++ b/pty_darwin_test.go @@ -0,0 +1,133 @@ +//go:build darwin + +package execx + +import ( + "errors" + "os" + "syscall" + "testing" + "unsafe" +) + +func TestPTYDarwinOpen(t *testing.T) { + if err := ptyCheck(); err != nil { + t.Fatalf("unexpected pty check error: %v", err) + } + master, slave, err := openPTY() + if err != nil { + t.Fatalf("expected openPTY to succeed, got %v", err) + } + _ = master.Close() + _ = slave.Close() +} + +func TestPTYIoctlSuccessAndError(t *testing.T) { + master, err := os.OpenFile("/dev/ptmx", os.O_RDWR, 0) + if err != nil { + t.Fatalf("open ptmx: %v", err) + } + defer master.Close() + if err := ptyIoctl(master.Fd(), syscall.TIOCPTYGRANT, 0); err != nil { + t.Fatalf("expected ioctl success, got %v", err) + } + if err := ptyIoctl(0, 0, 0); err == nil { + t.Fatalf("expected ioctl error") + } +} + +func TestOpenPTYWithOpenError(t *testing.T) { + openFile := func(string, int, os.FileMode) (*os.File, error) { + return nil, errors.New("open failed") + } + _, _, err := openPTYWith(openFile, func(uintptr, uintptr, uintptr) error { return nil }) + if err == nil || err.Error() != "open failed" { + t.Fatalf("expected open error, got %v", err) + } +} + +func TestOpenPTYWithGrantError(t *testing.T) { + openFile := func(string, int, os.FileMode) (*os.File, error) { + return os.OpenFile(os.DevNull, os.O_RDWR, 0) + } + _, _, err := openPTYWith(openFile, func(fd uintptr, req uintptr, arg uintptr) error { + if req == syscall.TIOCPTYGRANT { + return errors.New("grant failed") + } + return nil + }) + if err == nil || err.Error() != "grant failed" { + t.Fatalf("expected grant error, got %v", err) + } +} + +func TestOpenPTYWithUnlockError(t *testing.T) { + openFile := func(string, int, os.FileMode) (*os.File, error) { + return os.OpenFile(os.DevNull, os.O_RDWR, 0) + } + ioctl := func(fd uintptr, req uintptr, arg uintptr) error { + if req == syscall.TIOCPTYUNLK { + return errors.New("unlock failed") + } + return nil + } + _, _, err := openPTYWith(openFile, ioctl) + if err == nil || err.Error() != "unlock failed" { + t.Fatalf("expected unlock error, got %v", err) + } +} + +func TestOpenPTYWithNameError(t *testing.T) { + openFile := func(string, int, os.FileMode) (*os.File, error) { + return os.OpenFile(os.DevNull, os.O_RDWR, 0) + } + ioctl := func(fd uintptr, req uintptr, arg uintptr) error { + if req == syscall.TIOCPTYGNAME { + return errors.New("name failed") + } + return nil + } + _, _, err := openPTYWith(openFile, ioctl) + if err == nil || err.Error() != "name failed" { + t.Fatalf("expected name error, got %v", err) + } +} + +func TestOpenPTYWithSlaveError(t *testing.T) { + openFile := func(name string, flag int, perm os.FileMode) (*os.File, error) { + if name == "/dev/ptmx" { + return os.OpenFile(os.DevNull, os.O_RDWR, 0) + } + return nil, errors.New("slave open failed") + } + ioctl := func(fd uintptr, req uintptr, arg uintptr) error { + if req == syscall.TIOCPTYGNAME { + buf := (*[128]byte)(unsafe.Pointer(arg)) + copy(buf[:], []byte("/dev/doesnotexist")) + } + return nil + } + _, _, err := openPTYWith(openFile, ioctl) + if err == nil || err.Error() != "slave open failed" { + t.Fatalf("expected slave open error, got %v", err) + } +} + +func TestOpenPTYWithSuccess(t *testing.T) { + openFile := func(name string, flag int, perm os.FileMode) (*os.File, error) { + return os.OpenFile(os.DevNull, os.O_RDWR, 0) + } + ioctl := func(fd uintptr, req uintptr, arg uintptr) error { + if req == syscall.TIOCPTYGNAME { + buf := (*[128]byte)(unsafe.Pointer(arg)) + copy(buf[:], []byte(os.DevNull)) + } + return nil + } + master, slave, err := openPTYWith(openFile, ioctl) + if err != nil { + t.Fatalf("expected success, got %v", err) + } + _ = master.Close() + _ = slave.Close() +} diff --git a/pty_linux.go b/pty_linux.go new file mode 100644 index 0000000..e774bf0 --- /dev/null +++ b/pty_linux.go @@ -0,0 +1,51 @@ +//go:build linux + +package execx + +import ( + "fmt" + "os" + "syscall" + "unsafe" +) + +func ptyCheck() error { + return nil +} + +func openPTY() (*os.File, *os.File, error) { + return openPTYWith(os.OpenFile, ptyIoctl) +} + +func openPTYWith(openFile func(string, int, os.FileMode) (*os.File, error), ioctl func(uintptr, uintptr, uintptr) error) (*os.File, *os.File, error) { + master, err := openFile("/dev/ptmx", os.O_RDWR|syscall.O_NOCTTY, 0) + if err != nil { + return nil, nil, err + } + fd := master.Fd() + unlock := int32(0) + if err := ioctl(fd, syscall.TIOCSPTLCK, uintptr(unsafe.Pointer(&unlock))); err != nil { + _ = master.Close() + return nil, nil, err + } + var ptyNum uint32 + if err := ioctl(fd, syscall.TIOCGPTN, uintptr(unsafe.Pointer(&ptyNum))); err != nil { + _ = master.Close() + return nil, nil, err + } + name := fmt.Sprintf("/dev/pts/%d", ptyNum) + slave, err := openFile(name, os.O_RDWR|syscall.O_NOCTTY, 0) + if err != nil { + _ = master.Close() + return nil, nil, err + } + return master, slave, nil +} + +func ptyIoctl(fd uintptr, req uintptr, arg uintptr) error { + _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, fd, req, arg) + if errno != 0 { + return errno + } + return nil +} diff --git a/pty_linux_test.go b/pty_linux_test.go new file mode 100644 index 0000000..6d5fdce --- /dev/null +++ b/pty_linux_test.go @@ -0,0 +1,118 @@ +//go:build linux + +package execx + +import ( + "errors" + "os" + "syscall" + "testing" + "unsafe" +) + +func TestPTYLinuxOpen(t *testing.T) { + if err := ptyCheck(); err != nil { + t.Fatalf("unexpected pty check error: %v", err) + } + master, slave, err := openPTY() + if err != nil { + t.Fatalf("expected openPTY to succeed, got %v", err) + } + _ = master.Close() + _ = slave.Close() +} + +func TestPTYIoctlSuccessAndErrorLinux(t *testing.T) { + master, err := os.OpenFile("/dev/ptmx", os.O_RDWR|syscall.O_NOCTTY, 0) + if err != nil { + t.Fatalf("open ptmx: %v", err) + } + defer master.Close() + unlock := int32(0) + if err := ptyIoctl(master.Fd(), syscall.TIOCSPTLCK, uintptr(unsafe.Pointer(&unlock))); err != nil { + t.Fatalf("expected ioctl success, got %v", err) + } + if err := ptyIoctl(0, 0, 0); err == nil { + t.Fatalf("expected ioctl error") + } +} + +func TestOpenPTYWithOpenErrorLinux(t *testing.T) { + openFile := func(string, int, os.FileMode) (*os.File, error) { + return nil, errors.New("open failed") + } + _, _, err := openPTYWith(openFile, func(uintptr, uintptr, uintptr) error { return nil }) + if err == nil || err.Error() != "open failed" { + t.Fatalf("expected open error, got %v", err) + } +} + +func TestOpenPTYWithUnlockErrorLinux(t *testing.T) { + openFile := func(string, int, os.FileMode) (*os.File, error) { + return os.OpenFile(os.DevNull, os.O_RDWR, 0) + } + _, _, err := openPTYWith(openFile, func(fd uintptr, req uintptr, arg uintptr) error { + if req == syscall.TIOCSPTLCK { + return errors.New("unlock failed") + } + return nil + }) + if err == nil || err.Error() != "unlock failed" { + t.Fatalf("expected unlock error, got %v", err) + } +} + +func TestOpenPTYWithPTNErrorLinux(t *testing.T) { + openFile := func(string, int, os.FileMode) (*os.File, error) { + return os.OpenFile(os.DevNull, os.O_RDWR, 0) + } + _, _, err := openPTYWith(openFile, func(fd uintptr, req uintptr, arg uintptr) error { + if req == syscall.TIOCGPTN { + return errors.New("ptn failed") + } + return nil + }) + if err == nil || err.Error() != "ptn failed" { + t.Fatalf("expected ptn error, got %v", err) + } +} + +func TestOpenPTYWithSlaveErrorLinux(t *testing.T) { + openFile := func(name string, flag int, perm os.FileMode) (*os.File, error) { + if name == "/dev/ptmx" { + return os.OpenFile(os.DevNull, os.O_RDWR, 0) + } + return nil, errors.New("slave open failed") + } + ioctl := func(fd uintptr, req uintptr, arg uintptr) error { + if req == syscall.TIOCGPTN { + *(*uint32)(unsafe.Pointer(arg)) = 1234 + } + return nil + } + _, _, err := openPTYWith(openFile, ioctl) + if err == nil || err.Error() != "slave open failed" { + t.Fatalf("expected slave open error, got %v", err) + } +} + +func TestOpenPTYWithSuccessLinux(t *testing.T) { + openFile := func(name string, flag int, perm os.FileMode) (*os.File, error) { + return os.OpenFile(os.DevNull, os.O_RDWR, 0) + } + ioctl := func(fd uintptr, req uintptr, arg uintptr) error { + if req == syscall.TIOCGPTN { + *(*uint32)(unsafe.Pointer(arg)) = 0 + } + return nil + } + master, slave, err := openPTYWith(openFile, ioctl) + if err != nil { + t.Fatalf("expected success, got %v", err) + } + _ = master.Close() + _ = slave.Close() + if master.Name() != os.DevNull || slave.Name() != os.DevNull { + t.Fatalf("expected dev null files, got %q %q", master.Name(), slave.Name()) + } +} diff --git a/pty_unsupported.go b/pty_unsupported.go new file mode 100644 index 0000000..7f66181 --- /dev/null +++ b/pty_unsupported.go @@ -0,0 +1,16 @@ +//go:build !linux && !darwin + +package execx + +import ( + "errors" + "os" +) + +func ptyCheck() error { + return errors.New("execx: WithPTY is not supported on this platform") +} + +func openPTY() (*os.File, *os.File, error) { + return nil, nil, ptyCheck() +}