package tlsauth_test

import (
	"bytes"
	"context"
	"crypto/tls"
	"crypto/x509"
	"net/url"
	"strings"
	"testing"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"

	sr "tildegit.org/tjp/sliderule"
	"tildegit.org/tjp/sliderule/contrib/tlsauth"
	"tildegit.org/tjp/sliderule/gemini"
)

func TestIdentify(t *testing.T) {
	invoked := false

	var leafCert *x509.Certificate
	server, client, clientCert := setup(t,
		"testdata/server.crt", "testdata/server.key",
		"testdata/client1.crt", "testdata/client1.key",
		sr.HandlerFunc(func(_ context.Context, request *sr.Request) *sr.Response {
			invoked = true

			ident := tlsauth.Identity(request)
			if assert.NotNil(t, ident) {
				assert.True(t, ident.Equal(leafCert))
			}

			return nil
		}),
	)
	leafCert, err := x509.ParseCertificate(clientCert.Certificate[0])
	require.Nil(t, err)

	go func() {
		_ = server.Serve()
	}()
	defer server.Close()

	requestPath(t, client, server, "/")
	assert.True(t, invoked)
}

func TestRequiredAuth(t *testing.T) {
	invoked1 := false
	invoked2 := false

	handler1 := sr.HandlerFunc(func(_ context.Context, request *sr.Request) *sr.Response {
		invoked1 = true
		return gemini.Success("", &bytes.Buffer{})
	})

	handler2 := sr.HandlerFunc(func(_ context.Context, request *sr.Request) *sr.Response {
		invoked2 = true
		return gemini.Success("", &bytes.Buffer{})
	})

	authMiddleware := sr.Filter(tlsauth.RequiredAuth(tlsauth.Allow), nil)

	handler1 = sr.Filter(
		func(_ context.Context, req *sr.Request) bool {
			return strings.HasPrefix(req.Path, "/one")
		},
		nil,
	)(authMiddleware(handler1))
	handler2 = authMiddleware(handler2)

	server, client, _ := setup(t,
		"testdata/server.crt", "testdata/server.key",
		"testdata/client1.crt", "testdata/client1.key",
		sr.FallthroughHandler(handler1, handler2),
	)

	go func() {
		_ = server.Serve()
	}()
	defer server.Close()

	requestPath(t, client, server, "/one")
	assert.True(t, invoked1)

	client, _ = clientFor(t, server, "", "") // no client cert this time
	requestPath(t, client, server, "/two")
	assert.False(t, invoked2)
}

func TestOptionalAuth(t *testing.T) {
	invoked1 := false
	invoked2 := false

	handler1 := sr.HandlerFunc(func(_ context.Context, request *sr.Request) *sr.Response {
		if !strings.HasPrefix(request.Path, "/one") {
			return nil
		}

		invoked1 = true
		return gemini.Success("", &bytes.Buffer{})
	})

	handler2 := sr.HandlerFunc(func(_ context.Context, request *sr.Request) *sr.Response {
		invoked2 = true
		return gemini.Success("", &bytes.Buffer{})
	})

	mw := sr.Filter(tlsauth.OptionalAuth(tlsauth.Reject), nil)
	handler := sr.FallthroughHandler(mw(handler1), mw(handler2))

	server, client, _ := setup(t,
		"testdata/server.crt", "testdata/server.key",
		"testdata/client1.crt", "testdata/client1.key",
		handler,
	)

	go func() {
		_ = server.Serve()
	}()
	defer server.Close()

	requestPath(t, client, server, "/one")
	assert.False(t, invoked1)

	client, _ = clientFor(t, server, "", "")
	requestPath(t, client, server, "/two")
	assert.True(t, invoked2)
}

func setup(
	t *testing.T,
	serverCertPath string,
	serverKeyPath string,
	clientCertPath string,
	clientKeyPath string,
	handler sr.Handler,
) (sr.Server, gemini.Client, tls.Certificate) {
	serverTLS, err := gemini.FileTLS(serverCertPath, serverKeyPath)
	require.Nil(t, err)

	server, err := gemini.NewServer(
		context.Background(),
		"localhost",
		"tcp",
		"127.0.0.1:0",
		handler,
		nil,
		serverTLS,
	)
	require.Nil(t, err)

	client, clientCert := clientFor(t, server, clientCertPath, clientKeyPath)

	return server, client, clientCert
}

func clientFor(
	t *testing.T,
	server sr.Server,
	certPath string,
	keyPath string,
) (gemini.Client, tls.Certificate) {
	var clientCert tls.Certificate
	var certs []tls.Certificate
	if certPath != "" {
		c, err := tls.LoadX509KeyPair(certPath, keyPath)
		require.Nil(t, err)

		clientCert = c
		certs = []tls.Certificate{c}
	}

	return gemini.NewClient(&tls.Config{
		Certificates:       certs,
		InsecureSkipVerify: true,
	}), clientCert
}

func requestPath(t *testing.T, client gemini.Client, server sr.Server, path string) *sr.Response {
	u, err := url.Parse("gemini://" + server.Address() + path)
	require.Nil(t, err)

	response, err := client.RoundTrip(&sr.Request{URL: u})
	require.Nil(t, err)

	return response
}