tlsserver/tlsserver.go

215 lines
4.7 KiB
Go
Raw Normal View History

2013-07-29 12:08:57 +02:00
package main
import (
"crypto/tls"
"flag"
"fmt"
2013-07-30 13:20:53 +02:00
"log"
"net"
2013-07-29 12:08:57 +02:00
"os"
"os/exec"
"os/signal"
"os/user"
"strconv"
2013-07-30 09:08:37 +02:00
"syscall"
2013-07-29 12:08:57 +02:00
)
var (
cfile = flag.String("cert", "cert.pem", "Certificate file in PEM format")
kfile = flag.String("key", "key.pem", "Key file in PEM format")
2020-02-02 19:10:19 +01:00
port = flag.String("port", ":1234", "Port to bind to")
uid = flag.Int("uid", -1, "UID to run under")
2013-07-30 13:20:53 +02:00
args []string
2013-07-29 12:08:57 +02:00
)
2020-02-02 19:10:19 +01:00
const FORKED = "[forked]"
2013-07-29 12:08:57 +02:00
func main() {
flag.Parse()
if flag.NArg() < 1 {
2013-07-29 12:08:57 +02:00
fmt.Println("Usage: tlsserver [options] cmd [flags for cmd]")
fmt.Println("options:")
flag.PrintDefaults()
os.Exit(1)
}
2020-02-02 19:10:19 +01:00
if os.Args[0] == FORKED {
log.Fatal(serve())
} else {
log.Fatal(fork())
}
}
func fork() error {
if *uid < 0 {
fmt.Println("uid not set → not forking")
return serve()
}
// Collect information about the user
user, err := user.LookupId(strconv.Itoa(*uid))
if err != nil {
return err
}
2020-02-02 19:10:19 +01:00
// Get the primary and secondary groups
gids := []uint32{}
gid, err := strconv.Atoi(user.Gid)
if err != nil {
return fmt.Errorf("couldn't parse user.Gid, strconv.Atoi(%s): %v", user.Gid, err)
}
2020-02-02 19:10:19 +01:00
gids = append(gids, uint32(gid))
if sgids, err := user.GroupIds(); err != nil {
log.Println("weird, couldn't get supplimentary groups: %v", err)
} else {
for _, s := range sgids {
v, e := strconv.Atoi(s)
if e != nil {
log.Println("weird, strconv.Atoi(%s): %v", s, e)
continue
} else if v != gid {
gids = append(gids, uint32(v))
}
}
}
2020-02-02 19:10:19 +01:00
cmd := exec.Command(os.Args[0])
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
2020-02-02 19:10:19 +01:00
// Setup the arguments that are needed
cmd.Args = []string{FORKED} // name of the command
for _, arg := range []string{"cert", "key", "port"} {
fl := flag.Lookup(arg)
// Only set the flag if it's value are different from the default
if fl != nil && fl.Value.String() != fl.DefValue {
cmd.Args = append(cmd.Args, "-"+arg, fl.Value.String())
}
}
cmd.Args = append(cmd.Args, flag.Args()...)
// Set the uid/gid
cmd.SysProcAttr = &syscall.SysProcAttr{
Pdeathsig: syscall.SIGTERM,
Credential: &syscall.Credential{
Uid: uint32(*uid),
Gid: gids[0],
Groups: gids[1:],
},
}
2020-02-02 19:10:19 +01:00
// Pass on signals that we receive
ch := make(chan os.Signal, 1)
signal.Notify(ch,
os.Interrupt,
syscall.SIGTERM,
syscall.SIGHUP,
)
go func() {
select {
case sig := <-ch:
log.Printf("received signal %s, killing child with pid %d\n", sig, cmd.Process.Pid)
cmd.Process.Kill()
os.Exit(0)
}
}()
2020-02-02 19:10:19 +01:00
// Energy!
return cmd.Run()
}
func serve() error {
args = flag.Args()
2013-07-29 12:08:57 +02:00
// setup certs etc. for TLS-socket
tconf := new(tls.Config)
cert, err := tls.LoadX509KeyPair(*cfile, *kfile)
if err != nil {
fmt.Println("error with certs:", err)
os.Exit(2)
}
tconf.Certificates = append(tconf.Certificates, cert)
tconf.BuildNameToCertificate()
// start listening
2020-02-02 19:10:19 +01:00
sock, err := tls.Listen("tcp", *port, tconf)
2013-07-29 12:08:57 +02:00
if err != nil {
fmt.Println("error with tcp-socket:", err)
os.Exit(3)
}
defer sock.Close()
// accept-loop
for {
conn, err := sock.Accept()
if err != nil {
log.Println("error during Accept()", err)
continue
}
log.Println("Got connection:", conn.RemoteAddr())
go handleConnection(conn, args...)
2013-07-29 12:08:57 +02:00
}
}
func handleConnection(conn net.Conn, args ...string) {
2013-07-29 12:08:57 +02:00
defer conn.Close()
// setup cmd
cmd := exec.Command(args[0])
cmd.Args = args
cmd.Stdin = conn
cmd.Stdout = conn
cmd.Stderr = os.Stderr
2013-07-29 12:34:47 +02:00
// prepare environment according to tcp-environ(5)
lh, lp, err := net.SplitHostPort(conn.LocalAddr().String())
if err != nil {
log.Println(err)
return
}
rh, rp, err := net.SplitHostPort(conn.LocalAddr().String())
2013-07-29 12:08:57 +02:00
if err != nil {
log.Println(err)
2013-07-29 12:34:47 +02:00
return
}
2013-07-30 13:20:53 +02:00
cmd.Env = make([]string, 0)
2013-07-29 12:34:47 +02:00
cmd.Env = append(cmd.Env, "PATH="+os.Getenv("PATH"))
cmd.Env = append(cmd.Env, "PROTO=TCP")
cmd.Env = append(cmd.Env, "TCPLOCALIP="+lh)
cmd.Env = append(cmd.Env, "TCPLOCALPORT="+lp)
cmd.Env = append(cmd.Env, "TCPREMOTEIP="+rh)
cmd.Env = append(cmd.Env, "TCPREMOTEPORT="+rp)
err = cmd.Run()
if err != nil {
log.Println("after Run: ", err)
2013-07-29 12:08:57 +02:00
}
log.Println("Done with connection", conn.RemoteAddr())
}
2016-01-02 14:49:54 +01:00
// Since go1.4 the setgid syscall is deliberatelly not supported anymore, as it
// only applies to the calling thread. So we try this here:
func setgid(gid int) error {
// RawSyscall(trap, a1, a2, a3 uintptr) (r1, r2 uintptr, err Errno)
_, _, e := syscall.RawSyscall(syscall.SYS_SETGID, uintptr(gid), 0, 0)
if e != 0 {
return fmt.Errorf(e.Error())
}
return nil
}
func setuid(uid int) error {
// RawSyscall(trap, a1, a2, a3 uintptr) (r1, r2 uintptr, err Errno)
_, _, e := syscall.RawSyscall(syscall.SYS_SETUID, uintptr(uid), 0, 0)
if e != 0 {
return fmt.Errorf(e.Error())
}
return nil
}