package tlsauth_test

import (
	"context"
	"crypto/tls"
	"crypto/x509"
	"net/url"
	"testing"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"

	sr "tildegit.org/tjp/sliderule"
	"tildegit.org/tjp/sliderule/contrib/tlsauth"
	"tildegit.org/tjp/sliderule/gemini"
)

func TestIdentify(t *testing.T) {
	invoked := false

	var leafCert *x509.Certificate
	server, client, clientCert := setup(t,
		"testdata/server.crt", "testdata/server.key",
		"testdata/client1.crt", "testdata/client1.key",
		sr.HandlerFunc(func(_ context.Context, request *sr.Request) *sr.Response {
			invoked = true

			ident := tlsauth.Identity(request)
			if assert.NotNil(t, ident) {
				assert.True(t, ident.Equal(leafCert))
			}

			return nil
		}),
	)
	leafCert, err := x509.ParseCertificate(clientCert.Certificate[0])
	require.Nil(t, err)

	go func() {
		_ = server.Serve()
	}()
	defer server.Close()

	requestPath(t, client, server, "/")
	assert.True(t, invoked)
}

func setup(
	t *testing.T,
	serverCertPath string,
	serverKeyPath string,
	clientCertPath string,
	clientKeyPath string,
	handler sr.Handler,
) (sr.Server, gemini.Client, tls.Certificate) {
	serverTLS, err := gemini.FileTLS(serverCertPath, serverKeyPath)
	require.Nil(t, err)

	server, err := gemini.NewServer(
		context.Background(),
		"localhost",
		"tcp",
		"127.0.0.1:0",
		handler,
		nil,
		serverTLS,
	)
	require.Nil(t, err)

	client, clientCert := clientFor(t, server, clientCertPath, clientKeyPath)

	return server, client, clientCert
}

func clientFor(
	t *testing.T,
	server sr.Server,
	certPath string,
	keyPath string,
) (gemini.Client, tls.Certificate) {
	var clientCert tls.Certificate
	var certs []tls.Certificate
	if certPath != "" {
		c, err := tls.LoadX509KeyPair(certPath, keyPath)
		require.Nil(t, err)

		clientCert = c
		certs = []tls.Certificate{c}
	}

	return gemini.NewClient(&tls.Config{
		Certificates:       certs,
		InsecureSkipVerify: true,
	}), clientCert
}

func requestPath(t *testing.T, client gemini.Client, server sr.Server, path string) *sr.Response {
	u, err := url.Parse("gemini://" + server.Address() + path)
	require.Nil(t, err)

	response, err := client.RoundTrip(&sr.Request{URL: u})
	require.Nil(t, err)

	return response
}