use syscall.Exec after set(g|u)id and setgroups

This commit is contained in:
Özgür Kesim 2020-02-02 21:14:54 +01:00
parent f7be4ee72f
commit 2d6bdd9fb5

View File

@ -8,7 +8,6 @@ import (
"net" "net"
"os" "os"
"os/exec" "os/exec"
"os/signal"
"os/user" "os/user"
"strconv" "strconv"
"syscall" "syscall"
@ -22,7 +21,7 @@ var (
args []string args []string
) )
const FORKED = "[forked]" const TLSSERVER = "[tlsserver]"
func main() { func main() {
flag.Parse() flag.Parse()
@ -34,16 +33,15 @@ func main() {
os.Exit(1) os.Exit(1)
} }
if os.Args[0] == FORKED { if os.Args[0] == TLSSERVER {
log.Fatal(serve()) log.Fatal(serve())
} else { } else {
log.Fatal(fork()) log.Fatal(exc())
} }
} }
func fork() error { func exc() error {
if *uid < 0 { if *uid < 0 {
fmt.Println("uid not set → not forking")
return serve() return serve()
} }
@ -54,13 +52,12 @@ func fork() error {
} }
// Get the primary and secondary groups // Get the primary and secondary groups
gids := []uint32{} gids := []int{}
gid, err := strconv.Atoi(user.Gid) gid, err := strconv.Atoi(user.Gid)
if err != nil { if err != nil {
return fmt.Errorf("couldn't parse user.Gid, strconv.Atoi(%s): %v", user.Gid, err) return fmt.Errorf("couldn't parse user.Gid, strconv.Atoi(%s): %v", user.Gid, err)
} }
gids = append(gids, uint32(gid))
if sgids, err := user.GroupIds(); err != nil { if sgids, err := user.GroupIds(); err != nil {
log.Println("weird, couldn't get supplimentary groups: %v", err) log.Println("weird, couldn't get supplimentary groups: %v", err)
@ -71,56 +68,38 @@ func fork() error {
log.Println("weird, strconv.Atoi(%s): %v", s, e) log.Println("weird, strconv.Atoi(%s): %v", s, e)
continue continue
} else if v != gid { } else if v != gid {
gids = append(gids, uint32(v)) gids = append(gids, v)
} }
} }
} }
cmd := exec.Command(os.Args[0]) if err := setgid(gid); err != nil {
cmd.Stdin = os.Stdin return err
cmd.Stdout = os.Stdout }
cmd.Stderr = os.Stderr
if len(gids) > 0 {
if err = syscall.Setgroups(gids); err != nil {
return fmt.Errorf("syscall.Setgroups: %w", err)
}
}
if err = setuid(*uid); err != nil {
return err
}
// Setup the arguments that are needed // Setup the arguments that are needed
cmd.Args = []string{FORKED} // name of the command args := []string{TLSSERVER} // name of the command
for _, arg := range []string{"cert", "key", "port"} { for _, arg := range []string{"cert", "key", "port"} {
fl := flag.Lookup(arg) fl := flag.Lookup(arg)
// Only set the flag if it's value are different from the default // Only set the flag if it's value are different from the default
if fl != nil && fl.Value.String() != fl.DefValue { if fl != nil && fl.Value.String() != fl.DefValue {
cmd.Args = append(cmd.Args, "-"+arg, fl.Value.String()) args = append(args, "-"+arg, fl.Value.String())
} }
} }
cmd.Args = append(cmd.Args, flag.Args()...) args = append(args, flag.Args()...)
// Set the uid/gid return syscall.Exec(os.Args[0], args, os.Environ())
cmd.SysProcAttr = &syscall.SysProcAttr{
Pdeathsig: syscall.SIGTERM,
Credential: &syscall.Credential{
Uid: uint32(*uid),
Gid: gids[0],
Groups: gids[1:],
},
}
// Pass on signals that we receive
ch := make(chan os.Signal, 1)
signal.Notify(ch,
os.Interrupt,
syscall.SIGTERM,
syscall.SIGHUP,
)
go func() {
select {
case sig := <-ch:
log.Printf("received signal %s, killing child with pid %d\n", sig, cmd.Process.Pid)
cmd.Process.Kill()
os.Exit(0)
}
}()
// Energy!
return cmd.Run()
} }
func serve() error { func serve() error {
@ -199,7 +178,7 @@ func setgid(gid int) error {
// RawSyscall(trap, a1, a2, a3 uintptr) (r1, r2 uintptr, err Errno) // RawSyscall(trap, a1, a2, a3 uintptr) (r1, r2 uintptr, err Errno)
_, _, e := syscall.RawSyscall(syscall.SYS_SETGID, uintptr(gid), 0, 0) _, _, e := syscall.RawSyscall(syscall.SYS_SETGID, uintptr(gid), 0, 0)
if e != 0 { if e != 0 {
return fmt.Errorf(e.Error()) return fmt.Errorf("setgid: %w", e)
} }
return nil return nil
} }
@ -208,7 +187,7 @@ func setuid(uid int) error {
// RawSyscall(trap, a1, a2, a3 uintptr) (r1, r2 uintptr, err Errno) // RawSyscall(trap, a1, a2, a3 uintptr) (r1, r2 uintptr, err Errno)
_, _, e := syscall.RawSyscall(syscall.SYS_SETUID, uintptr(uid), 0, 0) _, _, e := syscall.RawSyscall(syscall.SYS_SETUID, uintptr(uid), 0, 0)
if e != 0 { if e != 0 {
return fmt.Errorf(e.Error()) return fmt.Errorf("setuid: %w", e)
} }
return nil return nil
} }