package cgi

import (
	"bytes"
	"context"
	"crypto/sha256"
	"encoding/hex"
	"errors"
	"io"
	"io/fs"
	"net"
	"os"
	"os/exec"
	"path/filepath"
	"strings"

	sr "tildegit.org/tjp/sliderule"
)

// ResolveCGI finds a CGI program corresponding to a request path.
//
// It returns the path to the executable file and the PATH_INFO that should be passed,
// or an error.
//
// It will find executables which are just part way through the path, so for example
// a request for /foo/bar/baz can run an executable found at /foo or /foo/bar. In such
// a case the PATH_INFO would include the remaining portion of the URI path.
func ResolveCGI(requestPath, fsRoot string) (string, string, error) {
	fsRoot = strings.TrimRight(fsRoot, "/")
	segments := strings.Split(strings.TrimLeft(requestPath, "/"), "/")

	for i := range append(segments, "") {
		filepath := strings.Join(append([]string{fsRoot}, segments[:i]...), "/")
		isDir, isExecutable, err := executableFile(filepath)
		if err != nil {
			return "", "", err
		}

		if isExecutable {
			pathinfo := "/"
			if len(segments) > i+1 {
				pathinfo = strings.Join(segments[i:], "/")
			}
			return filepath, pathinfo, nil
		}

		if !isDir {
			break
		}
	}

	return "", "", nil
}

func executableFile(filepath string) (bool, bool, error) {
	info, err := os.Stat(filepath)
	if isNotExistError(err) {
		return false, false, nil
	}
	if err != nil {
		return false, false, err
	}

	if info.IsDir() {
		return true, false, nil
	}

	// readable + executable by anyone
	return false, info.Mode()&5 == 5, nil
}

func isNotExistError(err error) bool {
	if err != nil {
		var pathErr *fs.PathError
		if errors.As(err, &pathErr) {
			e := pathErr.Err
			if errors.Is(e, fs.ErrInvalid) || errors.Is(e, fs.ErrNotExist) {
				return true
			}
		}
	}

	return false
}

// RunCGI runs a specific program as a CGI script.
func RunCGI(
	ctx context.Context,
	request *sr.Request,
	executable string,
	pathInfo string,
	workdir string,
	stderr io.Writer,
) (*bytes.Buffer, int, error) {
	infoLen := len(pathInfo)
	if pathInfo == "/" {
		infoLen = 0
	}

	scriptName := request.Path[:len(request.Path)-infoLen]
	scriptName = strings.TrimSuffix(scriptName, "/")

	execpath, err := filepath.Abs(executable)
	if err != nil {
		return nil, 0, err
	}

	cmd := exec.CommandContext(ctx, execpath)
	cmd.Env = prepareCGIEnv(ctx, request, scriptName, pathInfo)
	cmd.Dir = workdir

	if body, ok := request.Meta.(io.Reader); ok {
		cmd.Stdin = body
	}
	responseBuffer := &bytes.Buffer{}
	cmd.Stdout = responseBuffer
	cmd.Stderr = stderr

	if err := cmd.Run(); err != nil {
		var exErr *exec.ExitError
		if errors.As(err, &exErr) {
			return responseBuffer, exErr.ExitCode(), nil
		}
	}
	return responseBuffer, cmd.ProcessState.ExitCode(), err
}

func prepareCGIEnv(
	ctx context.Context,
	request *sr.Request,
	scriptName string,
	pathInfo string,
) []string {
	var authType string
	if request.TLSState != nil && len(request.TLSState.PeerCertificates) > 0 {
		authType = "Certificate"
	}
	environ := []string{
		"AUTH_TYPE=" + authType,
		"CONTENT_LENGTH=",
		"CONTENT_TYPE=",
		"GATEWAY_INTERFACE=CGI/1.1",
		"PATH_INFO=" + pathInfo,
		"PATH_TRANSLATED=",
		"QUERY_STRING=" + request.RawQuery,
	}

	host, port, _ := net.SplitHostPort(request.RemoteAddr.String())
	environ = append(environ, "REMOTE_ADDR="+host, "REMOTE_PORT="+port)

	environ = append(
		environ,
		"REMOTE_HOST=",
		"REMOTE_IDENT=",
		"SCRIPT_NAME="+scriptName,
		"SERVER_NAME="+request.Server.Hostname(),
		"SERVER_PORT="+request.Server.Port(),
		"SERVER_PROTOCOL="+request.Server.Protocol(),
		"SERVER_SOFTWARE=SLIDERULE",
	)

	if request.TLSState != nil && len(request.TLSState.PeerCertificates) > 0 {
		cert := request.TLSState.PeerCertificates[0]
		environ = append(
			environ,
			"TLS_CLIENT_HASH="+fingerprint(cert.Raw),
			"TLS_CLIENT_CERT="+hex.EncodeToString(cert.Raw),
			"TLS_CLIENT_ISSUER="+cert.Issuer.String(),
			"TLS_CLIENT_ISSUER_CN="+cert.Issuer.CommonName,
			"TLS_CLIENT_SUBJECT="+cert.Subject.String(),
			"TLS_CLIENT_SUBJECT_CN="+cert.Subject.CommonName,
		)
	}

	return environ
}

func fingerprint(raw []byte) string {
	hash := sha256.Sum256(raw)
	return hex.EncodeToString(hash[:])
}