diff --git a/tlsserver.go b/tlsserver.go index 4564a4f..77963c2 100644 --- a/tlsserver.go +++ b/tlsserver.go @@ -8,6 +8,9 @@ import ( "net" "os" "os/exec" + "os/signal" + "os/user" + "strconv" "syscall" ) @@ -15,27 +18,108 @@ var ( cfile = flag.String("cert", "cert.pem", "Certificate file in PEM format") kfile = flag.String("key", "key.pem", "Key file in PEM format") port = flag.Int("port", 1234, "Port to bind to") - /* - Rather than using setuid/setgid we rely on setcap CAP_NET_BIND_SERVICE - uid = flag.Int("uid", -1, "UID to run under") - gid = flag.Int("gid", -1, "GID to run under") - */ + uid = flag.Int("uid", -1, "UID to run under") + child = flag.Bool("child", false, "running as child") args []string - nargs int ) func main() { - flag.Parse() - args = flag.Args() - nargs = flag.NArg() - if nargs < 1 { + + if flag.NArg() < 1 { fmt.Println("Usage: tlsserver [options] cmd [flags for cmd]") fmt.Println("options:") flag.PrintDefaults() os.Exit(1) } + if *child { + log.Fatal(serve()) + } else { + log.Fatal(fork()) + } +} + +func fork() error { + if *uid < 0 { + fmt.Println("uid not set → not forking") + return serve() + } + + // Collect information about the user + user, err := user.LookupId(strconv.Itoa(*uid)) + if err != nil { + return err + } + + // Get the groups + gid, err := strconv.Atoi(user.Gid) + if err != nil { + return fmt.Errorf("couldn't parse user.Gid, strconv.Atoi(%s): %v", user.Gid, err) + } + + gids := []uint32{uint32(gid)} + + if sgids, err := user.GroupIds(); err != nil { + log.Println("weird, couldn't get supplimentary groups: %v", err) + } else { + for _, s := range sgids { + v, e := strconv.Atoi(s) + if e != nil { + log.Println("weird, strconv.Atoi(%s): %v", s, e) + continue + } else if v != gid { + gids = append(gids, uint32(v)) + } + } + } + + args := []string{ + "-child", + "-cert", *cfile, + "-key", *kfile, + "-port", strconv.Itoa(*port)} + args = append(args, flag.Args()...) + + cmd := exec.Command(os.Args[0], args...) + + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + // set the uid/gid + cmd.SysProcAttr = &syscall.SysProcAttr{ + Pdeathsig: syscall.SIGTERM, + Credential: &syscall.Credential{ + Uid: uint32(*uid), + Gid: gids[0], + Groups: gids[1:], + }, + } + + // pass on signals that we receive + ch := make(chan os.Signal, 1) + signal.Notify(ch, + os.Interrupt, + syscall.SIGTERM, + syscall.SIGHUP, + ) + go func() { + select { + case sig := <-ch: + log.Printf("received signal %s, killing child with pid %d\n", sig, cmd.Process.Pid) + cmd.Process.Kill() + os.Exit(0) + } + }() + + // energy! + return cmd.Run() +} + +func serve() error { + args = flag.Args() + // setup certs etc. for TLS-socket tconf := new(tls.Config) cert, err := tls.LoadX509KeyPair(*cfile, *kfile) @@ -56,29 +140,6 @@ func main() { } defer sock.Close() - /* - The right way to handle/drop privileges is to start with a - low-privileged user and use setcap CAP_NET_BIND_SERVICE on the - binary to allow for the listen-operation. - - // set uid/gid - if *gid >= 0 { - err := setgid(*gid) // syscall.Setgid(*gid) - if err != nil { - fmt.Println("Couldn't setgid to", *gid, ":", err) - os.Exit(4) - } - } - - if *uid >= 0 { - err := setuid(*uid) // syscall.Setuid(*uid) - if err != nil { - fmt.Println("Couldn't setuid to", *uid, ":", err) - os.Exit(4) - } - } - */ - // accept-loop for { conn, err := sock.Accept() @@ -87,11 +148,11 @@ func main() { continue } log.Println("Got connection:", conn.RemoteAddr()) - go handleConnection(conn) + go handleConnection(conn, args...) } } -func handleConnection(conn net.Conn) { +func handleConnection(conn net.Conn, args ...string) { defer conn.Close() // setup cmd @@ -100,7 +161,6 @@ func handleConnection(conn net.Conn) { cmd.Stdin = conn cmd.Stdout = conn cmd.Stderr = os.Stderr - cmd.SysProcAttr = &syscall.SysProcAttr{} // prepare environment according to tcp-environ(5) lh, lp, err := net.SplitHostPort(conn.LocalAddr().String())