summaryrefslogtreecommitdiff
path: root/contrib/tlsauth/auth_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/tlsauth/auth_test.go')
-rw-r--r--contrib/tlsauth/auth_test.go182
1 files changed, 182 insertions, 0 deletions
diff --git a/contrib/tlsauth/auth_test.go b/contrib/tlsauth/auth_test.go
new file mode 100644
index 0000000..8361fc3
--- /dev/null
+++ b/contrib/tlsauth/auth_test.go
@@ -0,0 +1,182 @@
+package tlsauth_test
+
+import (
+ "bytes"
+ "context"
+ "crypto/tls"
+ "crypto/x509"
+ "net/url"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "tildegit.org/tjp/gus"
+ "tildegit.org/tjp/gus/contrib/tlsauth"
+ "tildegit.org/tjp/gus/gemini"
+)
+
+func TestIdentify(t *testing.T) {
+ invoked := false
+
+ var leafCert *x509.Certificate
+ server, client, clientCert := setup(t,
+ "testdata/server.crt", "testdata/server.key",
+ "testdata/client1.crt", "testdata/client1.key",
+ func(_ context.Context, request *gus.Request) *gus.Response {
+ invoked = true
+
+ ident := tlsauth.Identity(request)
+ if assert.NotNil(t, ident) {
+ assert.True(t, ident.Equal(leafCert))
+ }
+
+ return nil
+ },
+ )
+ leafCert, err := x509.ParseCertificate(clientCert.Certificate[0])
+ require.Nil(t, err)
+
+ go server.Serve()
+ defer server.Close()
+
+ requestPath(t, client, server, "/")
+ assert.True(t, invoked)
+}
+
+func TestRequiredAuth(t *testing.T) {
+ invoked1 := false
+ invoked2 := false
+
+ handler1 := func(_ context.Context, request *gus.Request) *gus.Response {
+ invoked1 = true
+ return gemini.Success("", &bytes.Buffer{})
+ }
+
+ handler2 := func(_ context.Context, request *gus.Request) *gus.Response {
+ invoked2 = true
+ return gemini.Success("", &bytes.Buffer{})
+ }
+
+ authMiddleware := gus.Filter(tlsauth.RequiredAuth(tlsauth.Allow), nil)
+
+ handler1 = gus.Filter(
+ func(_ context.Context, req *gus.Request) bool {
+ return strings.HasPrefix(req.Path, "/one")
+ },
+ nil,
+ )(authMiddleware(handler1))
+ handler2 = authMiddleware(handler2)
+
+ server, client, _ := setup(t,
+ "testdata/server.crt", "testdata/server.key",
+ "testdata/client1.crt", "testdata/client1.key",
+ gus.FallthroughHandler(handler1, handler2),
+ )
+
+ go server.Serve()
+ defer server.Close()
+
+ requestPath(t, client, server, "/one")
+ assert.True(t, invoked1)
+
+ client, _ = clientFor(t, server, "", "") // no client cert this time
+ requestPath(t, client, server, "/two")
+ assert.False(t, invoked2)
+}
+
+func TestOptionalAuth(t *testing.T) {
+ invoked1 := false
+ invoked2 := false
+
+ handler1 := func(_ context.Context, request *gus.Request) *gus.Response {
+ if !strings.HasPrefix(request.Path, "/one") {
+ return nil
+ }
+
+ invoked1 = true
+ return gemini.Success("", &bytes.Buffer{})
+ }
+
+ handler2 := func(_ context.Context, request *gus.Request) *gus.Response {
+ invoked2 = true
+ return gemini.Success("", &bytes.Buffer{})
+ }
+
+ mw := gus.Filter(tlsauth.OptionalAuth(tlsauth.Reject), nil)
+ handler := gus.FallthroughHandler(mw(handler1), mw(handler2))
+
+ server, client, _ := setup(t,
+ "testdata/server.crt", "testdata/server.key",
+ "testdata/client1.crt", "testdata/client1.key",
+ handler,
+ )
+
+ go server.Serve()
+ defer server.Close()
+
+ requestPath(t, client, server, "/one")
+ assert.False(t, invoked1)
+
+ client, _ = clientFor(t, server, "", "")
+ requestPath(t, client, server, "/two")
+ assert.True(t, invoked2)
+}
+
+func setup(
+ t *testing.T,
+ serverCertPath string,
+ serverKeyPath string,
+ clientCertPath string,
+ clientKeyPath string,
+ handler gus.Handler,
+) (gus.Server, gemini.Client, tls.Certificate) {
+ serverTLS, err := gemini.FileTLS(serverCertPath, serverKeyPath)
+ require.Nil(t, err)
+
+ server, err := gemini.NewServer(
+ context.Background(),
+ serverTLS,
+ "tcp",
+ "127.0.0.1:0",
+ handler,
+ )
+ require.Nil(t, err)
+
+ client, clientCert := clientFor(t, server, clientCertPath, clientKeyPath)
+
+ return server, client, clientCert
+}
+
+func clientFor(
+ t *testing.T,
+ server gus.Server,
+ certPath string,
+ keyPath string,
+) (gemini.Client, tls.Certificate) {
+ var clientCert tls.Certificate
+ var certs []tls.Certificate
+ if certPath != "" {
+ c, err := tls.LoadX509KeyPair(certPath, keyPath)
+ require.Nil(t, err)
+
+ clientCert = c
+ certs = []tls.Certificate{c}
+ }
+
+ return gemini.NewClient(&tls.Config{
+ Certificates: certs,
+ InsecureSkipVerify: true,
+ }), clientCert
+}
+
+func requestPath(t *testing.T, client gemini.Client, server gus.Server, path string) *gus.Response {
+ u, err := url.Parse("gemini://" + server.Address() + path)
+ require.Nil(t, err)
+
+ response, err := client.RoundTrip(&gus.Request{URL: u})
+ require.Nil(t, err)
+
+ return response
+}