diff options
author | T <t@tjp.lol> | 2025-06-26 11:42:17 -0600 |
---|---|---|
committer | T <t@tjp.lol> | 2025-07-01 17:50:49 -0600 |
commit | 639ad6a02cbb4b713434671ec09f309aa5410921 (patch) | |
tree | 7dde9cce8136636d11f2f7c961072984cfc705e7 /auth_test.go |
Create authentic_kate: user authentication for go HTTP applications
Diffstat (limited to 'auth_test.go')
-rw-r--r-- | auth_test.go | 391 |
1 files changed, 391 insertions, 0 deletions
diff --git a/auth_test.go b/auth_test.go new file mode 100644 index 0000000..131e132 --- /dev/null +++ b/auth_test.go @@ -0,0 +1,391 @@ +package kate + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +type testData struct { + UserID int + Name string +} + +type testSerDes struct{} + +func (ts testSerDes) Serialize(w io.Writer, data testData) error { + _, err := w.Write([]byte(strings.Join([]string{string(rune(data.UserID)), data.Name}, "|"))) + return err +} + +func (ts testSerDes) Deserialize(r io.Reader, data *testData) error { + buf, err := io.ReadAll(r) + if err != nil { + return err + } + parts := strings.Split(string(buf), "|") + if len(parts) != 2 { + return io.ErrUnexpectedEOF + } + data.UserID = int(rune(parts[0][0])) + data.Name = parts[1] + return nil +} + +func createTestAuth() Auth[testData] { + config := AuthConfig[testData]{ + SerDes: testSerDes{}, + CookieName: "test_auth", + URLPath: "/", + URLDomain: "example.com", + HTTPSOnly: false, + MaxAge: time.Hour, + } + return New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config) +} + +func TestNew(t *testing.T) { + tests := []struct { + name string + privkey string + wantErr bool + }{ + { + name: "valid hex key", + privkey: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + wantErr: false, + }, + { + name: "invalid hex key - too short", + privkey: "0123456789abcdef", + wantErr: true, + }, + { + name: "invalid hex key - not hex", + privkey: "invalid_hex_key_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := AuthConfig[testData]{ + SerDes: testSerDes{}, + CookieName: "test", + } + + defer func() { + if r := recover(); r != nil { + if !tt.wantErr { + t.Errorf("New() panicked unexpectedly: %v", r) + } + } else if tt.wantErr { + t.Errorf("New() should have panicked but didn't") + } + }() + + auth := New(tt.privkey, config) + if !tt.wantErr && auth.config.CookieName != "test" { + t.Errorf("New() failed to create auth instance properly") + } + }) + } +} + +func TestAuth_Set(t *testing.T) { + auth := createTestAuth() + data := testData{UserID: 123, Name: "Alice"} + + w := httptest.NewRecorder() + err := auth.Set(w, data) + if err != nil { + t.Fatalf("Set() error = %v", err) + } + + cookies := w.Result().Cookies() + if len(cookies) != 1 { + t.Fatalf("Expected 1 cookie, got %d", len(cookies)) + } + + cookie := cookies[0] + if cookie.Name != "test_auth" { + t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name) + } + if cookie.Path != "/" { + t.Errorf("Expected cookie path '/', got %s", cookie.Path) + } + if cookie.Domain != "example.com" { + t.Errorf("Expected cookie domain 'example.com', got %s", cookie.Domain) + } + if !cookie.HttpOnly { + t.Error("Expected cookie to be HttpOnly") + } + if cookie.SameSite != http.SameSiteLaxMode { + t.Errorf("Expected SameSite to be Lax, got %v", cookie.SameSite) + } +} + +func TestAuth_Required(t *testing.T) { + auth := createTestAuth() + data := testData{UserID: 123, Name: "Alice"} + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authData, ok := auth.Get(r.Context()) + if !ok { + t.Error("Expected to find auth data in context") + return + } + if authData.UserID != 123 || authData.Name != "Alice" { + t.Errorf("Expected UserID=123, Name=Alice, got UserID=%d, Name=%s", authData.UserID, authData.Name) + } + w.WriteHeader(http.StatusOK) + }) + + protectedHandler := auth.Required(handler) + + t.Run("missing cookie", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + protectedHandler.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("Expected status 401, got %d", w.Code) + } + }) + + t.Run("invalid cookie", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + req.AddCookie(&http.Cookie{ + Name: "test_auth", + Value: "invalid_token", + }) + w := httptest.NewRecorder() + + protectedHandler.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("Expected status 401, got %d", w.Code) + } + }) + + t.Run("valid cookie", func(t *testing.T) { + setW := httptest.NewRecorder() + if err := auth.Set(setW, data); err != nil { + t.Fatalf("Set() error = %v", err) + } + cookie := setW.Result().Cookies()[0] + + req := httptest.NewRequest("GET", "/", nil) + req.AddCookie(cookie) + w := httptest.NewRecorder() + + protectedHandler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + }) +} + +func TestAuth_Optional(t *testing.T) { + auth := createTestAuth() + data := testData{UserID: 123, Name: "Alice"} + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authData, ok := auth.Get(r.Context()) + if ok { + if authData.UserID != 123 || authData.Name != "Alice" { + t.Errorf("Expected UserID=123, Name=Alice, got UserID=%d, Name=%s", authData.UserID, authData.Name) + } + } + w.WriteHeader(http.StatusOK) + }) + + optionalHandler := auth.Optional(handler) + + t.Run("missing cookie - should allow", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + optionalHandler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + }) + + t.Run("invalid cookie - should allow", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + req.AddCookie(&http.Cookie{ + Name: "test_auth", + Value: "invalid_token", + }) + w := httptest.NewRecorder() + + optionalHandler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + }) + + t.Run("valid cookie", func(t *testing.T) { + setW := httptest.NewRecorder() + if err := auth.Set(setW, data); err != nil { + t.Fatalf("Set() error = %v", err) + } + cookie := setW.Result().Cookies()[0] + + req := httptest.NewRequest("GET", "/", nil) + req.AddCookie(cookie) + w := httptest.NewRecorder() + + optionalHandler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + }) +} + +func TestAuth_Get(t *testing.T) { + auth := createTestAuth() + data := testData{UserID: 123, Name: "Alice"} + + t.Run("no data in context", func(t *testing.T) { + ctx := context.Background() + result, ok := auth.Get(ctx) + if ok { + t.Error("Expected Get to return false for empty context") + } + if result.UserID != 0 || result.Name != "" { + t.Error("Expected zero value for missing data") + } + }) + + t.Run("data in context", func(t *testing.T) { + ctx := context.WithValue(context.Background(), key, data) + result, ok := auth.Get(ctx) + if !ok { + t.Error("Expected Get to return true for context with data") + } + if result.UserID != 123 || result.Name != "Alice" { + t.Errorf("Expected UserID=123, Name=Alice, got UserID=%d, Name=%s", result.UserID, result.Name) + } + }) + + t.Run("wrong type in context", func(t *testing.T) { + ctx := context.WithValue(context.Background(), key, "wrong_type") + result, ok := auth.Get(ctx) + if ok { + t.Error("Expected Get to return false for wrong type in context") + } + if result.UserID != 0 || result.Name != "" { + t.Error("Expected zero value for wrong type") + } + }) +} + +type failingSerDes struct { + failSerialize bool + failDeserialize bool +} + +func (f failingSerDes) Serialize(w io.Writer, data testData) error { + if f.failSerialize { + return io.ErrClosedPipe + } + return testSerDes{}.Serialize(w, data) +} + +func (f failingSerDes) Deserialize(r io.Reader, data *testData) error { + if f.failDeserialize { + return io.ErrUnexpectedEOF + } + return testSerDes{}.Deserialize(r, data) +} + +func TestAuth_Set_SerializationError(t *testing.T) { + config := AuthConfig[testData]{ + SerDes: failingSerDes{failSerialize: true}, + CookieName: "test_auth", + } + auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config) + + data := testData{UserID: 123, Name: "Alice"} + w := httptest.NewRecorder() + + err := auth.Set(w, data) + if err == nil { + t.Error("Expected Set to return error for serialization failure") + } +} + +func TestAuth_Required_DeserializationError(t *testing.T) { + config := AuthConfig[testData]{ + SerDes: failingSerDes{failDeserialize: true}, + CookieName: "test_auth", + } + goodAuth := createTestAuth() + badAuth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config) + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + protectedHandler := badAuth.Required(handler) + + data := testData{UserID: 123, Name: "Alice"} + setW := httptest.NewRecorder() + if err := goodAuth.Set(setW, data); err != nil { + t.Fatalf("Set() error = %v", err) + } + cookie := setW.Result().Cookies()[0] + + req := httptest.NewRequest("GET", "/", nil) + req.AddCookie(cookie) + w := httptest.NewRecorder() + + protectedHandler.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("Expected status 500 for deserialization error, got %d", w.Code) + } +} + +func TestAuth_Optional_DeserializationError(t *testing.T) { + config := AuthConfig[testData]{ + SerDes: failingSerDes{failDeserialize: true}, + CookieName: "test_auth", + } + goodAuth := createTestAuth() + badAuth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config) + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + optionalHandler := badAuth.Optional(handler) + + data := testData{UserID: 123, Name: "Alice"} + setW := httptest.NewRecorder() + if err := goodAuth.Set(setW, data); err != nil { + t.Fatalf("Set() error = %v", err) + } + cookie := setW.Result().Cookies()[0] + + req := httptest.NewRequest("GET", "/", nil) + req.AddCookie(cookie) + w := httptest.NewRecorder() + + optionalHandler.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("Expected status 500 for deserialization error, got %d", w.Code) + } +} |