package main import ( "crypto/tls" "flag" "fmt" "log" "net" "os" "os/exec" "os/signal" "os/user" "strconv" "syscall" ) var ( cfile = flag.String("cert", "cert.pem", "Certificate file in PEM format") kfile = flag.String("key", "key.pem", "Key file in PEM format") port = flag.String("port", ":1234", "Port to bind to") uid = flag.Int("uid", -1, "UID to run under") args []string ) const FORKED = "[forked]" func main() { flag.Parse() if flag.NArg() < 1 { fmt.Println("Usage: tlsserver [options] cmd [flags for cmd]") fmt.Println("options:") flag.PrintDefaults() os.Exit(1) } if os.Args[0] == FORKED { log.Fatal(serve()) } else { log.Fatal(fork()) } } func fork() error { if *uid < 0 { fmt.Println("uid not set → not forking") return serve() } // Collect information about the user user, err := user.LookupId(strconv.Itoa(*uid)) if err != nil { return err } // Get the primary and secondary groups gids := []uint32{} gid, err := strconv.Atoi(user.Gid) if err != nil { return fmt.Errorf("couldn't parse user.Gid, strconv.Atoi(%s): %v", user.Gid, err) } gids = append(gids, uint32(gid)) if sgids, err := user.GroupIds(); err != nil { log.Println("weird, couldn't get supplimentary groups: %v", err) } else { for _, s := range sgids { v, e := strconv.Atoi(s) if e != nil { log.Println("weird, strconv.Atoi(%s): %v", s, e) continue } else if v != gid { gids = append(gids, uint32(v)) } } } cmd := exec.Command(os.Args[0]) cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr // Setup the arguments that are needed cmd.Args = []string{FORKED} // name of the command for _, arg := range []string{"cert", "key", "port"} { fl := flag.Lookup(arg) // Only set the flag if it's value are different from the default if fl != nil && fl.Value.String() != fl.DefValue { cmd.Args = append(cmd.Args, "-"+arg, fl.Value.String()) } } cmd.Args = append(cmd.Args, flag.Args()...) // Set the uid/gid cmd.SysProcAttr = &syscall.SysProcAttr{ Pdeathsig: syscall.SIGTERM, Credential: &syscall.Credential{ Uid: uint32(*uid), Gid: gids[0], Groups: gids[1:], }, } // Pass on signals that we receive ch := make(chan os.Signal, 1) signal.Notify(ch, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP, ) go func() { select { case sig := <-ch: log.Printf("received signal %s, killing child with pid %d\n", sig, cmd.Process.Pid) cmd.Process.Kill() os.Exit(0) } }() // Energy! return cmd.Run() } func serve() error { args = flag.Args() // setup certs etc. for TLS-socket tconf := new(tls.Config) cert, err := tls.LoadX509KeyPair(*cfile, *kfile) if err != nil { fmt.Println("error with certs:", err) os.Exit(2) } tconf.Certificates = append(tconf.Certificates, cert) tconf.BuildNameToCertificate() // start listening sock, err := tls.Listen("tcp", *port, tconf) if err != nil { fmt.Println("error with tcp-socket:", err) os.Exit(3) } defer sock.Close() // accept-loop for { conn, err := sock.Accept() if err != nil { log.Println("error during Accept()", err) continue } log.Println("Got connection:", conn.RemoteAddr()) go handleConnection(conn, args...) } } func handleConnection(conn net.Conn, args ...string) { defer conn.Close() // setup cmd cmd := exec.Command(args[0]) cmd.Args = args cmd.Stdin = conn cmd.Stdout = conn cmd.Stderr = os.Stderr // prepare environment according to tcp-environ(5) lh, lp, err := net.SplitHostPort(conn.LocalAddr().String()) if err != nil { log.Println(err) return } rh, rp, err := net.SplitHostPort(conn.LocalAddr().String()) if err != nil { log.Println(err) return } cmd.Env = make([]string, 0) cmd.Env = append(cmd.Env, "PATH="+os.Getenv("PATH")) cmd.Env = append(cmd.Env, "PROTO=TCP") cmd.Env = append(cmd.Env, "TCPLOCALIP="+lh) cmd.Env = append(cmd.Env, "TCPLOCALPORT="+lp) cmd.Env = append(cmd.Env, "TCPREMOTEIP="+rh) cmd.Env = append(cmd.Env, "TCPREMOTEPORT="+rp) err = cmd.Run() if err != nil { log.Println("after Run: ", err) } log.Println("Done with connection", conn.RemoteAddr()) } // Since go1.4 the setgid syscall is deliberatelly not supported anymore, as it // only applies to the calling thread. So we try this here: func setgid(gid int) error { // RawSyscall(trap, a1, a2, a3 uintptr) (r1, r2 uintptr, err Errno) _, _, e := syscall.RawSyscall(syscall.SYS_SETGID, uintptr(gid), 0, 0) if e != 0 { return fmt.Errorf(e.Error()) } return nil } func setuid(uid int) error { // RawSyscall(trap, a1, a2, a3 uintptr) (r1, r2 uintptr, err Errno) _, _, e := syscall.RawSyscall(syscall.SYS_SETUID, uintptr(uid), 0, 0) if e != 0 { return fmt.Errorf(e.Error()) } return nil }