Compare commits

..

No commits in common. "a075c84a40d56ba3e830a48574ace24e6d9d06e5" and "975b3a8eeef4e511baabe7e30b15ab28f552d160" have entirely different histories.

2 changed files with 15 additions and 134 deletions

3
go.mod
View File

@ -1,3 +0,0 @@
module tlsserver
go 1.21.6

View File

@ -2,109 +2,34 @@ package main
import ( import (
"crypto/tls" "crypto/tls"
"flag"
"fmt"
"log" "log"
"net" "net"
"flag"
"fmt"
"os" "os"
"os/exec" "os/exec"
"os/user"
"strconv"
"syscall"
) )
var ( var (
cfile = flag.String("cert", "cert.pem", "Certificate file in PEM format") cfile = flag.String("cert", "cert.pem", "Certificate file in PEM format")
kfile = flag.String("key", "key.pem", "Key file in PEM format") kfile = flag.String("key", "key.pem", "Key file in PEM format")
port = flag.String("port", ":1234", "Port to bind to") port = flag.Int("port", 1234, "Port to bind to")
uid = flag.Int("uid", -1, "UID to run under") args []string
args []string nargs int
) )
const TLSSERVER = "[tlsserver]"
func main() { func main() {
flag.Parse()
if flag.NArg() < 1 { flag.Parse()
args = flag.Args()
nargs = flag.NArg()
if nargs < 1 {
fmt.Println("Usage: tlsserver [options] cmd [flags for cmd]") fmt.Println("Usage: tlsserver [options] cmd [flags for cmd]")
fmt.Println("options:") fmt.Println("options:")
flag.PrintDefaults() flag.PrintDefaults()
os.Exit(1) os.Exit(1)
} }
if os.Args[0] == TLSSERVER {
log.Fatal(serve())
} else {
log.Fatal(exc())
}
}
func exc() error {
if *uid < 0 {
return serve()
}
// Collect information about the user
user, err := user.LookupId(strconv.Itoa(*uid))
if err != nil {
return err
}
// Get the primary and secondary groups
gids := []int{}
gid, err := strconv.Atoi(user.Gid)
if err != nil {
return fmt.Errorf("couldn't parse user.Gid, strconv.Atoi(%s): %v", user.Gid, err)
}
if sgids, err := user.GroupIds(); err != nil {
log.Println("weird, couldn't get supplimentary groups: %v", err)
} else {
for _, s := range sgids {
v, e := strconv.Atoi(s)
if e != nil {
log.Println("weird, strconv.Atoi(%s): %v", s, e)
continue
} else if v != gid {
gids = append(gids, v)
}
}
}
if err := setgid(gid); err != nil {
return err
}
if len(gids) > 0 {
if err = syscall.Setgroups(gids); err != nil {
return fmt.Errorf("syscall.Setgroups: %w", err)
}
}
if err = setuid(*uid); err != nil {
return err
}
// Setup the arguments that are needed
args := []string{TLSSERVER} // name of the command
for _, arg := range []string{"cert", "key", "port"} {
fl := flag.Lookup(arg)
// Only set the flag if it's value are different from the default
if fl != nil && fl.Value.String() != fl.DefValue {
args = append(args, "-"+arg, fl.Value.String())
}
}
args = append(args, flag.Args()...)
return syscall.Exec(os.Args[0], args, os.Environ())
}
func serve() error {
args = flag.Args()
// setup certs etc. for TLS-socket // setup certs etc. for TLS-socket
tconf := new(tls.Config) tconf := new(tls.Config)
cert, err := tls.LoadX509KeyPair(*cfile, *kfile) cert, err := tls.LoadX509KeyPair(*cfile, *kfile)
@ -117,7 +42,8 @@ func serve() error {
tconf.BuildNameToCertificate() tconf.BuildNameToCertificate()
// start listening // start listening
sock, err := tls.Listen("tcp", *port, tconf) sport := fmt.Sprintf(":%d", *port)
sock, err := tls.Listen("tcp", sport , tconf)
if err != nil { if err != nil {
fmt.Println("error with tcp-socket:", err) fmt.Println("error with tcp-socket:", err)
os.Exit(3) os.Exit(3)
@ -132,11 +58,11 @@ func serve() error {
continue continue
} }
log.Println("Got connection:", conn.RemoteAddr()) log.Println("Got connection:", conn.RemoteAddr())
go handleConnection(conn, args...) go handleConnection(conn)
} }
} }
func handleConnection(conn net.Conn, args ...string) { func handleConnection(conn net.Conn) {
defer conn.Close() defer conn.Close()
// setup cmd // setup cmd
@ -145,51 +71,9 @@ func handleConnection(conn net.Conn, args ...string) {
cmd.Stdin = conn cmd.Stdin = conn
cmd.Stdout = conn cmd.Stdout = conn
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
err := cmd.Run()
// prepare environment according to tcp-environ(5)
lh, lp, err := net.SplitHostPort(conn.LocalAddr().String())
if err != nil { if err != nil {
log.Println("net.SplitHostPort(conn.LocalAddr()):", err) log.Println(err)
return
}
rh, rp, err := net.SplitHostPort(conn.RemoteAddr().String())
if err != nil {
log.Println("net.SplitHostPort(conn.RemoteAddr()):", err)
return
}
cmd.Env = make([]string, 0)
cmd.Env = append(cmd.Env, "PATH="+os.Getenv("PATH"))
cmd.Env = append(cmd.Env, "PROTO=TCP")
cmd.Env = append(cmd.Env, "TCPLOCALIP="+lh)
cmd.Env = append(cmd.Env, "TCPLOCALPORT="+lp)
cmd.Env = append(cmd.Env, "TCPREMOTEIP="+rh)
cmd.Env = append(cmd.Env, "TCPREMOTEPORT="+rp)
err = cmd.Run()
if err != nil {
log.Println("after Run: ", err)
} }
log.Println("Done with connection", conn.RemoteAddr()) log.Println("Done with connection", conn.RemoteAddr())
} }
// Since go1.4 the setgid syscall is deliberatelly not supported anymore, as it
// only applies to the calling thread. So we try this here:
func setgid(gid int) error {
// RawSyscall(trap, a1, a2, a3 uintptr) (r1, r2 uintptr, err Errno)
_, _, e := syscall.RawSyscall(syscall.SYS_SETGID, uintptr(gid), 0, 0)
if e != 0 {
return fmt.Errorf("setgid: %w", e)
}
return nil
}
func setuid(uid int) error {
// RawSyscall(trap, a1, a2, a3 uintptr) (r1, r2 uintptr, err Errno)
_, _, e := syscall.RawSyscall(syscall.SYS_SETUID, uintptr(uid), 0, 0)
if e != 0 {
return fmt.Errorf("setuid: %w", e)
}
return nil
}