package kate import ( "context" "io" "net/http" "net/http/httptest" "strings" "testing" "time" ) type testData struct { UserID int Name string } type testSerDes struct{} func (ts testSerDes) Serialize(w io.Writer, data testData) error { _, err := w.Write([]byte(strings.Join([]string{string(rune(data.UserID)), data.Name}, "|"))) return err } func (ts testSerDes) Deserialize(r io.Reader, data *testData) error { buf, err := io.ReadAll(r) if err != nil { return err } parts := strings.Split(string(buf), "|") if len(parts) != 2 { return io.ErrUnexpectedEOF } data.UserID = int(rune(parts[0][0])) data.Name = parts[1] return nil } func createTestAuth() Auth[testData] { config := AuthConfig[testData]{ SerDes: testSerDes{}, CookieName: "test_auth", URLPath: "/", URLDomain: "example.com", HTTPSOnly: false, MaxAge: time.Hour, } return New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config) } func TestNew(t *testing.T) { tests := []struct { name string privkey string wantErr bool }{ { name: "valid hex key", privkey: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", wantErr: false, }, { name: "invalid hex key - too short", privkey: "0123456789abcdef", wantErr: true, }, { name: "invalid hex key - not hex", privkey: "invalid_hex_key_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { config := AuthConfig[testData]{ SerDes: testSerDes{}, CookieName: "test", } defer func() { if r := recover(); r != nil { if !tt.wantErr { t.Errorf("New() panicked unexpectedly: %v", r) } } else if tt.wantErr { t.Errorf("New() should have panicked but didn't") } }() auth := New(tt.privkey, config) if !tt.wantErr && auth.config.CookieName != "test" { t.Errorf("New() failed to create auth instance properly") } }) } } func TestAuth_Set(t *testing.T) { auth := createTestAuth() data := testData{UserID: 123, Name: "Alice"} w := httptest.NewRecorder() err := auth.Set(w, data) if err != nil { t.Fatalf("Set() error = %v", err) } cookies := w.Result().Cookies() if len(cookies) != 1 { t.Fatalf("Expected 1 cookie, got %d", len(cookies)) } cookie := cookies[0] if cookie.Name != "test_auth" { t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name) } if cookie.Path != "/" { t.Errorf("Expected cookie path '/', got %s", cookie.Path) } if cookie.Domain != "example.com" { t.Errorf("Expected cookie domain 'example.com', got %s", cookie.Domain) } if !cookie.HttpOnly { t.Error("Expected cookie to be HttpOnly") } if cookie.SameSite != http.SameSiteLaxMode { t.Errorf("Expected SameSite to be Lax, got %v", cookie.SameSite) } } func TestAuth_Required(t *testing.T) { auth := createTestAuth() data := testData{UserID: 123, Name: "Alice"} handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { authData, ok := auth.Get(r.Context()) if !ok { t.Error("Expected to find auth data in context") return } if authData.UserID != 123 || authData.Name != "Alice" { t.Errorf("Expected UserID=123, Name=Alice, got UserID=%d, Name=%s", authData.UserID, authData.Name) } w.WriteHeader(http.StatusOK) }) protectedHandler := auth.Required(handler) t.Run("missing cookie", func(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() protectedHandler.ServeHTTP(w, req) if w.Code != http.StatusUnauthorized { t.Errorf("Expected status 401, got %d", w.Code) } }) t.Run("invalid cookie", func(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) req.AddCookie(&http.Cookie{ Name: "test_auth", Value: "invalid_token", }) w := httptest.NewRecorder() protectedHandler.ServeHTTP(w, req) if w.Code != http.StatusUnauthorized { t.Errorf("Expected status 401, got %d", w.Code) } }) t.Run("valid cookie", func(t *testing.T) { setW := httptest.NewRecorder() if err := auth.Set(setW, data); err != nil { t.Fatalf("Set() error = %v", err) } cookie := setW.Result().Cookies()[0] req := httptest.NewRequest("GET", "/", nil) req.AddCookie(cookie) w := httptest.NewRecorder() protectedHandler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("Expected status 200, got %d", w.Code) } }) } func TestAuth_Optional(t *testing.T) { auth := createTestAuth() data := testData{UserID: 123, Name: "Alice"} handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { authData, ok := auth.Get(r.Context()) if ok { if authData.UserID != 123 || authData.Name != "Alice" { t.Errorf("Expected UserID=123, Name=Alice, got UserID=%d, Name=%s", authData.UserID, authData.Name) } } w.WriteHeader(http.StatusOK) }) optionalHandler := auth.Optional(handler) t.Run("missing cookie - should allow", func(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() optionalHandler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("Expected status 200, got %d", w.Code) } }) t.Run("invalid cookie - should allow", func(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) req.AddCookie(&http.Cookie{ Name: "test_auth", Value: "invalid_token", }) w := httptest.NewRecorder() optionalHandler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("Expected status 200, got %d", w.Code) } }) t.Run("valid cookie", func(t *testing.T) { setW := httptest.NewRecorder() if err := auth.Set(setW, data); err != nil { t.Fatalf("Set() error = %v", err) } cookie := setW.Result().Cookies()[0] req := httptest.NewRequest("GET", "/", nil) req.AddCookie(cookie) w := httptest.NewRecorder() optionalHandler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("Expected status 200, got %d", w.Code) } }) } func TestAuth_Get(t *testing.T) { auth := createTestAuth() data := testData{UserID: 123, Name: "Alice"} t.Run("no data in context", func(t *testing.T) { ctx := context.Background() result, ok := auth.Get(ctx) if ok { t.Error("Expected Get to return false for empty context") } if result.UserID != 0 || result.Name != "" { t.Error("Expected zero value for missing data") } }) t.Run("data in context", func(t *testing.T) { ctx := context.WithValue(context.Background(), key, data) result, ok := auth.Get(ctx) if !ok { t.Error("Expected Get to return true for context with data") } if result.UserID != 123 || result.Name != "Alice" { t.Errorf("Expected UserID=123, Name=Alice, got UserID=%d, Name=%s", result.UserID, result.Name) } }) t.Run("wrong type in context", func(t *testing.T) { ctx := context.WithValue(context.Background(), key, "wrong_type") result, ok := auth.Get(ctx) if ok { t.Error("Expected Get to return false for wrong type in context") } if result.UserID != 0 || result.Name != "" { t.Error("Expected zero value for wrong type") } }) } type failingSerDes struct { failSerialize bool failDeserialize bool } func (f failingSerDes) Serialize(w io.Writer, data testData) error { if f.failSerialize { return io.ErrClosedPipe } return testSerDes{}.Serialize(w, data) } func (f failingSerDes) Deserialize(r io.Reader, data *testData) error { if f.failDeserialize { return io.ErrUnexpectedEOF } return testSerDes{}.Deserialize(r, data) } func TestAuth_Set_SerializationError(t *testing.T) { config := AuthConfig[testData]{ SerDes: failingSerDes{failSerialize: true}, CookieName: "test_auth", } auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config) data := testData{UserID: 123, Name: "Alice"} w := httptest.NewRecorder() err := auth.Set(w, data) if err == nil { t.Error("Expected Set to return error for serialization failure") } } func TestAuth_Required_DeserializationError(t *testing.T) { config := AuthConfig[testData]{ SerDes: failingSerDes{failDeserialize: true}, CookieName: "test_auth", } goodAuth := createTestAuth() badAuth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) protectedHandler := badAuth.Required(handler) data := testData{UserID: 123, Name: "Alice"} setW := httptest.NewRecorder() if err := goodAuth.Set(setW, data); err != nil { t.Fatalf("Set() error = %v", err) } cookie := setW.Result().Cookies()[0] req := httptest.NewRequest("GET", "/", nil) req.AddCookie(cookie) w := httptest.NewRecorder() protectedHandler.ServeHTTP(w, req) if w.Code != http.StatusInternalServerError { t.Errorf("Expected status 500 for deserialization error, got %d", w.Code) } } func TestAuth_Optional_DeserializationError(t *testing.T) { config := AuthConfig[testData]{ SerDes: failingSerDes{failDeserialize: true}, CookieName: "test_auth", } goodAuth := createTestAuth() badAuth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) optionalHandler := badAuth.Optional(handler) data := testData{UserID: 123, Name: "Alice"} setW := httptest.NewRecorder() if err := goodAuth.Set(setW, data); err != nil { t.Fatalf("Set() error = %v", err) } cookie := setW.Result().Cookies()[0] req := httptest.NewRequest("GET", "/", nil) req.AddCookie(cookie) w := httptest.NewRecorder() optionalHandler.ServeHTTP(w, req) if w.Code != http.StatusInternalServerError { t.Errorf("Expected status 500 for deserialization error, got %d", w.Code) } } func TestAuth_Clear(t *testing.T) { auth := createTestAuth() t.Run("clear cookie properties", func(t *testing.T) { w := httptest.NewRecorder() auth.Clear(w) cookies := w.Result().Cookies() if len(cookies) != 1 { t.Fatalf("Expected 1 cookie, got %d", len(cookies)) } cookie := cookies[0] if cookie.Name != "test_auth" { t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name) } if cookie.Value != "" { t.Errorf("Expected empty cookie value, got %s", cookie.Value) } if cookie.Path != "/" { t.Errorf("Expected cookie path '/', got %s", cookie.Path) } if cookie.Domain != "example.com" { t.Errorf("Expected cookie domain 'example.com', got %s", cookie.Domain) } if cookie.MaxAge != -1 { t.Errorf("Expected cookie MaxAge -1, got %d", cookie.MaxAge) } if !cookie.HttpOnly { t.Error("Expected cookie to be HttpOnly") } if cookie.SameSite != http.SameSiteLaxMode { t.Errorf("Expected SameSite to be Lax, got %v", cookie.SameSite) } }) } func TestAuth_Clear_RemovesExistingHeaders(t *testing.T) { auth := createTestAuth() data := testData{UserID: 123, Name: "Alice"} t.Run("clear removes existing Set-Cookie headers for same cookie", func(t *testing.T) { w := httptest.NewRecorder() // Set the auth cookie first err := auth.Set(w, data) if err != nil { t.Fatalf("Set() error = %v", err) } // Verify we have one Set-Cookie header setCookieHeaders := w.Header()["Set-Cookie"] if len(setCookieHeaders) != 1 { t.Fatalf("Expected 1 Set-Cookie header after Set(), got %d", len(setCookieHeaders)) } // Clear the auth cookie auth.Clear(w) // Verify we still have only one Set-Cookie header (the clear one) setCookieHeaders = w.Header()["Set-Cookie"] if len(setCookieHeaders) != 1 { t.Fatalf("Expected 1 Set-Cookie header after Clear(), got %d", len(setCookieHeaders)) } // Parse the remaining cookie to verify it's the clear cookie cookies := w.Result().Cookies() if len(cookies) != 1 { t.Fatalf("Expected 1 parseable cookie, got %d", len(cookies)) } cookie := cookies[0] if cookie.Name != "test_auth" { t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name) } if cookie.MaxAge != -1 { t.Errorf("Expected cookie MaxAge -1, got %d", cookie.MaxAge) } if cookie.Value != "" { t.Errorf("Expected empty cookie value, got %s", cookie.Value) } }) t.Run("clear does not affect other cookies", func(t *testing.T) { w := httptest.NewRecorder() // Add a different cookie manually otherCookie := &http.Cookie{ Name: "other_cookie", Value: "other_value", } w.Header().Add("Set-Cookie", otherCookie.String()) // Set the auth cookie err := auth.Set(w, data) if err != nil { t.Fatalf("Set() error = %v", err) } // Clear the auth cookie auth.Clear(w) // Verify we have two cookies: the other cookie and the clear auth cookie cookies := w.Result().Cookies() if len(cookies) != 2 { t.Fatalf("Expected 2 cookies after Clear(), got %d", len(cookies)) } // Find the other cookie var foundOtherCookie bool var foundClearCookie bool for _, cookie := range cookies { if cookie.Name == "other_cookie" && cookie.Value == "other_value" { foundOtherCookie = true } if cookie.Name == "test_auth" && cookie.MaxAge == -1 { foundClearCookie = true } } if !foundOtherCookie { t.Error("Expected to find other_cookie in response") } if !foundClearCookie { t.Error("Expected to find clear auth cookie in response") } }) t.Run("clear works when no existing headers", func(t *testing.T) { w := httptest.NewRecorder() // Clear without setting first auth.Clear(w) // Verify we have one Set-Cookie header (the clear one) cookies := w.Result().Cookies() if len(cookies) != 1 { t.Fatalf("Expected 1 cookie after Clear(), got %d", len(cookies)) } cookie := cookies[0] if cookie.Name != "test_auth" { t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name) } if cookie.MaxAge != -1 { t.Errorf("Expected cookie MaxAge -1, got %d", cookie.MaxAge) } }) } func TestRemoveSetCookieHeaders(t *testing.T) { t.Run("removes matching cookie headers", func(t *testing.T) { w := httptest.NewRecorder() // Add multiple cookies w.Header().Add("Set-Cookie", "test_auth=value1; Path=/") w.Header().Add("Set-Cookie", "other_cookie=value2; Path=/") w.Header().Add("Set-Cookie", "test_auth=value3; Path=/; Domain=example.com") removeSetCookieHeaders(w, "test_auth") cookies := w.Result().Cookies() if len(cookies) != 1 { t.Fatalf("Expected 1 cookie after removal, got %d", len(cookies)) } if cookies[0].Name != "other_cookie" { t.Errorf("Expected other_cookie to remain, got %s", cookies[0].Name) } }) t.Run("handles empty headers", func(t *testing.T) { w := httptest.NewRecorder() removeSetCookieHeaders(w, "test_auth") cookies := w.Result().Cookies() if len(cookies) != 0 { t.Fatalf("Expected 0 cookies, got %d", len(cookies)) } }) t.Run("handles malformed headers", func(t *testing.T) { w := httptest.NewRecorder() // Add a malformed header (no = sign) w.Header().Add("Set-Cookie", "malformed_header") w.Header().Add("Set-Cookie", "test_auth=value") removeSetCookieHeaders(w, "test_auth") // The malformed header should be preserved (check raw headers) setCookieHeaders := w.Header()["Set-Cookie"] if len(setCookieHeaders) != 1 || setCookieHeaders[0] != "malformed_header" { t.Errorf("Expected malformed header to be preserved, got %v", setCookieHeaders) } // The malformed header won't be parsed into a valid cookie by Go's parser cookies := w.Result().Cookies() if len(cookies) != 0 { t.Errorf("Expected 0 parseable cookies (malformed header), got %d", len(cookies)) } }) } func TestAuth_Required_ClearsInvalidCookies(t *testing.T) { auth := createTestAuth() handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) protectedHandler := auth.Required(handler) t.Run("clears invalid cookie on decryption failure", func(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) req.AddCookie(&http.Cookie{ Name: "test_auth", Value: "invalid_encrypted_value", }) w := httptest.NewRecorder() protectedHandler.ServeHTTP(w, req) if w.Code != http.StatusUnauthorized { t.Errorf("Expected status 401, got %d", w.Code) } // Verify that a clear cookie was set cookies := w.Result().Cookies() if len(cookies) != 1 { t.Fatalf("Expected 1 clear cookie to be set, got %d", len(cookies)) } cookie := cookies[0] if cookie.Name != "test_auth" { t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name) } if cookie.MaxAge != -1 { t.Errorf("Expected cookie MaxAge -1, got %d", cookie.MaxAge) } }) t.Run("clears invalid cookie on deserialization failure", func(t *testing.T) { // Create a cookie with a different auth instance that will deserialize differently config := AuthConfig[testData]{ SerDes: failingSerDes{failDeserialize: true}, CookieName: "test_auth", } goodAuth := createTestAuth() badAuth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config) // Set a cookie with the good auth data := testData{UserID: 123, Name: "Alice"} setW := httptest.NewRecorder() if err := goodAuth.Set(setW, data); err != nil { t.Fatalf("Set() error = %v", err) } validCookie := setW.Result().Cookies()[0] // Try to use it with the bad auth (will fail deserialization) req := httptest.NewRequest("GET", "/", nil) req.AddCookie(validCookie) w := httptest.NewRecorder() badAuth.Required(handler).ServeHTTP(w, req) if w.Code != http.StatusInternalServerError { t.Errorf("Expected status 500, got %d", w.Code) } // Verify that a clear cookie was set cookies := w.Result().Cookies() if len(cookies) != 1 { t.Fatalf("Expected 1 clear cookie to be set, got %d", len(cookies)) } cookie := cookies[0] if cookie.Name != "test_auth" { t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name) } if cookie.MaxAge != -1 { t.Errorf("Expected cookie MaxAge -1, got %d", cookie.MaxAge) } }) } func TestAuth_Optional_ClearsInvalidCookies(t *testing.T) { auth := createTestAuth() handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) optionalHandler := auth.Optional(handler) t.Run("clears invalid cookie on decryption failure and continues", func(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) req.AddCookie(&http.Cookie{ Name: "test_auth", Value: "invalid_encrypted_value", }) w := httptest.NewRecorder() optionalHandler.ServeHTTP(w, req) // Should continue processing (200 OK) if w.Code != http.StatusOK { t.Errorf("Expected status 200, got %d", w.Code) } // Verify that a clear cookie was set cookies := w.Result().Cookies() if len(cookies) != 1 { t.Fatalf("Expected 1 clear cookie to be set, got %d", len(cookies)) } cookie := cookies[0] if cookie.Name != "test_auth" { t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name) } if cookie.MaxAge != -1 { t.Errorf("Expected cookie MaxAge -1, got %d", cookie.MaxAge) } }) t.Run("clears invalid cookie on deserialization failure", func(t *testing.T) { // Create a cookie with a different auth instance that will deserialize differently config := AuthConfig[testData]{ SerDes: failingSerDes{failDeserialize: true}, CookieName: "test_auth", } goodAuth := createTestAuth() badAuth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config) // Set a cookie with the good auth data := testData{UserID: 123, Name: "Alice"} setW := httptest.NewRecorder() if err := goodAuth.Set(setW, data); err != nil { t.Fatalf("Set() error = %v", err) } validCookie := setW.Result().Cookies()[0] // Try to use it with the bad auth (will fail deserialization) req := httptest.NewRequest("GET", "/", nil) req.AddCookie(validCookie) w := httptest.NewRecorder() badAuth.Optional(handler).ServeHTTP(w, req) if w.Code != http.StatusInternalServerError { t.Errorf("Expected status 500, got %d", w.Code) } // Verify that a clear cookie was set cookies := w.Result().Cookies() if len(cookies) != 1 { t.Fatalf("Expected 1 clear cookie to be set, got %d", len(cookies)) } cookie := cookies[0] if cookie.Name != "test_auth" { t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name) } if cookie.MaxAge != -1 { t.Errorf("Expected cookie MaxAge -1, got %d", cookie.MaxAge) } }) } func TestAuth_Clear_Integration(t *testing.T) { auth := createTestAuth() data := testData{UserID: 123, Name: "Alice"} handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { authData, ok := auth.Get(r.Context()) if !ok { t.Error("Expected to find auth data in context") return } if authData.UserID != 123 || authData.Name != "Alice" { t.Errorf("Expected UserID=123, Name=Alice, got UserID=%d, Name=%s", authData.UserID, authData.Name) } w.WriteHeader(http.StatusOK) }) protectedHandler := auth.Required(handler) t.Run("authentication works before clear", func(t *testing.T) { setW := httptest.NewRecorder() if err := auth.Set(setW, data); err != nil { t.Fatalf("Set() error = %v", err) } cookie := setW.Result().Cookies()[0] req := httptest.NewRequest("GET", "/", nil) req.AddCookie(cookie) w := httptest.NewRecorder() protectedHandler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("Expected status 200, got %d", w.Code) } }) t.Run("authentication fails after clear", func(t *testing.T) { setW := httptest.NewRecorder() if err := auth.Set(setW, data); err != nil { t.Fatalf("Set() error = %v", err) } originalCookie := setW.Result().Cookies()[0] clearW := httptest.NewRecorder() auth.Clear(clearW) clearCookie := clearW.Result().Cookies()[0] if clearCookie.MaxAge != -1 { t.Errorf("Expected clear cookie to have MaxAge -1, got %d", clearCookie.MaxAge) } req := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() protectedHandler.ServeHTTP(w, req) if w.Code != http.StatusUnauthorized { t.Errorf("Expected status 401 when no cookie is present (simulating browser behavior after clear), got %d", w.Code) } req2 := httptest.NewRequest("GET", "/", nil) req2.AddCookie(originalCookie) w2 := httptest.NewRecorder() protectedHandler.ServeHTTP(w2, req2) if w2.Code != http.StatusOK { t.Errorf("Expected original cookie to still work, got %d", w2.Code) } }) }