package main import ( "crypto/tls" "flag" "fmt" "log" "net" "os" "os/exec" "os/user" "strconv" "syscall" ) var ( cfile = flag.String("cert", "cert.pem", "Certificate file in PEM format") kfile = flag.String("key", "key.pem", "Key file in PEM format") port = flag.String("port", ":1234", "Port to bind to") uid = flag.Int("uid", -1, "UID to run under") args []string ) const TLSSERVER = "[tlsserver]" func main() { flag.Parse() if flag.NArg() < 1 { fmt.Println("Usage: tlsserver [options] cmd [flags for cmd]") fmt.Println("options:") flag.PrintDefaults() os.Exit(1) } if os.Args[0] == TLSSERVER { log.Fatal(serve()) } else { log.Fatal(exc()) } } func exc() error { if *uid < 0 { return serve() } // Collect information about the user user, err := user.LookupId(strconv.Itoa(*uid)) if err != nil { return err } // Get the primary and secondary groups gids := []int{} gid, err := strconv.Atoi(user.Gid) if err != nil { return fmt.Errorf("couldn't parse user.Gid, strconv.Atoi(%s): %v", user.Gid, err) } if sgids, err := user.GroupIds(); err != nil { log.Println("weird, couldn't get supplimentary groups: %v", err) } else { for _, s := range sgids { v, e := strconv.Atoi(s) if e != nil { log.Println("weird, strconv.Atoi(%s): %v", s, e) continue } else if v != gid { gids = append(gids, v) } } } if err := setgid(gid); err != nil { return err } if len(gids) > 0 { if err = syscall.Setgroups(gids); err != nil { return fmt.Errorf("syscall.Setgroups: %w", err) } } if err = setuid(*uid); err != nil { return err } // Setup the arguments that are needed args := []string{TLSSERVER} // name of the command for _, arg := range []string{"cert", "key", "port"} { fl := flag.Lookup(arg) // Only set the flag if it's value are different from the default if fl != nil && fl.Value.String() != fl.DefValue { args = append(args, "-"+arg, fl.Value.String()) } } args = append(args, flag.Args()...) return syscall.Exec(os.Args[0], args, os.Environ()) } func serve() error { args = flag.Args() // setup certs etc. for TLS-socket tconf := new(tls.Config) cert, err := tls.LoadX509KeyPair(*cfile, *kfile) if err != nil { fmt.Println("error with certs:", err) os.Exit(2) } tconf.Certificates = append(tconf.Certificates, cert) tconf.BuildNameToCertificate() // start listening sock, err := tls.Listen("tcp", *port, tconf) if err != nil { fmt.Println("error with tcp-socket:", err) os.Exit(3) } defer sock.Close() // accept-loop for { conn, err := sock.Accept() if err != nil { log.Println("error during Accept()", err) continue } log.Println("Got connection:", conn.RemoteAddr()) go handleConnection(conn, args...) } } func handleConnection(conn net.Conn, args ...string) { defer conn.Close() // setup cmd cmd := exec.Command(args[0]) cmd.Args = args cmd.Stdin = conn cmd.Stdout = conn cmd.Stderr = os.Stderr // prepare environment according to tcp-environ(5) lh, lp, err := net.SplitHostPort(conn.LocalAddr().String()) if err != nil { log.Println("net.SplitHostPort(conn.LocalAddr()):", err) return } rh, rp, err := net.SplitHostPort(conn.RemoteAddr().String()) if err != nil { log.Println("net.SplitHostPort(conn.RemoteAddr()):", err) return } cmd.Env = make([]string, 0) cmd.Env = append(cmd.Env, "PATH="+os.Getenv("PATH")) cmd.Env = append(cmd.Env, "PROTO=TCP") cmd.Env = append(cmd.Env, "TCPLOCALIP="+lh) cmd.Env = append(cmd.Env, "TCPLOCALPORT="+lp) cmd.Env = append(cmd.Env, "TCPREMOTEIP="+rh) cmd.Env = append(cmd.Env, "TCPREMOTEPORT="+rp) err = cmd.Run() if err != nil { log.Println("after Run: ", err) } log.Println("Done with connection", conn.RemoteAddr()) } // Since go1.4 the setgid syscall is deliberatelly not supported anymore, as it // only applies to the calling thread. So we try this here: func setgid(gid int) error { // RawSyscall(trap, a1, a2, a3 uintptr) (r1, r2 uintptr, err Errno) _, _, e := syscall.RawSyscall(syscall.SYS_SETGID, uintptr(gid), 0, 0) if e != 0 { return fmt.Errorf("setgid: %w", e) } return nil } func setuid(uid int) error { // RawSyscall(trap, a1, a2, a3 uintptr) (r1, r2 uintptr, err Errno) _, _, e := syscall.RawSyscall(syscall.SYS_SETUID, uintptr(uid), 0, 0) if e != 0 { return fmt.Errorf("setuid: %w", e) } return nil }