package main

import (
	"bytes"
	"context"
	"crypto/sha256"
	"crypto/tls"
	"crypto/x509"
	"encoding/hex"
	"fmt"
	"log"
	"os"
	"strings"

	sr "tildegit.org/tjp/sliderule"
	"tildegit.org/tjp/sliderule/gemini"
	"tildegit.org/tjp/sliderule/logging"
)

func main() {
	// Get TLS files from the environment
	certfile, keyfile := envConfig()

	// build a TLS configuration suitable for gemini
	tlsconf, err := gemini.FileTLS(certfile, keyfile)
	if err != nil {
		log.Fatal(err)
	}

	_, infoLog, _, errLog := logging.DefaultLoggers()

	// add stdout logging to the request handler
	handler := logging.LogRequests(infoLog)(inspectHandler)

	// run the server
	server, err := gemini.NewServer(context.Background(), "localhost", "tcp4", ":1965", handler, errLog, tlsconf)
	if err != nil {
		log.Fatal(err)
	}
	server.Serve()
}

func envConfig() (string, string) {
	certfile, ok := os.LookupEnv("SERVER_CERTIFICATE")
	if !ok {
		log.Fatal("missing SERVER_CERTIFICATE environment variable")
	}

	keyfile, ok := os.LookupEnv("SERVER_PRIVATEKEY")
	if !ok {
		log.Fatal("missing SERVER_PRIVATEKEY environment variable")
	}

	return certfile, keyfile
}

var inspectHandler = sr.HandlerFunc(func(ctx context.Context, req *sr.Request) *sr.Response {
	// build and return a ```-wrapped description of the connection TLS state
	body := "```\n" + displayTLSState(req.TLSState) + "\n```"
	return gemini.Success("text/gemini", bytes.NewBufferString(body))
})

func displayTLSState(state *tls.ConnectionState) string {
	builder := &strings.Builder{}

	builder.WriteString("Version:             ")
	builder.WriteString(map[uint16]string{
		tls.VersionTLS10: "TLSv1.0",
		tls.VersionTLS11: "TLSv1.1",
		tls.VersionTLS12: "TLSv1.2",
		tls.VersionTLS13: "TLSv1.3",
		tls.VersionSSL30: "SSLv3",
	}[state.Version])
	builder.WriteString("\n")

	builder.WriteString(fmt.Sprintf("Handshake complete:  %t\n", state.HandshakeComplete))
	builder.WriteString(fmt.Sprintf("Did resume:          %t\n", state.DidResume))
	builder.WriteString(fmt.Sprintf("Cipher suite:        %x\n", state.CipherSuite))
	builder.WriteString(fmt.Sprintf("Negotiated protocol: %q\n", state.NegotiatedProtocol))
	builder.WriteString(fmt.Sprintf("Server name:         %s\n", state.ServerName))

	builder.WriteString(fmt.Sprintf("Certificates (%d)\n", len(state.PeerCertificates)))
	for i, cert := range state.PeerCertificates {
		builder.WriteString(fmt.Sprintf("  #%d:                %s\n", i+1, fingerprint(cert)))
	}

	return builder.String()
}

func fingerprint(cert *x509.Certificate) []byte {
	raw := sha256.Sum256(cert.Raw)
	dst := make([]byte, hex.EncodedLen(len(raw)))
	hex.Encode(dst, raw[:])
	return dst
}