diff --git a/tlsserver.go b/tlsserver.go index 7bae358..64c5516 100644 --- a/tlsserver.go +++ b/tlsserver.go @@ -10,6 +10,7 @@ import ( "os/exec" "os/user" "strconv" + "sync" "syscall" ) @@ -17,6 +18,7 @@ var ( cfile = flag.String("cert", "cert.pem", "Certificate file in PEM format") kfile = flag.String("key", "key.pem", "Key file in PEM format") port = flag.String("port", ":1234", "Port to bind to") + max = flag.Int("max", 5, "Maximum allowed failed attempts from IP") uid = flag.Int("uid", -1, "UID to run under") args []string ) @@ -60,12 +62,12 @@ func exc() error { } if sgids, err := user.GroupIds(); err != nil { - log.Println("weird, couldn't get supplimentary groups: %v", err) + log.Printf("weird, couldn't get supplimentary groups: %v", err) } else { for _, s := range sgids { v, e := strconv.Atoi(s) if e != nil { - log.Println("weird, strconv.Atoi(%s): %v", s, e) + log.Printf("weird, strconv.Atoi(%s): %v", s, e) continue } else if v != gid { gids = append(gids, v) @@ -136,40 +138,79 @@ func serve() error { } } +// Track the number of bad attempts. + +var mux sync.RWMutex +var badActors = map[string]int{} + +func isBadActor(host string) bool { + mux.RLock() + defer mux.RUnlock() + if n, ok := badActors[host]; ok && n >= 10 { + return true + } + return false +} + +func markBadActor(host string) { + mux.Lock() + defer mux.Unlock() + n, _ := badActors[host] + badActors[host] = n + 1 +} + +func markGoodActor(host string) { + mux.Lock() + defer mux.Unlock() + n, _ := badActors[host] + if n > 0 { + n -= 1 + } + badActors[host] = n +} + func handleConnection(conn net.Conn, args ...string) { defer conn.Close() + // prepare environment according to tcp-environ(5) + rh, rp, err := net.SplitHostPort(conn.RemoteAddr().String()) + if err != nil { + log.Println("net.SplitHostPort(conn.RemoteAddr()):", err) + return + } + + if isBadActor(rh) { + log.Printf("Too many bad attempts from %s, dropping connection\n", rh) + return + } + + lh, lp, err := net.SplitHostPort(conn.LocalAddr().String()) + if err != nil { + log.Println("net.SplitHostPort(conn.LocalAddr()):", err) + return + } + // setup cmd cmd := exec.Command(args[0]) cmd.Args = args cmd.Stdin = conn cmd.Stdout = conn cmd.Stderr = os.Stderr - - // prepare environment according to tcp-environ(5) - lh, lp, err := net.SplitHostPort(conn.LocalAddr().String()) - if err != nil { - log.Println("net.SplitHostPort(conn.LocalAddr()):", err) - return + cmd.Env = []string{ + "PATH=" + os.Getenv("PATH"), + "PROTO=TCP", + "TCPLOCALIP=" + lh, + "TCPLOCALPORT=" + lp, + "TCPREMOTEIP=" + rh, + "TCPREMOTEPORT=" + rp, } - rh, rp, err := net.SplitHostPort(conn.RemoteAddr().String()) - if err != nil { - log.Println("net.SplitHostPort(conn.RemoteAddr()):", err) - return - } - - cmd.Env = make([]string, 0) - cmd.Env = append(cmd.Env, "PATH="+os.Getenv("PATH")) - cmd.Env = append(cmd.Env, "PROTO=TCP") - cmd.Env = append(cmd.Env, "TCPLOCALIP="+lh) - cmd.Env = append(cmd.Env, "TCPLOCALPORT="+lp) - cmd.Env = append(cmd.Env, "TCPREMOTEIP="+rh) - cmd.Env = append(cmd.Env, "TCPREMOTEPORT="+rp) - err = cmd.Run() if err != nil { - log.Println("after Run: ", err) + log.Println("Cmd error return: ", err) + markBadActor(rh) + } else { + markGoodActor(rh) } log.Println("Done with connection", conn.RemoteAddr()) }