summaryrefslogtreecommitdiff
path: root/tlsserver.go
diff options
context:
space:
mode:
authorÖzgür Kesim <oec@kesim.org>2020-02-02 21:14:54 +0100
committerÖzgür Kesim <oec@kesim.org>2020-02-02 21:14:54 +0100
commit2d6bdd9fb54a5bef5883d51076bb4e41b68036e5 (patch)
tree1e239a80638e73c15c72835e2f4655b61c14f893 /tlsserver.go
parentf7be4ee72ff31c06b4dd10115a3dac445fb86ede (diff)
use syscall.Exec after set(g|u)id and setgroups
Diffstat (limited to 'tlsserver.go')
-rw-r--r--tlsserver.go71
1 files changed, 25 insertions, 46 deletions
diff --git a/tlsserver.go b/tlsserver.go
index 8af03ad..c92c660 100644
--- a/tlsserver.go
+++ b/tlsserver.go
@@ -8,7 +8,6 @@ import (
"net"
"os"
"os/exec"
- "os/signal"
"os/user"
"strconv"
"syscall"
@@ -22,7 +21,7 @@ var (
args []string
)
-const FORKED = "[forked]"
+const TLSSERVER = "[tlsserver]"
func main() {
flag.Parse()
@@ -34,16 +33,15 @@ func main() {
os.Exit(1)
}
- if os.Args[0] == FORKED {
+ if os.Args[0] == TLSSERVER {
log.Fatal(serve())
} else {
- log.Fatal(fork())
+ log.Fatal(exc())
}
}
-func fork() error {
+func exc() error {
if *uid < 0 {
- fmt.Println("uid not set → not forking")
return serve()
}
@@ -54,13 +52,12 @@ func fork() error {
}
// Get the primary and secondary groups
- gids := []uint32{}
+ gids := []int{}
gid, err := strconv.Atoi(user.Gid)
if err != nil {
return fmt.Errorf("couldn't parse user.Gid, strconv.Atoi(%s): %v", user.Gid, err)
}
- gids = append(gids, uint32(gid))
if sgids, err := user.GroupIds(); err != nil {
log.Println("weird, couldn't get supplimentary groups: %v", err)
@@ -71,56 +68,38 @@ func fork() error {
log.Println("weird, strconv.Atoi(%s): %v", s, e)
continue
} else if v != gid {
- gids = append(gids, uint32(v))
+ gids = append(gids, v)
}
}
}
- cmd := exec.Command(os.Args[0])
- cmd.Stdin = os.Stdin
- cmd.Stdout = os.Stdout
- cmd.Stderr = os.Stderr
+ if err := setgid(gid); err != nil {
+ return err
+ }
+
+ if len(gids) > 0 {
+ if err = syscall.Setgroups(gids); err != nil {
+ return fmt.Errorf("syscall.Setgroups: %w", err)
+ }
+ }
+
+ if err = setuid(*uid); err != nil {
+ return err
+ }
// Setup the arguments that are needed
- cmd.Args = []string{FORKED} // name of the command
+ args := []string{TLSSERVER} // name of the command
for _, arg := range []string{"cert", "key", "port"} {
fl := flag.Lookup(arg)
// Only set the flag if it's value are different from the default
if fl != nil && fl.Value.String() != fl.DefValue {
- cmd.Args = append(cmd.Args, "-"+arg, fl.Value.String())
+ args = append(args, "-"+arg, fl.Value.String())
}
}
- cmd.Args = append(cmd.Args, flag.Args()...)
-
- // Set the uid/gid
- cmd.SysProcAttr = &syscall.SysProcAttr{
- Pdeathsig: syscall.SIGTERM,
- Credential: &syscall.Credential{
- Uid: uint32(*uid),
- Gid: gids[0],
- Groups: gids[1:],
- },
- }
-
- // Pass on signals that we receive
- ch := make(chan os.Signal, 1)
- signal.Notify(ch,
- os.Interrupt,
- syscall.SIGTERM,
- syscall.SIGHUP,
- )
- go func() {
- select {
- case sig := <-ch:
- log.Printf("received signal %s, killing child with pid %d\n", sig, cmd.Process.Pid)
- cmd.Process.Kill()
- os.Exit(0)
- }
- }()
+ args = append(args, flag.Args()...)
- // Energy!
- return cmd.Run()
+ return syscall.Exec(os.Args[0], args, os.Environ())
}
func serve() error {
@@ -199,7 +178,7 @@ func setgid(gid int) error {
// RawSyscall(trap, a1, a2, a3 uintptr) (r1, r2 uintptr, err Errno)
_, _, e := syscall.RawSyscall(syscall.SYS_SETGID, uintptr(gid), 0, 0)
if e != 0 {
- return fmt.Errorf(e.Error())
+ return fmt.Errorf("setgid: %w", e)
}
return nil
}
@@ -208,7 +187,7 @@ func setuid(uid int) error {
// RawSyscall(trap, a1, a2, a3 uintptr) (r1, r2 uintptr, err Errno)
_, _, e := syscall.RawSyscall(syscall.SYS_SETUID, uintptr(uid), 0, 0)
if e != 0 {
- return fmt.Errorf(e.Error())
+ return fmt.Errorf("setuid: %w", e)
}
return nil
}