diff --git a/tlsserver.go b/tlsserver.go index 77963c2..8af03ad 100644 --- a/tlsserver.go +++ b/tlsserver.go @@ -17,12 +17,13 @@ import ( var ( cfile = flag.String("cert", "cert.pem", "Certificate file in PEM format") kfile = flag.String("key", "key.pem", "Key file in PEM format") - port = flag.Int("port", 1234, "Port to bind to") + port = flag.String("port", ":1234", "Port to bind to") uid = flag.Int("uid", -1, "UID to run under") - child = flag.Bool("child", false, "running as child") args []string ) +const FORKED = "[forked]" + func main() { flag.Parse() @@ -33,7 +34,7 @@ func main() { os.Exit(1) } - if *child { + if os.Args[0] == FORKED { log.Fatal(serve()) } else { log.Fatal(fork()) @@ -52,13 +53,14 @@ func fork() error { return err } - // Get the groups + // Get the primary and secondary groups + gids := []uint32{} + gid, err := strconv.Atoi(user.Gid) if err != nil { return fmt.Errorf("couldn't parse user.Gid, strconv.Atoi(%s): %v", user.Gid, err) } - - gids := []uint32{uint32(gid)} + gids = append(gids, uint32(gid)) if sgids, err := user.GroupIds(); err != nil { log.Println("weird, couldn't get supplimentary groups: %v", err) @@ -74,20 +76,24 @@ func fork() error { } } - args := []string{ - "-child", - "-cert", *cfile, - "-key", *kfile, - "-port", strconv.Itoa(*port)} - args = append(args, flag.Args()...) - - cmd := exec.Command(os.Args[0], args...) - + cmd := exec.Command(os.Args[0]) cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr - // set the uid/gid + // Setup the arguments that are needed + cmd.Args = []string{FORKED} // name of the command + + for _, arg := range []string{"cert", "key", "port"} { + fl := flag.Lookup(arg) + // Only set the flag if it's value are different from the default + if fl != nil && fl.Value.String() != fl.DefValue { + cmd.Args = append(cmd.Args, "-"+arg, fl.Value.String()) + } + } + cmd.Args = append(cmd.Args, flag.Args()...) + + // Set the uid/gid cmd.SysProcAttr = &syscall.SysProcAttr{ Pdeathsig: syscall.SIGTERM, Credential: &syscall.Credential{ @@ -97,7 +103,7 @@ func fork() error { }, } - // pass on signals that we receive + // Pass on signals that we receive ch := make(chan os.Signal, 1) signal.Notify(ch, os.Interrupt, @@ -113,7 +119,7 @@ func fork() error { } }() - // energy! + // Energy! return cmd.Run() } @@ -132,8 +138,7 @@ func serve() error { tconf.BuildNameToCertificate() // start listening - sport := fmt.Sprintf(":%d", *port) - sock, err := tls.Listen("tcp", sport, tconf) + sock, err := tls.Listen("tcp", *port, tconf) if err != nil { fmt.Println("error with tcp-socket:", err) os.Exit(3)