diff options
Diffstat (limited to 'contrib/tlsauth/auth_test.go')
-rw-r--r-- | contrib/tlsauth/auth_test.go | 182 |
1 files changed, 182 insertions, 0 deletions
diff --git a/contrib/tlsauth/auth_test.go b/contrib/tlsauth/auth_test.go new file mode 100644 index 0000000..8361fc3 --- /dev/null +++ b/contrib/tlsauth/auth_test.go @@ -0,0 +1,182 @@ +package tlsauth_test + +import ( + "bytes" + "context" + "crypto/tls" + "crypto/x509" + "net/url" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "tildegit.org/tjp/gus" + "tildegit.org/tjp/gus/contrib/tlsauth" + "tildegit.org/tjp/gus/gemini" +) + +func TestIdentify(t *testing.T) { + invoked := false + + var leafCert *x509.Certificate + server, client, clientCert := setup(t, + "testdata/server.crt", "testdata/server.key", + "testdata/client1.crt", "testdata/client1.key", + func(_ context.Context, request *gus.Request) *gus.Response { + invoked = true + + ident := tlsauth.Identity(request) + if assert.NotNil(t, ident) { + assert.True(t, ident.Equal(leafCert)) + } + + return nil + }, + ) + leafCert, err := x509.ParseCertificate(clientCert.Certificate[0]) + require.Nil(t, err) + + go server.Serve() + defer server.Close() + + requestPath(t, client, server, "/") + assert.True(t, invoked) +} + +func TestRequiredAuth(t *testing.T) { + invoked1 := false + invoked2 := false + + handler1 := func(_ context.Context, request *gus.Request) *gus.Response { + invoked1 = true + return gemini.Success("", &bytes.Buffer{}) + } + + handler2 := func(_ context.Context, request *gus.Request) *gus.Response { + invoked2 = true + return gemini.Success("", &bytes.Buffer{}) + } + + authMiddleware := gus.Filter(tlsauth.RequiredAuth(tlsauth.Allow), nil) + + handler1 = gus.Filter( + func(_ context.Context, req *gus.Request) bool { + return strings.HasPrefix(req.Path, "/one") + }, + nil, + )(authMiddleware(handler1)) + handler2 = authMiddleware(handler2) + + server, client, _ := setup(t, + "testdata/server.crt", "testdata/server.key", + "testdata/client1.crt", "testdata/client1.key", + gus.FallthroughHandler(handler1, handler2), + ) + + go server.Serve() + defer server.Close() + + requestPath(t, client, server, "/one") + assert.True(t, invoked1) + + client, _ = clientFor(t, server, "", "") // no client cert this time + requestPath(t, client, server, "/two") + assert.False(t, invoked2) +} + +func TestOptionalAuth(t *testing.T) { + invoked1 := false + invoked2 := false + + handler1 := func(_ context.Context, request *gus.Request) *gus.Response { + if !strings.HasPrefix(request.Path, "/one") { + return nil + } + + invoked1 = true + return gemini.Success("", &bytes.Buffer{}) + } + + handler2 := func(_ context.Context, request *gus.Request) *gus.Response { + invoked2 = true + return gemini.Success("", &bytes.Buffer{}) + } + + mw := gus.Filter(tlsauth.OptionalAuth(tlsauth.Reject), nil) + handler := gus.FallthroughHandler(mw(handler1), mw(handler2)) + + server, client, _ := setup(t, + "testdata/server.crt", "testdata/server.key", + "testdata/client1.crt", "testdata/client1.key", + handler, + ) + + go server.Serve() + defer server.Close() + + requestPath(t, client, server, "/one") + assert.False(t, invoked1) + + client, _ = clientFor(t, server, "", "") + requestPath(t, client, server, "/two") + assert.True(t, invoked2) +} + +func setup( + t *testing.T, + serverCertPath string, + serverKeyPath string, + clientCertPath string, + clientKeyPath string, + handler gus.Handler, +) (gus.Server, gemini.Client, tls.Certificate) { + serverTLS, err := gemini.FileTLS(serverCertPath, serverKeyPath) + require.Nil(t, err) + + server, err := gemini.NewServer( + context.Background(), + serverTLS, + "tcp", + "127.0.0.1:0", + handler, + ) + require.Nil(t, err) + + client, clientCert := clientFor(t, server, clientCertPath, clientKeyPath) + + return server, client, clientCert +} + +func clientFor( + t *testing.T, + server gus.Server, + certPath string, + keyPath string, +) (gemini.Client, tls.Certificate) { + var clientCert tls.Certificate + var certs []tls.Certificate + if certPath != "" { + c, err := tls.LoadX509KeyPair(certPath, keyPath) + require.Nil(t, err) + + clientCert = c + certs = []tls.Certificate{c} + } + + return gemini.NewClient(&tls.Config{ + Certificates: certs, + InsecureSkipVerify: true, + }), clientCert +} + +func requestPath(t *testing.T, client gemini.Client, server gus.Server, path string) *gus.Response { + u, err := url.Parse("gemini://" + server.Address() + path) + require.Nil(t, err) + + response, err := client.RoundTrip(&gus.Request{URL: u}) + require.Nil(t, err) + + return response +} |