package main import ( "crypto/tls" "flag" "fmt" "log" "net" "os" "os/exec" "os/user" "strconv" "sync" "syscall" ) var ( cfile = flag.String("cert", "cert.pem", "Certificate file in PEM format") kfile = flag.String("key", "key.pem", "Key file in PEM format") port = flag.String("port", ":1234", "Port to bind to") max = flag.Int("max", 5, "Maximum allowed failed attempts from IP") uid = flag.Int("uid", -1, "UID to run under") args []string ) const TLSSERVER = "[tlsserver]" func main() { flag.Parse() if flag.NArg() < 1 { fmt.Println("Usage: tlsserver [options] cmd [flags for cmd]") fmt.Println("options:") flag.PrintDefaults() os.Exit(1) } if os.Args[0] == TLSSERVER { log.Fatal(serve()) } else { log.Fatal(exc()) } } func exc() error { if *uid < 0 { return serve() } // Collect information about the user user, err := user.LookupId(strconv.Itoa(*uid)) if err != nil { return err } // Get the primary and secondary groups gids := []int{} gid, err := strconv.Atoi(user.Gid) if err != nil { return fmt.Errorf("couldn't parse user.Gid, strconv.Atoi(%s): %v", user.Gid, err) } if sgids, err := user.GroupIds(); err != nil { log.Printf("weird, couldn't get supplimentary groups: %v", err) } else { for _, s := range sgids { v, e := strconv.Atoi(s) if e != nil { log.Printf("weird, strconv.Atoi(%s): %v", s, e) continue } else if v != gid { gids = append(gids, v) } } } if err := setgid(gid); err != nil { return err } if len(gids) > 0 { if err = syscall.Setgroups(gids); err != nil { return fmt.Errorf("syscall.Setgroups: %w", err) } } if err = setuid(*uid); err != nil { return err } // Setup the arguments that are needed args := []string{TLSSERVER} // name of the command for _, arg := range []string{"cert", "key", "port"} { fl := flag.Lookup(arg) // Only set the flag if it's value are different from the default if fl != nil && fl.Value.String() != fl.DefValue { args = append(args, "-"+arg, fl.Value.String()) } } args = append(args, flag.Args()...) return syscall.Exec(os.Args[0], args, os.Environ()) } func serve() error { args = flag.Args() // setup certs etc. for TLS-socket tconf := new(tls.Config) cert, err := tls.LoadX509KeyPair(*cfile, *kfile) if err != nil { fmt.Println("error with certs:", err) os.Exit(2) } tconf.Certificates = append(tconf.Certificates, cert) tconf.BuildNameToCertificate() // start listening sock, err := tls.Listen("tcp", *port, tconf) if err != nil { fmt.Println("error with tcp-socket:", err) os.Exit(3) } defer sock.Close() // accept-loop for { conn, err := sock.Accept() if err != nil { log.Println("error during Accept()", err) continue } log.Println("Got connection:", conn.RemoteAddr()) go handleConnection(conn, args...) } } // Track the number of bad attempts. var mux sync.RWMutex var badActors = map[string]int{} func isBadActor(host string) bool { mux.RLock() defer mux.RUnlock() if n, ok := badActors[host]; ok && n >= *max { return true } return false } func markBadActor(host string) { mux.Lock() defer mux.Unlock() n, _ := badActors[host] badActors[host] = n + 1 } func markGoodActor(host string) { mux.Lock() defer mux.Unlock() n, _ := badActors[host] if n > 0 { n -= 1 } badActors[host] = n } func handleConnection(conn net.Conn, args ...string) { defer conn.Close() // prepare environment according to tcp-environ(5) rh, rp, err := net.SplitHostPort(conn.RemoteAddr().String()) if err != nil { log.Println("net.SplitHostPort(conn.RemoteAddr()):", err) return } if isBadActor(rh) { log.Printf("Too many bad attempts from %s, dropping connection\n", rh) return } lh, lp, err := net.SplitHostPort(conn.LocalAddr().String()) if err != nil { log.Println("net.SplitHostPort(conn.LocalAddr()):", err) return } // setup cmd cmd := exec.Command(args[0]) cmd.Args = args cmd.Stdin = conn cmd.Stdout = conn cmd.Stderr = os.Stderr cmd.Env = []string{ "PATH=" + os.Getenv("PATH"), "PROTO=TCP", "TCPLOCALIP=" + lh, "TCPLOCALPORT=" + lp, "TCPREMOTEIP=" + rh, "TCPREMOTEPORT=" + rp, } err = cmd.Run() if err != nil { log.Println("Cmd error return: ", err) markBadActor(rh) } else { markGoodActor(rh) } log.Println("Done with connection", conn.RemoteAddr()) } // Since go1.4 the setgid syscall is deliberatelly not supported anymore, as it // only applies to the calling thread. So we try this here: func setgid(gid int) error { // RawSyscall(trap, a1, a2, a3 uintptr) (r1, r2 uintptr, err Errno) _, _, e := syscall.RawSyscall(syscall.SYS_SETGID, uintptr(gid), 0, 0) if e != 0 { return fmt.Errorf("setgid: %w", e) } return nil } func setuid(uid int) error { // RawSyscall(trap, a1, a2, a3 uintptr) (r1, r2 uintptr, err Errno) _, _, e := syscall.RawSyscall(syscall.SYS_SETUID, uintptr(uid), 0, 0) if e != 0 { return fmt.Errorf("setuid: %w", e) } return nil }