package tlsauth_test

import (
	"bytes"
	"context"
	"strings"
	"testing"

	"github.com/stretchr/testify/assert"

	sr "tildegit.org/tjp/sliderule"
	"tildegit.org/tjp/sliderule/contrib/tlsauth"
	"tildegit.org/tjp/sliderule/gemini"
)

func TestGeminiAuth(t *testing.T) {
	handler1 := sr.HandlerFunc(func(_ context.Context, request *sr.Request) *sr.Response {
		if !strings.HasPrefix(request.Path, "/one") {
			return nil
		}

		return gemini.Success("", &bytes.Buffer{})
	})
	handler2 := sr.HandlerFunc(func(_ context.Context, request *sr.Request) *sr.Response {
		if !strings.HasPrefix(request.Path, "/two") {
			return nil
		}

		return gemini.Success("", &bytes.Buffer{})
	})
	handler3 := sr.HandlerFunc(func(_ context.Context, request *sr.Request) *sr.Response {
		if !strings.HasPrefix(request.Path, "/three") {
			return nil
		}

		return gemini.Success("", &bytes.Buffer{})
	})
	handler4 := sr.HandlerFunc(func(_ context.Context, request *sr.Request) *sr.Response {
		return gemini.Success("", &bytes.Buffer{})
	})

	handler := sr.FallthroughHandler(
		tlsauth.GeminiAuth(tlsauth.Allow)(handler1),
		tlsauth.GeminiAuth(tlsauth.Allow)(handler2),
		tlsauth.GeminiAuth(tlsauth.Reject)(handler3),
		tlsauth.GeminiAuth(tlsauth.Reject)(handler4),
	)

	server, authClient, _ := setup(t,
		"testdata/server.crt", "testdata/server.key",
		"testdata/client1.crt", "testdata/client1.key",
		handler,
	)

	authlessClient, _ := clientFor(t, server, "", "")

	go func() {
		_ = server.Serve()
	}()
	defer server.Close()

	resp := requestPath(t, authClient, server, "/one")
	assert.Equal(t, gemini.StatusSuccess, resp.Status)

	resp = requestPath(t, authlessClient, server, "/two")
	assert.Equal(t, gemini.StatusClientCertificateRequired, resp.Status)

	resp = requestPath(t, authClient, server, "/three")
	assert.Equal(t, gemini.StatusCertificateNotAuthorized, resp.Status)

	resp = requestPath(t, authlessClient, server, "/four")
	assert.Equal(t, gemini.StatusClientCertificateRequired, resp.Status)
}

func TestGeminiOptionalAuth(t *testing.T) {
	pathHandler := func(path string) sr.Handler {
		return sr.HandlerFunc(func(_ context.Context, request *sr.Request) *sr.Response {
			if !strings.HasPrefix(request.Path, path) {
				return nil
			}
			return gemini.Success("", &bytes.Buffer{})
		})
	}

	handler := sr.FallthroughHandler(
		tlsauth.GeminiOptionalAuth(tlsauth.Allow)(pathHandler("/one")),
		tlsauth.GeminiOptionalAuth(tlsauth.Allow)(pathHandler("/two")),
		tlsauth.GeminiOptionalAuth(tlsauth.Reject)(pathHandler("/three")),
		tlsauth.GeminiOptionalAuth(tlsauth.Reject)(pathHandler("/four")),
	)

	server, authClient, _ := setup(t,
		"testdata/server.crt", "testdata/server.key",
		"testdata/client1.crt", "testdata/client1.key",
		handler,
	)
	authlessClient, _ := clientFor(t, server, "", "")

	go func() {
		_ = server.Serve()
	}()
	defer server.Close()

	resp := requestPath(t, authClient, server, "/one")
	assert.Equal(t, gemini.StatusSuccess, resp.Status)

	resp = requestPath(t, authlessClient, server, "/two")
	assert.Equal(t, gemini.StatusSuccess, resp.Status)

	resp = requestPath(t, authClient, server, "/three")
	assert.Equal(t, gemini.StatusCertificateNotAuthorized, resp.Status)

	resp = requestPath(t, authlessClient, server, "/four")
	assert.Equal(t, gemini.StatusSuccess, resp.Status)
}