diff --git a/tlsserver.go b/tlsserver.go index 8af03ad..c92c660 100644 --- a/tlsserver.go +++ b/tlsserver.go @@ -8,7 +8,6 @@ import ( "net" "os" "os/exec" - "os/signal" "os/user" "strconv" "syscall" @@ -22,7 +21,7 @@ var ( args []string ) -const FORKED = "[forked]" +const TLSSERVER = "[tlsserver]" func main() { flag.Parse() @@ -34,16 +33,15 @@ func main() { os.Exit(1) } - if os.Args[0] == FORKED { + if os.Args[0] == TLSSERVER { log.Fatal(serve()) } else { - log.Fatal(fork()) + log.Fatal(exc()) } } -func fork() error { +func exc() error { if *uid < 0 { - fmt.Println("uid not set → not forking") return serve() } @@ -54,13 +52,12 @@ func fork() error { } // Get the primary and secondary groups - gids := []uint32{} + gids := []int{} gid, err := strconv.Atoi(user.Gid) if err != nil { return fmt.Errorf("couldn't parse user.Gid, strconv.Atoi(%s): %v", user.Gid, err) } - gids = append(gids, uint32(gid)) if sgids, err := user.GroupIds(); err != nil { log.Println("weird, couldn't get supplimentary groups: %v", err) @@ -71,56 +68,38 @@ func fork() error { log.Println("weird, strconv.Atoi(%s): %v", s, e) continue } else if v != gid { - gids = append(gids, uint32(v)) + gids = append(gids, v) } } } - cmd := exec.Command(os.Args[0]) - cmd.Stdin = os.Stdin - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr + if err := setgid(gid); err != nil { + return err + } + + if len(gids) > 0 { + if err = syscall.Setgroups(gids); err != nil { + return fmt.Errorf("syscall.Setgroups: %w", err) + } + } + + if err = setuid(*uid); err != nil { + return err + } // Setup the arguments that are needed - cmd.Args = []string{FORKED} // name of the command + args := []string{TLSSERVER} // name of the command for _, arg := range []string{"cert", "key", "port"} { fl := flag.Lookup(arg) // Only set the flag if it's value are different from the default if fl != nil && fl.Value.String() != fl.DefValue { - cmd.Args = append(cmd.Args, "-"+arg, fl.Value.String()) + args = append(args, "-"+arg, fl.Value.String()) } } - cmd.Args = append(cmd.Args, flag.Args()...) + args = append(args, flag.Args()...) - // Set the uid/gid - cmd.SysProcAttr = &syscall.SysProcAttr{ - Pdeathsig: syscall.SIGTERM, - Credential: &syscall.Credential{ - Uid: uint32(*uid), - Gid: gids[0], - Groups: gids[1:], - }, - } - - // Pass on signals that we receive - ch := make(chan os.Signal, 1) - signal.Notify(ch, - os.Interrupt, - syscall.SIGTERM, - syscall.SIGHUP, - ) - go func() { - select { - case sig := <-ch: - log.Printf("received signal %s, killing child with pid %d\n", sig, cmd.Process.Pid) - cmd.Process.Kill() - os.Exit(0) - } - }() - - // Energy! - return cmd.Run() + return syscall.Exec(os.Args[0], args, os.Environ()) } func serve() error { @@ -199,7 +178,7 @@ func setgid(gid int) error { // RawSyscall(trap, a1, a2, a3 uintptr) (r1, r2 uintptr, err Errno) _, _, e := syscall.RawSyscall(syscall.SYS_SETGID, uintptr(gid), 0, 0) if e != 0 { - return fmt.Errorf(e.Error()) + return fmt.Errorf("setgid: %w", e) } return nil } @@ -208,7 +187,7 @@ func setuid(uid int) error { // RawSyscall(trap, a1, a2, a3 uintptr) (r1, r2 uintptr, err Errno) _, _, e := syscall.RawSyscall(syscall.SYS_SETUID, uintptr(uid), 0, 0) if e != 0 { - return fmt.Errorf(e.Error()) + return fmt.Errorf("setuid: %w", e) } return nil }