From 82fbd4e21ea412a93e957397fbe644714ff83eb6 Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Mon, 8 Jan 2024 19:04:39 -0500 Subject: [PATCH] feat: separate conpty into its own package --- cmd_windows.go | 343 +-------------------------------------- conpty/conpty_windows.go | 260 +++++++++++++++++++++++++++++ conpty/doc.go | 5 + conpty/exec_windows.go | 214 ++++++++++++++++++++++++ examples/go.mod | 4 +- examples/go.sum | 8 +- examples/shell/main.go | 7 +- pty_windows.go | 103 +----------- 8 files changed, 509 insertions(+), 435 deletions(-) create mode 100644 conpty/conpty_windows.go create mode 100644 conpty/doc.go create mode 100644 conpty/exec_windows.go diff --git a/cmd_windows.go b/cmd_windows.go index 529e6ed..0b3253e 100644 --- a/cmd_windows.go +++ b/cmd_windows.go @@ -7,18 +7,12 @@ import ( "errors" "fmt" "os" - "os/exec" - "path/filepath" - "strings" "syscall" - "unicode/utf16" - "unsafe" "golang.org/x/sys/windows" ) type conPtySys struct { - attrs *windows.ProcThreadAttributeListContainer done chan error cmdErr error } @@ -29,142 +23,26 @@ func (c *Cmd) start() error { return ErrInvalidCommand } - if c.SysProcAttr == nil { - c.SysProcAttr = &syscall.SysProcAttr{} - } - - argv0, err := lookExtensions(c.Path, c.Dir) - if err != nil { - return err - } - if len(c.Dir) != 0 { - // Windows CreateProcess looks for argv0 relative to the current - // directory, and, only once the new process is started, it does - // Chdir(attr.Dir). We are adjusting for that difference here by - // making argv0 absolute. - var err error - argv0, err = joinExeDirAndFName(c.Dir, c.Path) - if err != nil { - return err - } - } - - argv0p, err := windows.UTF16PtrFromString(argv0) - if err != nil { - return err - } - - var cmdline string - if c.SysProcAttr.CmdLine != "" { - cmdline = c.SysProcAttr.CmdLine - } else { - cmdline = windows.ComposeCommandLine(c.Args) - } - argvp, err := windows.UTF16PtrFromString(cmdline) - if err != nil { - return err - } - - var dirp *uint16 - if len(c.Dir) != 0 { - dirp, err = windows.UTF16PtrFromString(c.Dir) - if err != nil { - return err - } - } - - if c.Env == nil { - c.Env, err = execEnvDefault(c.SysProcAttr) - if err != nil { - return err - } - } - - siEx := new(windows.StartupInfoEx) - siEx.Flags = windows.STARTF_USESTDHANDLES - pi := new(windows.ProcessInformation) - - // Need EXTENDED_STARTUPINFO_PRESENT as we're making use of the attribute list field. - flags := uint32(windows.CREATE_UNICODE_ENVIRONMENT) | windows.EXTENDED_STARTUPINFO_PRESENT | c.SysProcAttr.CreationFlags - - // Allocate an attribute list that's large enough to do the operations we care about - // 2. Pseudo console setup if one was requested. - // Therefore we need a list of size 1. - attrs, err := windows.NewProcThreadAttributeList(1) - if err != nil { - return fmt.Errorf("failed to initialize process thread attribute list: %w", err) - } - c.sys = &conPtySys{ - attrs: attrs, - done: make(chan error, 1), - } - - if err := pty.updateProcThreadAttribute(attrs); err != nil { - return err - } - - var zeroSec windows.SecurityAttributes - pSec := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(zeroSec)), InheritHandle: 1} - if c.SysProcAttr.ProcessAttributes != nil { - pSec = &windows.SecurityAttributes{ - Length: c.SysProcAttr.ProcessAttributes.Length, - InheritHandle: c.SysProcAttr.ProcessAttributes.InheritHandle, - } - } - tSec := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(zeroSec)), InheritHandle: 1} - if c.SysProcAttr.ThreadAttributes != nil { - tSec = &windows.SecurityAttributes{ - Length: c.SysProcAttr.ThreadAttributes.Length, - InheritHandle: c.SysProcAttr.ThreadAttributes.InheritHandle, - } + done: make(chan error, 1), } - siEx.ProcThreadAttributeList = attrs.List() //nolint:govet // unusedwrite: ProcThreadAttributeList will be read in syscall - siEx.Cb = uint32(unsafe.Sizeof(*siEx)) - if c.SysProcAttr.Token != 0 { - err = windows.CreateProcessAsUser( - windows.Token(c.SysProcAttr.Token), - argv0p, - argvp, - pSec, - tSec, - false, - flags, - createEnvBlock(addCriticalEnv(dedupEnvCase(true, c.Env))), - dirp, - &siEx.StartupInfo, - pi, - ) - } else { - err = windows.CreateProcess( - argv0p, - argvp, - pSec, - tSec, - false, - flags, - createEnvBlock(addCriticalEnv(dedupEnvCase(true, c.Env))), - dirp, - &siEx.StartupInfo, - pi, - ) - } + pid, proc, err := pty.Spawn(c.Path, c.Args, &syscall.ProcAttr{ + Dir: c.Dir, + Env: c.Env, + Sys: c.SysProcAttr, + }) if err != nil { - return fmt.Errorf("failed to create process: %w", err) + return err } - // Don't need the thread handle for anything. - defer func() { - _ = windows.CloseHandle(pi.Thread) - }() // Grab an *os.Process to avoid reinventing the wheel here. The stdlib has great logic around waiting, exit code status/cleanup after a // process has been launched. - c.Process, err = os.FindProcess(int(pi.ProcessId)) + c.Process, err = os.FindProcess(pid) if err != nil { // If we can't find the process via os.FindProcess, terminate the process as that's what we rely on for all further operations on the // object. - if tErr := windows.TerminateProcess(pi.Process, 1); tErr != nil { + if tErr := windows.TerminateProcess(windows.Handle(proc), 1); tErr != nil { return fmt.Errorf("failed to terminate process after process not found: %w", tErr) } return fmt.Errorf("failed to find process after starting: %w", err) @@ -199,7 +77,6 @@ func (c *Cmd) wait() (retErr error) { } defer func() { sys := c.sys.(*conPtySys) - sys.attrs.Delete() sys.done <- nil if retErr == nil { retErr = sys.cmdErr @@ -211,205 +88,3 @@ func (c *Cmd) wait() (retErr error) { } return } - -// -// Below are a bunch of helpers for working with Windows' CreateProcess family of functions. These are mostly exact copies of the same utilities -// found in the go stdlib. -// - -func lookExtensions(path, dir string) (string, error) { - if filepath.Base(path) == path { - path = filepath.Join(".", path) - } - - if dir == "" { - return exec.LookPath(path) - } - - if filepath.VolumeName(path) != "" { - return exec.LookPath(path) - } - - if len(path) > 1 && os.IsPathSeparator(path[0]) { - return exec.LookPath(path) - } - - dirandpath := filepath.Join(dir, path) - - // We assume that LookPath will only add file extension. - lp, err := exec.LookPath(dirandpath) - if err != nil { - return "", err - } - - ext := strings.TrimPrefix(lp, dirandpath) - - return path + ext, nil -} - -func execEnvDefault(sys *syscall.SysProcAttr) (env []string, err error) { - if sys == nil || sys.Token == 0 { - return syscall.Environ(), nil - } - - var block *uint16 - err = windows.CreateEnvironmentBlock(&block, windows.Token(sys.Token), false) - if err != nil { - return nil, err - } - - defer windows.DestroyEnvironmentBlock(block) - blockp := uintptr(unsafe.Pointer(block)) - - for { - // find NUL terminator - end := unsafe.Pointer(blockp) - for *(*uint16)(end) != 0 { - end = unsafe.Pointer(uintptr(end) + 2) - } - - n := (uintptr(end) - uintptr(unsafe.Pointer(blockp))) / 2 - if n == 0 { - // environment block ends with empty string - break - } - - entry := (*[(1 << 30) - 1]uint16)(unsafe.Pointer(blockp))[:n:n] - env = append(env, string(utf16.Decode(entry))) - blockp += 2 * (uintptr(len(entry)) + 1) - } - return -} - -func isSlash(c uint8) bool { - return c == '\\' || c == '/' -} - -func normalizeDir(dir string) (name string, err error) { - ndir, err := syscall.FullPath(dir) - if err != nil { - return "", err - } - if len(ndir) > 2 && isSlash(ndir[0]) && isSlash(ndir[1]) { - // dir cannot have \\server\share\path form - return "", syscall.EINVAL - } - return ndir, nil -} - -func volToUpper(ch int) int { - if 'a' <= ch && ch <= 'z' { - ch += 'A' - 'a' - } - return ch -} - -func joinExeDirAndFName(dir, p string) (name string, err error) { - if len(p) == 0 { - return "", syscall.EINVAL - } - if len(p) > 2 && isSlash(p[0]) && isSlash(p[1]) { - // \\server\share\path form - return p, nil - } - if len(p) > 1 && p[1] == ':' { - // has drive letter - if len(p) == 2 { - return "", syscall.EINVAL - } - if isSlash(p[2]) { - return p, nil - } else { - d, err := normalizeDir(dir) - if err != nil { - return "", err - } - if volToUpper(int(p[0])) == volToUpper(int(d[0])) { - return syscall.FullPath(d + "\\" + p[2:]) - } else { - return syscall.FullPath(p) - } - } - } else { - // no drive letter - d, err := normalizeDir(dir) - if err != nil { - return "", err - } - if isSlash(p[0]) { - return windows.FullPath(d[:2] + p) - } else { - return windows.FullPath(d + "\\" + p) - } - } -} - -// createEnvBlock converts an array of environment strings into -// the representation required by CreateProcess: a sequence of NUL -// terminated strings followed by a nil. -// Last bytes are two UCS-2 NULs, or four NUL bytes. -func createEnvBlock(envv []string) *uint16 { - if len(envv) == 0 { - return &utf16.Encode([]rune("\x00\x00"))[0] - } - length := 0 - for _, s := range envv { - length += len(s) + 1 - } - length++ - - b := make([]byte, length) - i := 0 - for _, s := range envv { - l := len(s) - copy(b[i:i+l], []byte(s)) - copy(b[i+l:i+l+1], []byte{0}) - i = i + l + 1 - } - copy(b[i:i+1], []byte{0}) - - return &utf16.Encode([]rune(string(b)))[0] -} - -// dedupEnvCase is dedupEnv with a case option for testing. -// If caseInsensitive is true, the case of keys is ignored. -func dedupEnvCase(caseInsensitive bool, env []string) []string { - out := make([]string, 0, len(env)) - saw := make(map[string]int, len(env)) // key => index into out - for _, kv := range env { - eq := strings.Index(kv, "=") - if eq < 0 { - out = append(out, kv) - continue - } - k := kv[:eq] - if caseInsensitive { - k = strings.ToLower(k) - } - if dupIdx, isDup := saw[k]; isDup { - out[dupIdx] = kv - continue - } - saw[k] = len(out) - out = append(out, kv) - } - return out -} - -// addCriticalEnv adds any critical environment variables that are required -// (or at least almost always required) on the operating system. -// Currently this is only used for Windows. -func addCriticalEnv(env []string) []string { - for _, kv := range env { - eq := strings.Index(kv, "=") - if eq < 0 { - continue - } - k := kv[:eq] - if strings.EqualFold(k, "SYSTEMROOT") { - // We already have it. - return env - } - } - return append(env, "SYSTEMROOT="+os.Getenv("SYSTEMROOT")) -} diff --git a/conpty/conpty_windows.go b/conpty/conpty_windows.go new file mode 100644 index 0000000..f659180 --- /dev/null +++ b/conpty/conpty_windows.go @@ -0,0 +1,260 @@ +package conpty + +import ( + "errors" + "fmt" + "io" + "os" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +// Default size. +const ( + DefaultWidth = 80 + DefaultHeight = 25 +) + +// ConPty represents a Windows Console Pseudo-terminal. +// https://learn.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session#preparing-the-communication-channels +type ConPty struct { + hpc *windows.Handle + inPipeFd, outPipeFd windows.Handle + inPipe, outPipe *os.File + attrList *windows.ProcThreadAttributeListContainer + size windows.Coord +} + +var _ io.Writer = &ConPty{} +var _ io.Reader = &ConPty{} + +// New creates a new ConPty device. +// Accepts a custom width, height, and flags that will get passed to +// windows.CreatePseudoConsole. +func New(w int, h int, flags int) (c *ConPty, err error) { + if w <= 0 { + w = DefaultWidth + } + if h <= 0 { + h = DefaultHeight + } + + c = &ConPty{ + hpc: new(windows.Handle), + size: windows.Coord{ + X: int16(w), Y: int16(h), + }, + } + + var ptyIn, ptyOut windows.Handle + if err := windows.CreatePipe(&ptyIn, &c.inPipeFd, nil, 0); err != nil { + return nil, fmt.Errorf("failed to create pipes for pseudo console: %w", err) + } + + if err := windows.CreatePipe(&c.outPipeFd, &ptyOut, nil, 0); err != nil { + return nil, fmt.Errorf("failed to create pipes for pseudo console: %w", err) + } + + if err := windows.CreatePseudoConsole(c.size, ptyIn, ptyOut, uint32(flags), c.hpc); err != nil { + return nil, fmt.Errorf("failed to create pseudo console: %w", err) + } + + // We don't need the pty pipes anymore, these will get dup'd when the + // new process starts. + if err := windows.CloseHandle(ptyOut); err != nil { + return nil, fmt.Errorf("failed to close pseudo console handle: %w", err) + } + if err := windows.CloseHandle(ptyIn); err != nil { + return nil, fmt.Errorf("failed to close pseudo console handle: %w", err) + } + + c.inPipe = os.NewFile(uintptr(c.inPipeFd), "|0") + c.outPipe = os.NewFile(uintptr(c.outPipeFd), "|1") + + // Allocate an attribute list that's large enough to do the operations we care about + // 1. Pseudo console setup + c.attrList, err = windows.NewProcThreadAttributeList(1) + if err != nil { + return nil, err + } + + if err := c.attrList.Update( + windows.PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE, + unsafe.Pointer(*c.hpc), + unsafe.Sizeof(*c.hpc), + ); err != nil { + return nil, fmt.Errorf("failed to update proc thread attributes for pseudo console: %w", err) + } + + return +} + +// Handle returns the ConPty handle. +func (p *ConPty) Handle() windows.Handle { + return *p.hpc +} + +// Close closes the ConPty device. +func (p *ConPty) Close() error { + if p.attrList != nil { + p.attrList.Delete() + } + + windows.ClosePseudoConsole(*p.hpc) + return errors.Join(p.inPipe.Close(), p.outPipe.Close()) +} + +// InPipe returns the ConPty input pipe. +func (p *ConPty) InPipe() *os.File { + return p.inPipe +} + +// OutPipe returns the ConPty output pipe. +func (p *ConPty) OutPipe() *os.File { + return p.outPipe +} + +// Write safely writes bytes to the ConPty. +func (c *ConPty) Write(p []byte) (n int, err error) { + var l uint32 + err = windows.WriteFile(c.inPipeFd, p, &l, nil) + return int(l), err +} + +// Read safely reads bytes from the ConPty. +func (c *ConPty) Read(p []byte) (n int, err error) { + var l uint32 + err = windows.ReadFile(c.outPipeFd, p, &l, nil) + return int(l), err +} + +// Resize resizes the pseudo-console. +func (c *ConPty) Resize(w int, h int) error { + if err := windows.ResizePseudoConsole(*c.hpc, windows.Coord{X: int16(w), Y: int16(h)}); err != nil { + return fmt.Errorf("failed to resize pseudo console: %w", err) + } + return nil +} + +var zeroAttr syscall.ProcAttr + +// Spawn creates a new process attached to the pseudo-console. +func (c *ConPty) Spawn(name string, args []string, attr *syscall.ProcAttr) (pid int, handle uintptr, err error) { + if attr == nil { + attr = &zeroAttr + } + + argv0, err := lookExtensions(name, attr.Dir) + if err != nil { + return 0, 0, err + } + if len(attr.Dir) != 0 { + // Windows CreateProcess looks for argv0 relative to the current + // directory, and, only once the new process is started, it does + // Chdir(attr.Dir). We are adjusting for that difference here by + // making argv0 absolute. + var err error + argv0, err = joinExeDirAndFName(attr.Dir, argv0) + if err != nil { + return 0, 0, err + } + } + + argv0p, err := windows.UTF16PtrFromString(argv0) + if err != nil { + return 0, 0, err + } + + var cmdline string + if attr.Sys != nil && attr.Sys.CmdLine != "" { + cmdline = attr.Sys.CmdLine + } else { + cmdline = windows.ComposeCommandLine(args) + } + argvp, err := windows.UTF16PtrFromString(cmdline) + if err != nil { + return 0, 0, err + } + + var dirp *uint16 + if len(attr.Dir) != 0 { + dirp, err = windows.UTF16PtrFromString(attr.Dir) + if err != nil { + return 0, 0, err + } + } + + if attr.Env == nil { + attr.Env, err = execEnvDefault(attr.Sys) + if err != nil { + return 0, 0, err + } + } + + siEx := new(windows.StartupInfoEx) + siEx.Flags = windows.STARTF_USESTDHANDLES + + pi := new(windows.ProcessInformation) + + // Need EXTENDED_STARTUPINFO_PRESENT as we're making use of the attribute list field. + flags := uint32(windows.CREATE_UNICODE_ENVIRONMENT) | windows.EXTENDED_STARTUPINFO_PRESENT + if attr.Sys != nil && attr.Sys.CreationFlags != 0 { + flags |= attr.Sys.CreationFlags + } + + var zeroSec windows.SecurityAttributes + pSec := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(zeroSec)), InheritHandle: 1} + if attr.Sys != nil && attr.Sys.ProcessAttributes != nil { + pSec = &windows.SecurityAttributes{ + Length: attr.Sys.ProcessAttributes.Length, + InheritHandle: attr.Sys.ProcessAttributes.InheritHandle, + } + } + tSec := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(zeroSec)), InheritHandle: 1} + if attr.Sys != nil && attr.Sys.ThreadAttributes != nil { + tSec = &windows.SecurityAttributes{ + Length: attr.Sys.ThreadAttributes.Length, + InheritHandle: attr.Sys.ThreadAttributes.InheritHandle, + } + } + + siEx.ProcThreadAttributeList = c.attrList.List() //nolint:govet // unusedwrite: ProcThreadAttributeList will be read in syscall + siEx.Cb = uint32(unsafe.Sizeof(*siEx)) + if attr.Sys != nil && attr.Sys.Token != 0 { + err = windows.CreateProcessAsUser( + windows.Token(attr.Sys.Token), + argv0p, + argvp, + pSec, + tSec, + false, + flags, + createEnvBlock(addCriticalEnv(dedupEnvCase(true, attr.Env))), + dirp, + &siEx.StartupInfo, + pi, + ) + } else { + err = windows.CreateProcess( + argv0p, + argvp, + pSec, + tSec, + false, + flags, + createEnvBlock(addCriticalEnv(dedupEnvCase(true, attr.Env))), + dirp, + &siEx.StartupInfo, + pi, + ) + } + if err != nil { + return 0, 0, fmt.Errorf("failed to create process: %w", err) + } + + defer windows.CloseHandle(pi.Thread) + + return int(pi.ProcessId), uintptr(pi.Process), nil +} diff --git a/conpty/doc.go b/conpty/doc.go new file mode 100644 index 0000000..1f1778d --- /dev/null +++ b/conpty/doc.go @@ -0,0 +1,5 @@ +// Package conpty implements Windows Console Pseudo-terminal support. +// +// https://learn.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session + +package conpty diff --git a/conpty/exec_windows.go b/conpty/exec_windows.go new file mode 100644 index 0000000..8455e2d --- /dev/null +++ b/conpty/exec_windows.go @@ -0,0 +1,214 @@ +package conpty + +import ( + "os" + "os/exec" + "path/filepath" + "strings" + "syscall" + "unicode/utf16" + "unsafe" + + "golang.org/x/sys/windows" +) + +// Below are a bunch of helpers for working with Windows' CreateProcess family +// of functions. These are mostly exact copies of the same utilities found in +// the go stdlib. + +func lookExtensions(path, dir string) (string, error) { + if filepath.Base(path) == path { + path = filepath.Join(".", path) + } + + if dir == "" { + return exec.LookPath(path) + } + + if filepath.VolumeName(path) != "" { + return exec.LookPath(path) + } + + if len(path) > 1 && os.IsPathSeparator(path[0]) { + return exec.LookPath(path) + } + + dirandpath := filepath.Join(dir, path) + + // We assume that LookPath will only add file extension. + lp, err := exec.LookPath(dirandpath) + if err != nil { + return "", err + } + + ext := strings.TrimPrefix(lp, dirandpath) + + return path + ext, nil +} + +func execEnvDefault(sys *syscall.SysProcAttr) (env []string, err error) { + if sys == nil || sys.Token == 0 { + return syscall.Environ(), nil + } + + var block *uint16 + err = windows.CreateEnvironmentBlock(&block, windows.Token(sys.Token), false) + if err != nil { + return nil, err + } + + defer windows.DestroyEnvironmentBlock(block) + blockp := uintptr(unsafe.Pointer(block)) + + for { + // find NUL terminator + end := unsafe.Pointer(blockp) + for *(*uint16)(end) != 0 { + end = unsafe.Pointer(uintptr(end) + 2) + } + + n := (uintptr(end) - uintptr(unsafe.Pointer(blockp))) / 2 + if n == 0 { + // environment block ends with empty string + break + } + + entry := (*[(1 << 30) - 1]uint16)(unsafe.Pointer(blockp))[:n:n] + env = append(env, string(utf16.Decode(entry))) + blockp += 2 * (uintptr(len(entry)) + 1) + } + return +} + +func isSlash(c uint8) bool { + return c == '\\' || c == '/' +} + +func normalizeDir(dir string) (name string, err error) { + ndir, err := syscall.FullPath(dir) + if err != nil { + return "", err + } + if len(ndir) > 2 && isSlash(ndir[0]) && isSlash(ndir[1]) { + // dir cannot have \\server\share\path form + return "", syscall.EINVAL + } + return ndir, nil +} + +func volToUpper(ch int) int { + if 'a' <= ch && ch <= 'z' { + ch += 'A' - 'a' + } + return ch +} + +func joinExeDirAndFName(dir, p string) (name string, err error) { + if len(p) == 0 { + return "", syscall.EINVAL + } + if len(p) > 2 && isSlash(p[0]) && isSlash(p[1]) { + // \\server\share\path form + return p, nil + } + if len(p) > 1 && p[1] == ':' { + // has drive letter + if len(p) == 2 { + return "", syscall.EINVAL + } + if isSlash(p[2]) { + return p, nil + } else { + d, err := normalizeDir(dir) + if err != nil { + return "", err + } + if volToUpper(int(p[0])) == volToUpper(int(d[0])) { + return syscall.FullPath(d + "\\" + p[2:]) + } else { + return syscall.FullPath(p) + } + } + } else { + // no drive letter + d, err := normalizeDir(dir) + if err != nil { + return "", err + } + if isSlash(p[0]) { + return windows.FullPath(d[:2] + p) + } else { + return windows.FullPath(d + "\\" + p) + } + } +} + +// createEnvBlock converts an array of environment strings into +// the representation required by CreateProcess: a sequence of NUL +// terminated strings followed by a nil. +// Last bytes are two UCS-2 NULs, or four NUL bytes. +func createEnvBlock(envv []string) *uint16 { + if len(envv) == 0 { + return &utf16.Encode([]rune("\x00\x00"))[0] + } + length := 0 + for _, s := range envv { + length += len(s) + 1 + } + length++ + + b := make([]byte, length) + i := 0 + for _, s := range envv { + l := len(s) + copy(b[i:i+l], []byte(s)) + copy(b[i+l:i+l+1], []byte{0}) + i = i + l + 1 + } + copy(b[i:i+1], []byte{0}) + + return &utf16.Encode([]rune(string(b)))[0] +} + +// dedupEnvCase is dedupEnv with a case option for testing. +// If caseInsensitive is true, the case of keys is ignored. +func dedupEnvCase(caseInsensitive bool, env []string) []string { + out := make([]string, 0, len(env)) + saw := make(map[string]int, len(env)) // key => index into out + for _, kv := range env { + eq := strings.Index(kv, "=") + if eq < 0 { + out = append(out, kv) + continue + } + k := kv[:eq] + if caseInsensitive { + k = strings.ToLower(k) + } + if dupIdx, isDup := saw[k]; isDup { + out[dupIdx] = kv + continue + } + saw[k] = len(out) + out = append(out, kv) + } + return out +} + +// addCriticalEnv adds any critical environment variables that are required +// (or at least almost always required) on the operating system. +// Currently this is only used for Windows. +func addCriticalEnv(env []string) []string { + for _, kv := range env { + eq := strings.Index(kv, "=") + if eq < 0 { + continue + } + k := kv[:eq] + if strings.EqualFold(k, "SYSTEMROOT") { + // We already have it. + return env + } + } + return append(env, "SYSTEMROOT="+os.Getenv("SYSTEMROOT")) +} diff --git a/examples/go.mod b/examples/go.mod index c4c1a3e..3518c5b 100644 --- a/examples/go.mod +++ b/examples/go.mod @@ -12,8 +12,8 @@ require ( require ( github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect - github.com/creack/pty v1.1.18 // indirect + github.com/creack/pty v1.1.21 // indirect github.com/u-root/u-root v0.11.0 // indirect golang.org/x/crypto v0.17.0 // indirect - golang.org/x/sys v0.15.0 // indirect + golang.org/x/sys v0.16.0 // indirect ) diff --git a/examples/go.sum b/examples/go.sum index 1036f03..630e449 100644 --- a/examples/go.sum +++ b/examples/go.sum @@ -2,8 +2,8 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFI github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= github.com/charmbracelet/ssh v0.0.0-20230822194956-1a051f898e09 h1:ZDIQmTtohv0S/AAYE//w8mYTxCzqphhF1+4ACPDMiLU= github.com/charmbracelet/ssh v0.0.0-20230822194956-1a051f898e09/go.mod h1:F1vgddWsb/Yr/OZilFeRZEh5sE/qU0Dt1mKkmke6Zvg= -github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= -github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= +github.com/creack/pty v1.1.21 h1:1/QdRyBaHHJP61QkWMXlOIBfsgdDeeKfK8SYVUWJKf0= +github.com/creack/pty v1.1.21/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/u-root/gobusybox/src v0.0.0-20221229083637-46b2883a7f90 h1:zTk5683I9K62wtZ6eUa6vu6IWwVHXPnoKK5n2unAwv0= github.com/u-root/u-root v0.11.0 h1:6gCZLOeRyevw7gbTwMj3fKxnr9+yHFlgF3N7udUVNO8= github.com/u-root/u-root v0.11.0/go.mod h1:DBkDtiZyONk9hzVEdB/PWI9B4TxDkElWlVTHseglrZY= @@ -15,8 +15,8 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= -golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= +golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4= diff --git a/examples/shell/main.go b/examples/shell/main.go index 1b54464..334be9d 100644 --- a/examples/shell/main.go +++ b/examples/shell/main.go @@ -5,6 +5,7 @@ import ( "log" "os" "os/signal" + "runtime" "github.com/aymanbagabas/go-pty" "golang.org/x/term" @@ -22,7 +23,11 @@ func test() error { defer ptmx.Close() - c := ptmx.Command(`bash`) + cmd := "bash" + if runtime.GOOS == "windows" { + cmd = "powershell.exe" + } + c := ptmx.Command(cmd) if err := c.Start(); err != nil { return err } diff --git a/pty_windows.go b/pty_windows.go index ae7fc27..9ea81fc 100644 --- a/pty_windows.go +++ b/pty_windows.go @@ -6,21 +6,13 @@ package pty import ( "context" "errors" - "fmt" "os" - "sync" - "unsafe" - "golang.org/x/sys/windows" -) - -const ( - _PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE = 0x20016 // nolint:revive + "github.com/aymanbagabas/go-pty/conpty" ) var ( - errClosedConPty = errors.New("pseudo console is closed") - errNotStarted = errors.New("process not started") + errNotStarted = errors.New("process not started") ) // conPty is a Windows console pseudo-terminal. @@ -29,52 +21,18 @@ var ( // // See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session type conPty struct { - handle windows.Handle - inPipe, outPipe *os.File - mtx sync.RWMutex + *conpty.ConPty } var _ Pty = &conPty{} func newPty() (ConPty, error) { - ptyIn, inPipeOurs, err := os.Pipe() - if err != nil { - return nil, fmt.Errorf("failed to create pipes for pseudo console: %w", err) - } - - outPipeOurs, ptyOut, err := os.Pipe() - if err != nil { - return nil, fmt.Errorf("failed to create pipes for pseudo console: %w", err) - } - - var hpc windows.Handle - coord := windows.Coord{X: 80, Y: 25} - err = windows.CreatePseudoConsole(coord, windows.Handle(ptyIn.Fd()), windows.Handle(ptyOut.Fd()), 0, &hpc) + c, err := conpty.New(conpty.DefaultWidth, conpty.DefaultHeight, 0) if err != nil { - return nil, fmt.Errorf("failed to create pseudo console: %w", err) - } - - if err := ptyOut.Close(); err != nil { - return nil, fmt.Errorf("failed to close pseudo console handle: %w", err) - } - if err := ptyIn.Close(); err != nil { - return nil, fmt.Errorf("failed to close pseudo console handle: %w", err) + return nil, err } - return &conPty{ - handle: hpc, - inPipe: inPipeOurs, - outPipe: outPipeOurs, - }, nil -} - -// Close implements Pty. -func (p *conPty) Close() error { - p.mtx.Lock() - defer p.mtx.Unlock() - - windows.ClosePseudoConsole(p.handle) - return errors.Join(p.inPipe.Close(), p.outPipe.Close()) + return &conPty{ConPty: c}, nil } // Command implements Pty. @@ -105,60 +63,17 @@ func (*conPty) Name() string { return "windows-pty" } -// Read implements Pty. -func (p *conPty) Read(b []byte) (n int, err error) { - return p.outPipe.Read(b) -} - -// Resize implements Pty. -func (p *conPty) Resize(width int, height int) error { - p.mtx.RLock() - defer p.mtx.RUnlock() - if err := windows.ResizePseudoConsole(p.handle, windows.Coord{X: int16(width), Y: int16(height)}); err != nil { - return fmt.Errorf("failed to resize pseudo console: %w", err) - } - return nil -} - -// Write implements Pty. -func (p *conPty) Write(b []byte) (n int, err error) { - return p.inPipe.Write(b) -} - // Fd implements Pty. func (p *conPty) Fd() uintptr { - p.mtx.RLock() - defer p.mtx.RUnlock() - return uintptr(p.handle) + return uintptr(p.Handle()) } // InputPipe implements ConPty. func (p *conPty) InputPipe() *os.File { - return p.inPipe + return p.InPipe() } // OutputPipe implements ConPty. func (p *conPty) OutputPipe() *os.File { - return p.outPipe -} - -// updateProcThreadAttribute updates the passed in attribute list to contain the entry necessary for use with -// CreateProcess. -func (p *conPty) updateProcThreadAttribute(attrList *windows.ProcThreadAttributeListContainer) error { - p.mtx.RLock() - defer p.mtx.RUnlock() - - if p.handle == 0 { - return errClosedConPty - } - - if err := attrList.Update( - _PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE, - unsafe.Pointer(p.handle), - unsafe.Sizeof(p.handle), - ); err != nil { - return fmt.Errorf("failed to update proc thread attributes for pseudo console: %w", err) - } - - return nil + return p.OutPipe() }