summaryrefslogtreecommitdiff
path: root/contrib
diff options
context:
space:
mode:
Diffstat (limited to 'contrib')
-rw-r--r--contrib/tlsauth/approver.go20
-rw-r--r--contrib/tlsauth/approver_test.go16
-rw-r--r--contrib/tlsauth/auth.go31
-rw-r--r--contrib/tlsauth/auth_test.go85
-rw-r--r--contrib/tlsauth/gemini.go8
5 files changed, 30 insertions, 130 deletions
diff --git a/contrib/tlsauth/approver.go b/contrib/tlsauth/approver.go
index 064056d..ed442ce 100644
--- a/contrib/tlsauth/approver.go
+++ b/contrib/tlsauth/approver.go
@@ -1,17 +1,27 @@
package tlsauth
-import "crypto/x509"
+import (
+ "context"
+ "crypto/x509"
+
+ "tildegit.org/tjp/sliderule"
+)
// Approver is a function that validates a certificate.
//
// It should not be have to handle a nil argument.
-type Approver func(*x509.Certificate) bool
+type Approver func(context.Context, *sliderule.Request) bool
// RequireSpecificIdentity builds an approver that demands one specific client certificate.
-func RequireSpecificIdentity(identity *x509.Certificate) Approver { return identity.Equal }
+func RequireSpecificIdentity(identity *x509.Certificate) Approver {
+ return func(_ context.Context, request *sliderule.Request) bool {
+ cert := Identity(request)
+ return cert != nil && identity.Equal(cert)
+ }
+}
// Allow is an approver which permits anything.
-func Allow(_ *x509.Certificate) bool { return true }
+func Allow(_ context.Context, _ *sliderule.Request) bool { return true }
// Reject is an approver which denies everything.
-func Reject(_ *x509.Certificate) bool { return false }
+func Reject(_ context.Context, _ *sliderule.Request) bool { return false }
diff --git a/contrib/tlsauth/approver_test.go b/contrib/tlsauth/approver_test.go
index d2f4f07..32f7c40 100644
--- a/contrib/tlsauth/approver_test.go
+++ b/contrib/tlsauth/approver_test.go
@@ -1,6 +1,7 @@
package tlsauth_test
import (
+ "context"
"crypto/tls"
"crypto/x509"
"errors"
@@ -8,6 +9,7 @@ import (
"github.com/stretchr/testify/assert"
+ "tildegit.org/tjp/sliderule"
"tildegit.org/tjp/sliderule/contrib/tlsauth"
)
@@ -15,18 +17,24 @@ func TestRequireSpecificIdentity(t *testing.T) {
cert1, err := leafCert("testdata/client1.crt", "testdata/client1.key")
assert.Nil(t, err)
+ req1 := &sliderule.Request{TLSState: &tls.ConnectionState{PeerCertificates: []*x509.Certificate{cert1}}}
+
cert2, err := leafCert("testdata/client2.crt", "testdata/client2.key")
assert.Nil(t, err)
+ req2 := &sliderule.Request{TLSState: &tls.ConnectionState{PeerCertificates: []*x509.Certificate{cert2}}}
+
+ ctx := context.Background()
+
assert.True(t, cert1.Equal(cert1))
assert.False(t, cert1.Equal(cert2))
assert.False(t, cert2.Equal(cert1))
assert.True(t, cert2.Equal(cert2))
- assert.True(t, tlsauth.RequireSpecificIdentity(cert1)(cert1))
- assert.False(t, tlsauth.RequireSpecificIdentity(cert1)(cert2))
- assert.False(t, tlsauth.RequireSpecificIdentity(cert2)(cert1))
- assert.True(t, tlsauth.RequireSpecificIdentity(cert2)(cert2))
+ assert.True(t, tlsauth.RequireSpecificIdentity(cert1)(ctx, req1))
+ assert.False(t, tlsauth.RequireSpecificIdentity(cert1)(ctx, req2))
+ assert.False(t, tlsauth.RequireSpecificIdentity(cert2)(ctx, req1))
+ assert.True(t, tlsauth.RequireSpecificIdentity(cert2)(ctx, req2))
}
func leafCert(certfile, keyfile string) (*x509.Certificate, error) {
diff --git a/contrib/tlsauth/auth.go b/contrib/tlsauth/auth.go
index 439d297..ff8529b 100644
--- a/contrib/tlsauth/auth.go
+++ b/contrib/tlsauth/auth.go
@@ -1,7 +1,6 @@
package tlsauth
import (
- "context"
"crypto/x509"
sr "tildegit.org/tjp/sliderule"
@@ -14,33 +13,3 @@ func Identity(request *sr.Request) *x509.Certificate {
}
return request.TLSState.PeerCertificates[0]
}
-
-// RequiredAuth produces an auth predicate.
-//
-// The check requires both that there is a client certificate associated with the
-// request and that it passes the provided approver.
-func RequiredAuth(approve Approver) func(context.Context, *sr.Request) bool {
- return func(_ context.Context, request *sr.Request) bool {
- identity := Identity(request)
- if identity == nil {
- return false
- }
-
- return approve(identity)
- }
-}
-
-// OptionalAuth produces an auth predicate.
-//
-// The check allows through any request with no client certificate, but if
-// there is one present then it requires that it pass the provided approver.
-func OptionalAuth(approve Approver) func(context.Context, *sr.Request) bool {
- return func(_ context.Context, request *sr.Request) bool {
- identity := Identity(request)
- if identity == nil {
- return true
- }
-
- return approve(identity)
- }
-}
diff --git a/contrib/tlsauth/auth_test.go b/contrib/tlsauth/auth_test.go
index 2a95e1c..df67159 100644
--- a/contrib/tlsauth/auth_test.go
+++ b/contrib/tlsauth/auth_test.go
@@ -1,12 +1,10 @@
package tlsauth_test
import (
- "bytes"
"context"
"crypto/tls"
"crypto/x509"
"net/url"
- "strings"
"testing"
"github.com/stretchr/testify/assert"
@@ -47,89 +45,6 @@ func TestIdentify(t *testing.T) {
assert.True(t, invoked)
}
-func TestRequiredAuth(t *testing.T) {
- invoked1 := false
- invoked2 := false
-
- handler1 := sr.HandlerFunc(func(_ context.Context, request *sr.Request) *sr.Response {
- invoked1 = true
- return gemini.Success("", &bytes.Buffer{})
- })
-
- handler2 := sr.HandlerFunc(func(_ context.Context, request *sr.Request) *sr.Response {
- invoked2 = true
- return gemini.Success("", &bytes.Buffer{})
- })
-
- authMiddleware := sr.Filter(tlsauth.RequiredAuth(tlsauth.Allow), nil)
-
- handler1 = sr.Filter(
- func(_ context.Context, req *sr.Request) bool {
- return strings.HasPrefix(req.Path, "/one")
- },
- nil,
- )(authMiddleware(handler1))
- handler2 = authMiddleware(handler2)
-
- server, client, _ := setup(t,
- "testdata/server.crt", "testdata/server.key",
- "testdata/client1.crt", "testdata/client1.key",
- sr.FallthroughHandler(handler1, handler2),
- )
-
- go func() {
- _ = server.Serve()
- }()
- defer server.Close()
-
- requestPath(t, client, server, "/one")
- assert.True(t, invoked1)
-
- client, _ = clientFor(t, server, "", "") // no client cert this time
- requestPath(t, client, server, "/two")
- assert.False(t, invoked2)
-}
-
-func TestOptionalAuth(t *testing.T) {
- invoked1 := false
- invoked2 := false
-
- handler1 := sr.HandlerFunc(func(_ context.Context, request *sr.Request) *sr.Response {
- if !strings.HasPrefix(request.Path, "/one") {
- return nil
- }
-
- invoked1 = true
- return gemini.Success("", &bytes.Buffer{})
- })
-
- handler2 := sr.HandlerFunc(func(_ context.Context, request *sr.Request) *sr.Response {
- invoked2 = true
- return gemini.Success("", &bytes.Buffer{})
- })
-
- mw := sr.Filter(tlsauth.OptionalAuth(tlsauth.Reject), nil)
- handler := sr.FallthroughHandler(mw(handler1), mw(handler2))
-
- server, client, _ := setup(t,
- "testdata/server.crt", "testdata/server.key",
- "testdata/client1.crt", "testdata/client1.key",
- handler,
- )
-
- go func() {
- _ = server.Serve()
- }()
- defer server.Close()
-
- requestPath(t, client, server, "/one")
- assert.False(t, invoked1)
-
- client, _ = clientFor(t, server, "", "")
- requestPath(t, client, server, "/two")
- assert.True(t, invoked2)
-}
-
func setup(
t *testing.T,
serverCertPath string,
diff --git a/contrib/tlsauth/gemini.go b/contrib/tlsauth/gemini.go
index 9996595..9bf07bd 100644
--- a/contrib/tlsauth/gemini.go
+++ b/contrib/tlsauth/gemini.go
@@ -15,11 +15,10 @@ import (
func GeminiAuth(approver Approver) sr.Middleware {
return func(inner sr.Handler) sr.Handler {
return sr.HandlerFunc(func(ctx context.Context, request *sr.Request) *sr.Response {
- identity := Identity(request)
- if identity == nil {
+ if Identity(request) == nil {
return geminiMissingCert(ctx, request)
}
- if !approver(identity) {
+ if !approver(ctx, request) {
return geminiCertNotAuthorized(ctx, request)
}
@@ -36,8 +35,7 @@ func GeminiAuth(approver Approver) sr.Middleware {
func GeminiOptionalAuth(approver Approver) sr.Middleware {
return func(inner sr.Handler) sr.Handler {
return sr.HandlerFunc(func(ctx context.Context, request *sr.Request) *sr.Response {
- identity := Identity(request)
- if identity != nil && !approver(identity) {
+ if Identity(request) != nil && !approver(ctx, request) {
return geminiCertNotAuthorized(ctx, request)
}