package tlsauth_test import ( "bytes" "context" "crypto/tls" "crypto/x509" "net/url" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" sr "tildegit.org/tjp/sliderule" "tildegit.org/tjp/sliderule/contrib/tlsauth" "tildegit.org/tjp/sliderule/gemini" ) func TestIdentify(t *testing.T) { invoked := false var leafCert *x509.Certificate server, client, clientCert := setup(t, "testdata/server.crt", "testdata/server.key", "testdata/client1.crt", "testdata/client1.key", sr.HandlerFunc(func(_ context.Context, request *sr.Request) *sr.Response { invoked = true ident := tlsauth.Identity(request) if assert.NotNil(t, ident) { assert.True(t, ident.Equal(leafCert)) } return nil }), ) leafCert, err := x509.ParseCertificate(clientCert.Certificate[0]) require.Nil(t, err) go func() { _ = server.Serve() }() defer server.Close() requestPath(t, client, server, "/") assert.True(t, invoked) } func TestRequiredAuth(t *testing.T) { invoked1 := false invoked2 := false handler1 := sr.HandlerFunc(func(_ context.Context, request *sr.Request) *sr.Response { invoked1 = true return gemini.Success("", &bytes.Buffer{}) }) handler2 := sr.HandlerFunc(func(_ context.Context, request *sr.Request) *sr.Response { invoked2 = true return gemini.Success("", &bytes.Buffer{}) }) authMiddleware := sr.Filter(tlsauth.RequiredAuth(tlsauth.Allow), nil) handler1 = sr.Filter( func(_ context.Context, req *sr.Request) bool { return strings.HasPrefix(req.Path, "/one") }, nil, )(authMiddleware(handler1)) handler2 = authMiddleware(handler2) server, client, _ := setup(t, "testdata/server.crt", "testdata/server.key", "testdata/client1.crt", "testdata/client1.key", sr.FallthroughHandler(handler1, handler2), ) go func() { _ = server.Serve() }() defer server.Close() requestPath(t, client, server, "/one") assert.True(t, invoked1) client, _ = clientFor(t, server, "", "") // no client cert this time requestPath(t, client, server, "/two") assert.False(t, invoked2) } func TestOptionalAuth(t *testing.T) { invoked1 := false invoked2 := false handler1 := sr.HandlerFunc(func(_ context.Context, request *sr.Request) *sr.Response { if !strings.HasPrefix(request.Path, "/one") { return nil } invoked1 = true return gemini.Success("", &bytes.Buffer{}) }) handler2 := sr.HandlerFunc(func(_ context.Context, request *sr.Request) *sr.Response { invoked2 = true return gemini.Success("", &bytes.Buffer{}) }) mw := sr.Filter(tlsauth.OptionalAuth(tlsauth.Reject), nil) handler := sr.FallthroughHandler(mw(handler1), mw(handler2)) server, client, _ := setup(t, "testdata/server.crt", "testdata/server.key", "testdata/client1.crt", "testdata/client1.key", handler, ) go func() { _ = server.Serve() }() defer server.Close() requestPath(t, client, server, "/one") assert.False(t, invoked1) client, _ = clientFor(t, server, "", "") requestPath(t, client, server, "/two") assert.True(t, invoked2) } func setup( t *testing.T, serverCertPath string, serverKeyPath string, clientCertPath string, clientKeyPath string, handler sr.Handler, ) (sr.Server, gemini.Client, tls.Certificate) { serverTLS, err := gemini.FileTLS(serverCertPath, serverKeyPath) require.Nil(t, err) server, err := gemini.NewServer( context.Background(), "localhost", "tcp", "127.0.0.1:0", handler, nil, serverTLS, ) require.Nil(t, err) client, clientCert := clientFor(t, server, clientCertPath, clientKeyPath) return server, client, clientCert } func clientFor( t *testing.T, server sr.Server, certPath string, keyPath string, ) (gemini.Client, tls.Certificate) { var clientCert tls.Certificate var certs []tls.Certificate if certPath != "" { c, err := tls.LoadX509KeyPair(certPath, keyPath) require.Nil(t, err) clientCert = c certs = []tls.Certificate{c} } return gemini.NewClient(&tls.Config{ Certificates: certs, InsecureSkipVerify: true, }), clientCert } func requestPath(t *testing.T, client gemini.Client, server sr.Server, path string) *sr.Response { u, err := url.Parse("gemini://" + server.Address() + path) require.Nil(t, err) response, err := client.RoundTrip(&sr.Request{URL: u}) require.Nil(t, err) return response }