From cedcf58ea7d729acb6ed1a9ab7aec1ae38aed102 Mon Sep 17 00:00:00 2001 From: tjpcc Date: Mon, 9 Oct 2023 08:56:53 -0600 Subject: more useful tlsauth.Approver type the predicate function should be able to see the whole context and request --- contrib/tlsauth/approver.go | 20 +++++++--- contrib/tlsauth/approver_test.go | 16 ++++++-- contrib/tlsauth/auth.go | 31 --------------- contrib/tlsauth/auth_test.go | 85 ---------------------------------------- contrib/tlsauth/gemini.go | 8 ++-- 5 files changed, 30 insertions(+), 130 deletions(-) (limited to 'contrib') diff --git a/contrib/tlsauth/approver.go b/contrib/tlsauth/approver.go index 064056d..ed442ce 100644 --- a/contrib/tlsauth/approver.go +++ b/contrib/tlsauth/approver.go @@ -1,17 +1,27 @@ package tlsauth -import "crypto/x509" +import ( + "context" + "crypto/x509" + + "tildegit.org/tjp/sliderule" +) // Approver is a function that validates a certificate. // // It should not be have to handle a nil argument. -type Approver func(*x509.Certificate) bool +type Approver func(context.Context, *sliderule.Request) bool // RequireSpecificIdentity builds an approver that demands one specific client certificate. -func RequireSpecificIdentity(identity *x509.Certificate) Approver { return identity.Equal } +func RequireSpecificIdentity(identity *x509.Certificate) Approver { + return func(_ context.Context, request *sliderule.Request) bool { + cert := Identity(request) + return cert != nil && identity.Equal(cert) + } +} // Allow is an approver which permits anything. -func Allow(_ *x509.Certificate) bool { return true } +func Allow(_ context.Context, _ *sliderule.Request) bool { return true } // Reject is an approver which denies everything. -func Reject(_ *x509.Certificate) bool { return false } +func Reject(_ context.Context, _ *sliderule.Request) bool { return false } diff --git a/contrib/tlsauth/approver_test.go b/contrib/tlsauth/approver_test.go index d2f4f07..32f7c40 100644 --- a/contrib/tlsauth/approver_test.go +++ b/contrib/tlsauth/approver_test.go @@ -1,6 +1,7 @@ package tlsauth_test import ( + "context" "crypto/tls" "crypto/x509" "errors" @@ -8,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" + "tildegit.org/tjp/sliderule" "tildegit.org/tjp/sliderule/contrib/tlsauth" ) @@ -15,18 +17,24 @@ func TestRequireSpecificIdentity(t *testing.T) { cert1, err := leafCert("testdata/client1.crt", "testdata/client1.key") assert.Nil(t, err) + req1 := &sliderule.Request{TLSState: &tls.ConnectionState{PeerCertificates: []*x509.Certificate{cert1}}} + cert2, err := leafCert("testdata/client2.crt", "testdata/client2.key") assert.Nil(t, err) + req2 := &sliderule.Request{TLSState: &tls.ConnectionState{PeerCertificates: []*x509.Certificate{cert2}}} + + ctx := context.Background() + assert.True(t, cert1.Equal(cert1)) assert.False(t, cert1.Equal(cert2)) assert.False(t, cert2.Equal(cert1)) assert.True(t, cert2.Equal(cert2)) - assert.True(t, tlsauth.RequireSpecificIdentity(cert1)(cert1)) - assert.False(t, tlsauth.RequireSpecificIdentity(cert1)(cert2)) - assert.False(t, tlsauth.RequireSpecificIdentity(cert2)(cert1)) - assert.True(t, tlsauth.RequireSpecificIdentity(cert2)(cert2)) + assert.True(t, tlsauth.RequireSpecificIdentity(cert1)(ctx, req1)) + assert.False(t, tlsauth.RequireSpecificIdentity(cert1)(ctx, req2)) + assert.False(t, tlsauth.RequireSpecificIdentity(cert2)(ctx, req1)) + assert.True(t, tlsauth.RequireSpecificIdentity(cert2)(ctx, req2)) } func leafCert(certfile, keyfile string) (*x509.Certificate, error) { diff --git a/contrib/tlsauth/auth.go b/contrib/tlsauth/auth.go index 439d297..ff8529b 100644 --- a/contrib/tlsauth/auth.go +++ b/contrib/tlsauth/auth.go @@ -1,7 +1,6 @@ package tlsauth import ( - "context" "crypto/x509" sr "tildegit.org/tjp/sliderule" @@ -14,33 +13,3 @@ func Identity(request *sr.Request) *x509.Certificate { } return request.TLSState.PeerCertificates[0] } - -// RequiredAuth produces an auth predicate. -// -// The check requires both that there is a client certificate associated with the -// request and that it passes the provided approver. -func RequiredAuth(approve Approver) func(context.Context, *sr.Request) bool { - return func(_ context.Context, request *sr.Request) bool { - identity := Identity(request) - if identity == nil { - return false - } - - return approve(identity) - } -} - -// OptionalAuth produces an auth predicate. -// -// The check allows through any request with no client certificate, but if -// there is one present then it requires that it pass the provided approver. -func OptionalAuth(approve Approver) func(context.Context, *sr.Request) bool { - return func(_ context.Context, request *sr.Request) bool { - identity := Identity(request) - if identity == nil { - return true - } - - return approve(identity) - } -} diff --git a/contrib/tlsauth/auth_test.go b/contrib/tlsauth/auth_test.go index 2a95e1c..df67159 100644 --- a/contrib/tlsauth/auth_test.go +++ b/contrib/tlsauth/auth_test.go @@ -1,12 +1,10 @@ package tlsauth_test import ( - "bytes" "context" "crypto/tls" "crypto/x509" "net/url" - "strings" "testing" "github.com/stretchr/testify/assert" @@ -47,89 +45,6 @@ func TestIdentify(t *testing.T) { assert.True(t, invoked) } -func TestRequiredAuth(t *testing.T) { - invoked1 := false - invoked2 := false - - handler1 := sr.HandlerFunc(func(_ context.Context, request *sr.Request) *sr.Response { - invoked1 = true - return gemini.Success("", &bytes.Buffer{}) - }) - - handler2 := sr.HandlerFunc(func(_ context.Context, request *sr.Request) *sr.Response { - invoked2 = true - return gemini.Success("", &bytes.Buffer{}) - }) - - authMiddleware := sr.Filter(tlsauth.RequiredAuth(tlsauth.Allow), nil) - - handler1 = sr.Filter( - func(_ context.Context, req *sr.Request) bool { - return strings.HasPrefix(req.Path, "/one") - }, - nil, - )(authMiddleware(handler1)) - handler2 = authMiddleware(handler2) - - server, client, _ := setup(t, - "testdata/server.crt", "testdata/server.key", - "testdata/client1.crt", "testdata/client1.key", - sr.FallthroughHandler(handler1, handler2), - ) - - go func() { - _ = server.Serve() - }() - defer server.Close() - - requestPath(t, client, server, "/one") - assert.True(t, invoked1) - - client, _ = clientFor(t, server, "", "") // no client cert this time - requestPath(t, client, server, "/two") - assert.False(t, invoked2) -} - -func TestOptionalAuth(t *testing.T) { - invoked1 := false - invoked2 := false - - handler1 := sr.HandlerFunc(func(_ context.Context, request *sr.Request) *sr.Response { - if !strings.HasPrefix(request.Path, "/one") { - return nil - } - - invoked1 = true - return gemini.Success("", &bytes.Buffer{}) - }) - - handler2 := sr.HandlerFunc(func(_ context.Context, request *sr.Request) *sr.Response { - invoked2 = true - return gemini.Success("", &bytes.Buffer{}) - }) - - mw := sr.Filter(tlsauth.OptionalAuth(tlsauth.Reject), nil) - handler := sr.FallthroughHandler(mw(handler1), mw(handler2)) - - server, client, _ := setup(t, - "testdata/server.crt", "testdata/server.key", - "testdata/client1.crt", "testdata/client1.key", - handler, - ) - - go func() { - _ = server.Serve() - }() - defer server.Close() - - requestPath(t, client, server, "/one") - assert.False(t, invoked1) - - client, _ = clientFor(t, server, "", "") - requestPath(t, client, server, "/two") - assert.True(t, invoked2) -} - func setup( t *testing.T, serverCertPath string, diff --git a/contrib/tlsauth/gemini.go b/contrib/tlsauth/gemini.go index 9996595..9bf07bd 100644 --- a/contrib/tlsauth/gemini.go +++ b/contrib/tlsauth/gemini.go @@ -15,11 +15,10 @@ import ( func GeminiAuth(approver Approver) sr.Middleware { return func(inner sr.Handler) sr.Handler { return sr.HandlerFunc(func(ctx context.Context, request *sr.Request) *sr.Response { - identity := Identity(request) - if identity == nil { + if Identity(request) == nil { return geminiMissingCert(ctx, request) } - if !approver(identity) { + if !approver(ctx, request) { return geminiCertNotAuthorized(ctx, request) } @@ -36,8 +35,7 @@ func GeminiAuth(approver Approver) sr.Middleware { func GeminiOptionalAuth(approver Approver) sr.Middleware { return func(inner sr.Handler) sr.Handler { return sr.HandlerFunc(func(ctx context.Context, request *sr.Request) *sr.Response { - identity := Identity(request) - if identity != nil && !approver(identity) { + if Identity(request) != nil && !approver(ctx, request) { return geminiCertNotAuthorized(ctx, request) } -- cgit v1.2.3