summaryrefslogtreecommitdiff
path: root/auth_test.go
diff options
context:
space:
mode:
authorT <t@tjp.lol>2025-06-26 11:42:17 -0600
committerT <t@tjp.lol>2025-07-01 17:50:49 -0600
commit639ad6a02cbb4b713434671ec09f309aa5410921 (patch)
tree7dde9cce8136636d11f2f7c961072984cfc705e7 /auth_test.go
Create authentic_kate: user authentication for go HTTP applications
Diffstat (limited to 'auth_test.go')
-rw-r--r--auth_test.go391
1 files changed, 391 insertions, 0 deletions
diff --git a/auth_test.go b/auth_test.go
new file mode 100644
index 0000000..131e132
--- /dev/null
+++ b/auth_test.go
@@ -0,0 +1,391 @@
+package kate
+
+import (
+ "context"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+)
+
+type testData struct {
+ UserID int
+ Name string
+}
+
+type testSerDes struct{}
+
+func (ts testSerDes) Serialize(w io.Writer, data testData) error {
+ _, err := w.Write([]byte(strings.Join([]string{string(rune(data.UserID)), data.Name}, "|")))
+ return err
+}
+
+func (ts testSerDes) Deserialize(r io.Reader, data *testData) error {
+ buf, err := io.ReadAll(r)
+ if err != nil {
+ return err
+ }
+ parts := strings.Split(string(buf), "|")
+ if len(parts) != 2 {
+ return io.ErrUnexpectedEOF
+ }
+ data.UserID = int(rune(parts[0][0]))
+ data.Name = parts[1]
+ return nil
+}
+
+func createTestAuth() Auth[testData] {
+ config := AuthConfig[testData]{
+ SerDes: testSerDes{},
+ CookieName: "test_auth",
+ URLPath: "/",
+ URLDomain: "example.com",
+ HTTPSOnly: false,
+ MaxAge: time.Hour,
+ }
+ return New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config)
+}
+
+func TestNew(t *testing.T) {
+ tests := []struct {
+ name string
+ privkey string
+ wantErr bool
+ }{
+ {
+ name: "valid hex key",
+ privkey: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
+ wantErr: false,
+ },
+ {
+ name: "invalid hex key - too short",
+ privkey: "0123456789abcdef",
+ wantErr: true,
+ },
+ {
+ name: "invalid hex key - not hex",
+ privkey: "invalid_hex_key_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ config := AuthConfig[testData]{
+ SerDes: testSerDes{},
+ CookieName: "test",
+ }
+
+ defer func() {
+ if r := recover(); r != nil {
+ if !tt.wantErr {
+ t.Errorf("New() panicked unexpectedly: %v", r)
+ }
+ } else if tt.wantErr {
+ t.Errorf("New() should have panicked but didn't")
+ }
+ }()
+
+ auth := New(tt.privkey, config)
+ if !tt.wantErr && auth.config.CookieName != "test" {
+ t.Errorf("New() failed to create auth instance properly")
+ }
+ })
+ }
+}
+
+func TestAuth_Set(t *testing.T) {
+ auth := createTestAuth()
+ data := testData{UserID: 123, Name: "Alice"}
+
+ w := httptest.NewRecorder()
+ err := auth.Set(w, data)
+ if err != nil {
+ t.Fatalf("Set() error = %v", err)
+ }
+
+ cookies := w.Result().Cookies()
+ if len(cookies) != 1 {
+ t.Fatalf("Expected 1 cookie, got %d", len(cookies))
+ }
+
+ cookie := cookies[0]
+ if cookie.Name != "test_auth" {
+ t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name)
+ }
+ if cookie.Path != "/" {
+ t.Errorf("Expected cookie path '/', got %s", cookie.Path)
+ }
+ if cookie.Domain != "example.com" {
+ t.Errorf("Expected cookie domain 'example.com', got %s", cookie.Domain)
+ }
+ if !cookie.HttpOnly {
+ t.Error("Expected cookie to be HttpOnly")
+ }
+ if cookie.SameSite != http.SameSiteLaxMode {
+ t.Errorf("Expected SameSite to be Lax, got %v", cookie.SameSite)
+ }
+}
+
+func TestAuth_Required(t *testing.T) {
+ auth := createTestAuth()
+ data := testData{UserID: 123, Name: "Alice"}
+
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ authData, ok := auth.Get(r.Context())
+ if !ok {
+ t.Error("Expected to find auth data in context")
+ return
+ }
+ if authData.UserID != 123 || authData.Name != "Alice" {
+ t.Errorf("Expected UserID=123, Name=Alice, got UserID=%d, Name=%s", authData.UserID, authData.Name)
+ }
+ w.WriteHeader(http.StatusOK)
+ })
+
+ protectedHandler := auth.Required(handler)
+
+ t.Run("missing cookie", func(t *testing.T) {
+ req := httptest.NewRequest("GET", "/", nil)
+ w := httptest.NewRecorder()
+
+ protectedHandler.ServeHTTP(w, req)
+
+ if w.Code != http.StatusUnauthorized {
+ t.Errorf("Expected status 401, got %d", w.Code)
+ }
+ })
+
+ t.Run("invalid cookie", func(t *testing.T) {
+ req := httptest.NewRequest("GET", "/", nil)
+ req.AddCookie(&http.Cookie{
+ Name: "test_auth",
+ Value: "invalid_token",
+ })
+ w := httptest.NewRecorder()
+
+ protectedHandler.ServeHTTP(w, req)
+
+ if w.Code != http.StatusUnauthorized {
+ t.Errorf("Expected status 401, got %d", w.Code)
+ }
+ })
+
+ t.Run("valid cookie", func(t *testing.T) {
+ setW := httptest.NewRecorder()
+ if err := auth.Set(setW, data); err != nil {
+ t.Fatalf("Set() error = %v", err)
+ }
+ cookie := setW.Result().Cookies()[0]
+
+ req := httptest.NewRequest("GET", "/", nil)
+ req.AddCookie(cookie)
+ w := httptest.NewRecorder()
+
+ protectedHandler.ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Errorf("Expected status 200, got %d", w.Code)
+ }
+ })
+}
+
+func TestAuth_Optional(t *testing.T) {
+ auth := createTestAuth()
+ data := testData{UserID: 123, Name: "Alice"}
+
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ authData, ok := auth.Get(r.Context())
+ if ok {
+ if authData.UserID != 123 || authData.Name != "Alice" {
+ t.Errorf("Expected UserID=123, Name=Alice, got UserID=%d, Name=%s", authData.UserID, authData.Name)
+ }
+ }
+ w.WriteHeader(http.StatusOK)
+ })
+
+ optionalHandler := auth.Optional(handler)
+
+ t.Run("missing cookie - should allow", func(t *testing.T) {
+ req := httptest.NewRequest("GET", "/", nil)
+ w := httptest.NewRecorder()
+
+ optionalHandler.ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Errorf("Expected status 200, got %d", w.Code)
+ }
+ })
+
+ t.Run("invalid cookie - should allow", func(t *testing.T) {
+ req := httptest.NewRequest("GET", "/", nil)
+ req.AddCookie(&http.Cookie{
+ Name: "test_auth",
+ Value: "invalid_token",
+ })
+ w := httptest.NewRecorder()
+
+ optionalHandler.ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Errorf("Expected status 200, got %d", w.Code)
+ }
+ })
+
+ t.Run("valid cookie", func(t *testing.T) {
+ setW := httptest.NewRecorder()
+ if err := auth.Set(setW, data); err != nil {
+ t.Fatalf("Set() error = %v", err)
+ }
+ cookie := setW.Result().Cookies()[0]
+
+ req := httptest.NewRequest("GET", "/", nil)
+ req.AddCookie(cookie)
+ w := httptest.NewRecorder()
+
+ optionalHandler.ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Errorf("Expected status 200, got %d", w.Code)
+ }
+ })
+}
+
+func TestAuth_Get(t *testing.T) {
+ auth := createTestAuth()
+ data := testData{UserID: 123, Name: "Alice"}
+
+ t.Run("no data in context", func(t *testing.T) {
+ ctx := context.Background()
+ result, ok := auth.Get(ctx)
+ if ok {
+ t.Error("Expected Get to return false for empty context")
+ }
+ if result.UserID != 0 || result.Name != "" {
+ t.Error("Expected zero value for missing data")
+ }
+ })
+
+ t.Run("data in context", func(t *testing.T) {
+ ctx := context.WithValue(context.Background(), key, data)
+ result, ok := auth.Get(ctx)
+ if !ok {
+ t.Error("Expected Get to return true for context with data")
+ }
+ if result.UserID != 123 || result.Name != "Alice" {
+ t.Errorf("Expected UserID=123, Name=Alice, got UserID=%d, Name=%s", result.UserID, result.Name)
+ }
+ })
+
+ t.Run("wrong type in context", func(t *testing.T) {
+ ctx := context.WithValue(context.Background(), key, "wrong_type")
+ result, ok := auth.Get(ctx)
+ if ok {
+ t.Error("Expected Get to return false for wrong type in context")
+ }
+ if result.UserID != 0 || result.Name != "" {
+ t.Error("Expected zero value for wrong type")
+ }
+ })
+}
+
+type failingSerDes struct {
+ failSerialize bool
+ failDeserialize bool
+}
+
+func (f failingSerDes) Serialize(w io.Writer, data testData) error {
+ if f.failSerialize {
+ return io.ErrClosedPipe
+ }
+ return testSerDes{}.Serialize(w, data)
+}
+
+func (f failingSerDes) Deserialize(r io.Reader, data *testData) error {
+ if f.failDeserialize {
+ return io.ErrUnexpectedEOF
+ }
+ return testSerDes{}.Deserialize(r, data)
+}
+
+func TestAuth_Set_SerializationError(t *testing.T) {
+ config := AuthConfig[testData]{
+ SerDes: failingSerDes{failSerialize: true},
+ CookieName: "test_auth",
+ }
+ auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config)
+
+ data := testData{UserID: 123, Name: "Alice"}
+ w := httptest.NewRecorder()
+
+ err := auth.Set(w, data)
+ if err == nil {
+ t.Error("Expected Set to return error for serialization failure")
+ }
+}
+
+func TestAuth_Required_DeserializationError(t *testing.T) {
+ config := AuthConfig[testData]{
+ SerDes: failingSerDes{failDeserialize: true},
+ CookieName: "test_auth",
+ }
+ goodAuth := createTestAuth()
+ badAuth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config)
+
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ })
+
+ protectedHandler := badAuth.Required(handler)
+
+ data := testData{UserID: 123, Name: "Alice"}
+ setW := httptest.NewRecorder()
+ if err := goodAuth.Set(setW, data); err != nil {
+ t.Fatalf("Set() error = %v", err)
+ }
+ cookie := setW.Result().Cookies()[0]
+
+ req := httptest.NewRequest("GET", "/", nil)
+ req.AddCookie(cookie)
+ w := httptest.NewRecorder()
+
+ protectedHandler.ServeHTTP(w, req)
+
+ if w.Code != http.StatusInternalServerError {
+ t.Errorf("Expected status 500 for deserialization error, got %d", w.Code)
+ }
+}
+
+func TestAuth_Optional_DeserializationError(t *testing.T) {
+ config := AuthConfig[testData]{
+ SerDes: failingSerDes{failDeserialize: true},
+ CookieName: "test_auth",
+ }
+ goodAuth := createTestAuth()
+ badAuth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config)
+
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ })
+
+ optionalHandler := badAuth.Optional(handler)
+
+ data := testData{UserID: 123, Name: "Alice"}
+ setW := httptest.NewRecorder()
+ if err := goodAuth.Set(setW, data); err != nil {
+ t.Fatalf("Set() error = %v", err)
+ }
+ cookie := setW.Result().Cookies()[0]
+
+ req := httptest.NewRequest("GET", "/", nil)
+ req.AddCookie(cookie)
+ w := httptest.NewRecorder()
+
+ optionalHandler.ServeHTTP(w, req)
+
+ if w.Code != http.StatusInternalServerError {
+ t.Errorf("Expected status 500 for deserialization error, got %d", w.Code)
+ }
+}