package main import ( "crypto/tls" "flag" "fmt" "log" "net" "os" "os/exec" "os/signal" "os/user" "strconv" "syscall" ) var ( cfile = flag.String("cert", "cert.pem", "Certificate file in PEM format") kfile = flag.String("key", "key.pem", "Key file in PEM format") port = flag.Int("port", 1234, "Port to bind to") uid = flag.Int("uid", -1, "UID to run under") child = flag.Bool("child", false, "running as child") args []string ) func main() { flag.Parse() if flag.NArg() < 1 { fmt.Println("Usage: tlsserver [options] cmd [flags for cmd]") fmt.Println("options:") flag.PrintDefaults() os.Exit(1) } if *child { log.Fatal(serve()) } else { log.Fatal(fork()) } } func fork() error { if *uid < 0 { fmt.Println("uid not set → not forking") return serve() } // Collect information about the user user, err := user.LookupId(strconv.Itoa(*uid)) if err != nil { return err } // Get the groups gid, err := strconv.Atoi(user.Gid) if err != nil { return fmt.Errorf("couldn't parse user.Gid, strconv.Atoi(%s): %v", user.Gid, err) } gids := []uint32{uint32(gid)} if sgids, err := user.GroupIds(); err != nil { log.Println("weird, couldn't get supplimentary groups: %v", err) } else { for _, s := range sgids { v, e := strconv.Atoi(s) if e != nil { log.Println("weird, strconv.Atoi(%s): %v", s, e) continue } else if v != gid { gids = append(gids, uint32(v)) } } } args := []string{ "-child", "-cert", *cfile, "-key", *kfile, "-port", strconv.Itoa(*port)} args = append(args, flag.Args()...) cmd := exec.Command(os.Args[0], args...) cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr // set the uid/gid cmd.SysProcAttr = &syscall.SysProcAttr{ Pdeathsig: syscall.SIGTERM, Credential: &syscall.Credential{ Uid: uint32(*uid), Gid: gids[0], Groups: gids[1:], }, } // pass on signals that we receive ch := make(chan os.Signal, 1) signal.Notify(ch, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP, ) go func() { select { case sig := <-ch: log.Printf("received signal %s, killing child with pid %d\n", sig, cmd.Process.Pid) cmd.Process.Kill() os.Exit(0) } }() // energy! return cmd.Run() } func serve() error { args = flag.Args() // setup certs etc. for TLS-socket tconf := new(tls.Config) cert, err := tls.LoadX509KeyPair(*cfile, *kfile) if err != nil { fmt.Println("error with certs:", err) os.Exit(2) } tconf.Certificates = append(tconf.Certificates, cert) tconf.BuildNameToCertificate() // start listening sport := fmt.Sprintf(":%d", *port) sock, err := tls.Listen("tcp", sport, tconf) if err != nil { fmt.Println("error with tcp-socket:", err) os.Exit(3) } defer sock.Close() // accept-loop for { conn, err := sock.Accept() if err != nil { log.Println("error during Accept()", err) continue } log.Println("Got connection:", conn.RemoteAddr()) go handleConnection(conn, args...) } } func handleConnection(conn net.Conn, args ...string) { defer conn.Close() // setup cmd cmd := exec.Command(args[0]) cmd.Args = args cmd.Stdin = conn cmd.Stdout = conn cmd.Stderr = os.Stderr // prepare environment according to tcp-environ(5) lh, lp, err := net.SplitHostPort(conn.LocalAddr().String()) if err != nil { log.Println(err) return } rh, rp, err := net.SplitHostPort(conn.LocalAddr().String()) if err != nil { log.Println(err) return } cmd.Env = make([]string, 0) cmd.Env = append(cmd.Env, "PATH="+os.Getenv("PATH")) cmd.Env = append(cmd.Env, "PROTO=TCP") cmd.Env = append(cmd.Env, "TCPLOCALIP="+lh) cmd.Env = append(cmd.Env, "TCPLOCALPORT="+lp) cmd.Env = append(cmd.Env, "TCPREMOTEIP="+rh) cmd.Env = append(cmd.Env, "TCPREMOTEPORT="+rp) err = cmd.Run() if err != nil { log.Println("after Run: ", err) } log.Println("Done with connection", conn.RemoteAddr()) } // Since go1.4 the setgid syscall is deliberatelly not supported anymore, as it // only applies to the calling thread. So we try this here: func setgid(gid int) error { // RawSyscall(trap, a1, a2, a3 uintptr) (r1, r2 uintptr, err Errno) _, _, e := syscall.RawSyscall(syscall.SYS_SETGID, uintptr(gid), 0, 0) if e != 0 { return fmt.Errorf(e.Error()) } return nil } func setuid(uid int) error { // RawSyscall(trap, a1, a2, a3 uintptr) (r1, r2 uintptr, err Errno) _, _, e := syscall.RawSyscall(syscall.SYS_SETUID, uintptr(uid), 0, 0) if e != 0 { return fmt.Errorf(e.Error()) } return nil }