package main

import (
	"bufio"
	"debug/elf"
	"fmt"
	"os"
	"path/filepath"
	"strconv"
	"strings"

	"github.com/optimyze-interviews/OezguerKesim/GetRuntimeAddresses/ebpf"
)

func main() {
	mapFD, err := ebpf.CreateMap()
	if err != nil {
		fmt.Printf("Failed to create eBPF map: %s\n", err)
		os.Exit(1)
	}

	fmt.Printf("Created eBPF map (FD: %d)\n", mapFD)

	//
	// Solution to your tasks goes here
	//

	proc, err := os.Open("/proc")
	if err != nil {
		fmt.Printf("Failed to open /proc: %v\n", err)
		os.Exit(1)
	}

	infos, err := proc.Readdir(-1)
	if err != nil {
		fmt.Printf("Failed to read /proc: %v\n", err)
		os.Exit(1)
	}

	proc.Close()

	var (
		ownpid = os.Getpid()
	)

	for _, pinfo := range infos {
		var pid_s = pinfo.Name()

		// The entry /proc/NNN/ must be a directory with integer name
		if !pinfo.IsDir() {
			continue
		} else if pid, err := strconv.Atoi(pid_s); err != nil {
			continue
		} else if pid == ownpid { // skip our own pid
			continue
		} else if offset, found := searchSymbolIn(pid_s, "*python3*", "_PyRuntime"); !found {
			continue
		} else {
			mapFD.Add(pid, uint64(offset))
		}
	}

	mapContents, err := mapFD.GetMap()
	if err != nil {
		fmt.Printf("Failed to get the map contents: %s", err)
		os.Exit(1)
	}

	fmt.Printf("Printing contents of map %d\n", mapFD)
	for k, v := range mapContents {
		fmt.Printf("\t%d -> 0x%x\n", k, v)
	}
	os.Exit(0)
}

type region struct {
	start uint64 // could be uintptr
	end   uint64
}

func parseRegion(in string) (r region, e error) {
	parts := strings.Split(in, "-")

	if len(parts) != 2 {
		e = fmt.Errorf("[parseRegion] unrecognized format for region: %#q", in)
		return
	}

	r.start, e = strconv.ParseUint(parts[0], 16, 64)
	if e != nil {
		e = fmt.Errorf("[parseRegion] couldn't parse start-address %#q in %#q: %w", parts[0], in, e)
		return
	}

	r.end, e = strconv.ParseUint(parts[1], 16, 64)
	if e != nil {
		e = fmt.Errorf("[parseRegion] couldn't parse end-address %#q in %#q: %w", parts[1], in, e)
		return
	}

	return
}

func searchSymbolIn(pid, glob, symbol string) (offset uint64, ok bool) {
	// read the maps file for the binary and shared libraries
	path := filepath.Join("/proc", pid, "maps")
	maps, err := os.Open(path)
	if err != nil {
		// fmt.Printf("Warning: Failed to read %#q: %v\n", path, err)
		return
	}

	scanner := bufio.NewScanner(maps)
	for scanner.Scan() {
		// address                   perms offset  dev   inode       pathname
		// 7fdd8fece000-7fdd8ff74000 rw-p 00423000 fd:01 14156759    /usr/lib/x86_64-linux-gnu/libpython3.7m.so.1.0

		fields := strings.Fields(scanner.Text())

		// TODO: we assume that the pathname contains no spaces so
		// bytes.Fields splits the line excactly into six fields

		if len(fields) != 6 {
			continue
		}

		pathname := fields[5]

		if !strings.HasPrefix(pathname, "/") { // Not a pathname
			continue
		}

		filename := filepath.Base(pathname)
		ok, err := filepath.Match(glob, filename)
		if err != nil || !ok {
			continue
		}

		if fields[1] != "rw-p" { // symbol needs to be writable
			continue
		}

		sym, section, err := findSymbol("_PyRuntime", pathname)
		if err != nil || section == nil || sym == nil {
			// TODO: error handling
			// fmt.Printf("Warning: while reading mapped file %q: %w", pathname, err)
			continue
		}

		arange, err := parseRegion(fields[0])
		if err != nil {
			fmt.Printf("%w\n", err)
			continue
		}

		fileoffset, err := strconv.ParseUint(fields[2], 16, 64)
		if err != nil {
			fmt.Printf("Error while parsing fileoffset %#q: %w\n", fields[2], err)
			continue
		}

		memoff := sym.Value - section.Addr + alignedOffset(section)

		// fmt.Printf("pid: %s\nsym: %#v\nsection: %#v\nmemoff: 0x%x\narange: %#v\nmap-fileoffset: 0x%x\npathname: %s\n", pid, sym, section, memoff, arange, fileoffset, pathname)

		// stop when only _one_ is found
		return arange.start + memoff - fileoffset, true

	}

	return 0, false
}

func findSymbol(symbol string, pathname string) (*elf.Symbol, *elf.SectionHeader, error) {
	// TODO: caching

	var (
		sym     *elf.Symbol
		section *elf.Section
		header  *elf.SectionHeader
	)

	file, err := elf.Open(pathname)
	if err != nil {
		return nil, nil, err
	}

	symbols, err := file.DynamicSymbols()
	if err != nil {
		return nil, nil, err
	}

	for _, s := range symbols {
		if s.Name == symbol {
			sym = &s
			break
		}
	}

	if sym == nil {
		return nil, nil, nil
	}

	if len(file.Sections) < int(sym.Section) {
		return nil, nil, nil
	}

	section = file.Sections[sym.Section]
	if section == nil {
		return nil, nil, nil
	}

	header = &section.SectionHeader

	return sym, header, nil
}

func alignedOffset(section *elf.SectionHeader) uint64 {
	mask := section.Addralign - 1
	return (section.Offset + mask) & (^mask)
}