diff options
author | T <t@tjp.lol> | 2025-07-05 09:41:47 -0600 |
---|---|---|
committer | T <t@tjp.lol> | 2025-07-05 09:55:47 -0600 |
commit | caf5bb2ee84079365996a622ab8fc5ed510ef9a7 (patch) | |
tree | 5caf8bcbcda5ab5c8d70782fb733923c0ac70af3 | |
parent | 639ad6a02cbb4b713434671ec09f309aa5410921 (diff) |
- Add Auth.Clear() method that creates expired cookies (MaxAge: -1) to log out users
- Enhance Clear() to remove existing Set-Cookie headers to comply with RFC 6265
- Update Required() and Optional() middlewares to automatically clear invalid cookies
- Add comprehensive tests for all new functionality
- Update documentation with Auth.Clear() usage examples
-rw-r--r-- | README.md | 3 | ||||
-rw-r--r-- | auth.go | 62 | ||||
-rw-r--r-- | auth_test.go | 440 | ||||
-rw-r--r-- | doc.go | 3 |
4 files changed, 508 insertions, 0 deletions
@@ -42,6 +42,9 @@ http.Handle("GET /protected", auth.Required(protectedHandler)) user := UserData{ID: 123, Name: "John"} auth.Set(w, user) +// Clear authentication cookie (logout) +auth.Clear(w) + // Get authenticated data user, ok := auth.Get(r.Context()) ``` @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "strings" "time" ) @@ -78,12 +79,14 @@ func (a Auth[T]) Required(handler http.Handler) http.Handler { cleartext, ok := a.enc.Decrypt(cookie.Value) if !ok { + a.Clear(w) http.Error(w, "Authentication failed", http.StatusUnauthorized) return } var data T if err := a.config.SerDes.Deserialize(bytes.NewBuffer(cleartext), &data); err != nil { + a.Clear(w) http.Error(w, "Server error", http.StatusInternalServerError) return } @@ -108,12 +111,14 @@ func (a Auth[T]) Optional(handler http.Handler) http.Handler { cleartext, ok := a.enc.Decrypt(cookie.Value) if !ok { + a.Clear(w) handler.ServeHTTP(w, r) return } var data T if err := a.config.SerDes.Deserialize(bytes.NewBuffer(cleartext), &data); err != nil { + a.Clear(w) http.Error(w, "Server error", http.StatusInternalServerError) return } @@ -145,6 +150,63 @@ func (a Auth[T]) Set(w http.ResponseWriter, data T) error { return nil } +// removeSetCookieHeaders removes any existing Set-Cookie headers that match the given cookie name. +// This ensures we don't send multiple Set-Cookie headers with the same cookie name, which +// violates RFC 6265 recommendations. +func removeSetCookieHeaders(w http.ResponseWriter, cookieName string) { + headers := w.Header() + setCookieHeaders := headers["Set-Cookie"] + if len(setCookieHeaders) == 0 { + return + } + + // Filter out headers that match our cookie name + var filteredHeaders []string + for _, header := range setCookieHeaders { + // Parse the cookie name from the Set-Cookie header + // Format: "name=value; other=attributes" + if idx := strings.Index(header, "="); idx > 0 { + headerCookieName := strings.TrimSpace(header[:idx]) + if headerCookieName != cookieName { + filteredHeaders = append(filteredHeaders, header) + } + } else { + // Keep malformed headers as-is + filteredHeaders = append(filteredHeaders, header) + } + } + + // Replace the Set-Cookie headers with the filtered list + if len(filteredHeaders) == 0 { + headers.Del("Set-Cookie") + } else { + headers["Set-Cookie"] = filteredHeaders + } +} + +// Clear removes the authentication cookie by setting it to expire immediately. +// +// This effectively logs out the user by invalidating their authentication cookie. +// If there are existing Set-Cookie headers for the same cookie name, they are removed +// to comply with RFC 6265 recommendations against multiple Set-Cookie headers with +// the same cookie name. +func (a Auth[T]) Clear(w http.ResponseWriter) { + // Remove any existing Set-Cookie headers for this cookie name + removeSetCookieHeaders(w, a.config.CookieName) + + cookie := &http.Cookie{ + Name: a.config.CookieName, + Value: "", + Path: a.config.URLPath, + Domain: a.config.URLDomain, + MaxAge: -1, + Secure: a.config.HTTPSOnly, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + } + w.Header().Add("Set-Cookie", cookie.String()) +} + // Get retrieves authentication data from the request context. // // Returns the data and true if authentication data is present and valid, diff --git a/auth_test.go b/auth_test.go index 131e132..9b679bf 100644 --- a/auth_test.go +++ b/auth_test.go @@ -389,3 +389,443 @@ func TestAuth_Optional_DeserializationError(t *testing.T) { t.Errorf("Expected status 500 for deserialization error, got %d", w.Code) } } + +func TestAuth_Clear(t *testing.T) { + auth := createTestAuth() + + t.Run("clear cookie properties", func(t *testing.T) { + w := httptest.NewRecorder() + auth.Clear(w) + + cookies := w.Result().Cookies() + if len(cookies) != 1 { + t.Fatalf("Expected 1 cookie, got %d", len(cookies)) + } + + cookie := cookies[0] + if cookie.Name != "test_auth" { + t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name) + } + if cookie.Value != "" { + t.Errorf("Expected empty cookie value, got %s", cookie.Value) + } + if cookie.Path != "/" { + t.Errorf("Expected cookie path '/', got %s", cookie.Path) + } + if cookie.Domain != "example.com" { + t.Errorf("Expected cookie domain 'example.com', got %s", cookie.Domain) + } + if cookie.MaxAge != -1 { + t.Errorf("Expected cookie MaxAge -1, got %d", cookie.MaxAge) + } + if !cookie.HttpOnly { + t.Error("Expected cookie to be HttpOnly") + } + if cookie.SameSite != http.SameSiteLaxMode { + t.Errorf("Expected SameSite to be Lax, got %v", cookie.SameSite) + } + }) +} + +func TestAuth_Clear_RemovesExistingHeaders(t *testing.T) { + auth := createTestAuth() + data := testData{UserID: 123, Name: "Alice"} + + t.Run("clear removes existing Set-Cookie headers for same cookie", func(t *testing.T) { + w := httptest.NewRecorder() + + // Set the auth cookie first + err := auth.Set(w, data) + if err != nil { + t.Fatalf("Set() error = %v", err) + } + + // Verify we have one Set-Cookie header + setCookieHeaders := w.Header()["Set-Cookie"] + if len(setCookieHeaders) != 1 { + t.Fatalf("Expected 1 Set-Cookie header after Set(), got %d", len(setCookieHeaders)) + } + + // Clear the auth cookie + auth.Clear(w) + + // Verify we still have only one Set-Cookie header (the clear one) + setCookieHeaders = w.Header()["Set-Cookie"] + if len(setCookieHeaders) != 1 { + t.Fatalf("Expected 1 Set-Cookie header after Clear(), got %d", len(setCookieHeaders)) + } + + // Parse the remaining cookie to verify it's the clear cookie + cookies := w.Result().Cookies() + if len(cookies) != 1 { + t.Fatalf("Expected 1 parseable cookie, got %d", len(cookies)) + } + + cookie := cookies[0] + if cookie.Name != "test_auth" { + t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name) + } + if cookie.MaxAge != -1 { + t.Errorf("Expected cookie MaxAge -1, got %d", cookie.MaxAge) + } + if cookie.Value != "" { + t.Errorf("Expected empty cookie value, got %s", cookie.Value) + } + }) + + t.Run("clear does not affect other cookies", func(t *testing.T) { + w := httptest.NewRecorder() + + // Add a different cookie manually + otherCookie := &http.Cookie{ + Name: "other_cookie", + Value: "other_value", + } + w.Header().Add("Set-Cookie", otherCookie.String()) + + // Set the auth cookie + err := auth.Set(w, data) + if err != nil { + t.Fatalf("Set() error = %v", err) + } + + // Clear the auth cookie + auth.Clear(w) + + // Verify we have two cookies: the other cookie and the clear auth cookie + cookies := w.Result().Cookies() + if len(cookies) != 2 { + t.Fatalf("Expected 2 cookies after Clear(), got %d", len(cookies)) + } + + // Find the other cookie + var foundOtherCookie bool + var foundClearCookie bool + for _, cookie := range cookies { + if cookie.Name == "other_cookie" && cookie.Value == "other_value" { + foundOtherCookie = true + } + if cookie.Name == "test_auth" && cookie.MaxAge == -1 { + foundClearCookie = true + } + } + + if !foundOtherCookie { + t.Error("Expected to find other_cookie in response") + } + if !foundClearCookie { + t.Error("Expected to find clear auth cookie in response") + } + }) + + t.Run("clear works when no existing headers", func(t *testing.T) { + w := httptest.NewRecorder() + + // Clear without setting first + auth.Clear(w) + + // Verify we have one Set-Cookie header (the clear one) + cookies := w.Result().Cookies() + if len(cookies) != 1 { + t.Fatalf("Expected 1 cookie after Clear(), got %d", len(cookies)) + } + + cookie := cookies[0] + if cookie.Name != "test_auth" { + t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name) + } + if cookie.MaxAge != -1 { + t.Errorf("Expected cookie MaxAge -1, got %d", cookie.MaxAge) + } + }) +} + +func TestRemoveSetCookieHeaders(t *testing.T) { + t.Run("removes matching cookie headers", func(t *testing.T) { + w := httptest.NewRecorder() + + // Add multiple cookies + w.Header().Add("Set-Cookie", "test_auth=value1; Path=/") + w.Header().Add("Set-Cookie", "other_cookie=value2; Path=/") + w.Header().Add("Set-Cookie", "test_auth=value3; Path=/; Domain=example.com") + + removeSetCookieHeaders(w, "test_auth") + + cookies := w.Result().Cookies() + if len(cookies) != 1 { + t.Fatalf("Expected 1 cookie after removal, got %d", len(cookies)) + } + + if cookies[0].Name != "other_cookie" { + t.Errorf("Expected other_cookie to remain, got %s", cookies[0].Name) + } + }) + + t.Run("handles empty headers", func(t *testing.T) { + w := httptest.NewRecorder() + + removeSetCookieHeaders(w, "test_auth") + + cookies := w.Result().Cookies() + if len(cookies) != 0 { + t.Fatalf("Expected 0 cookies, got %d", len(cookies)) + } + }) + + t.Run("handles malformed headers", func(t *testing.T) { + w := httptest.NewRecorder() + + // Add a malformed header (no = sign) + w.Header().Add("Set-Cookie", "malformed_header") + w.Header().Add("Set-Cookie", "test_auth=value") + + removeSetCookieHeaders(w, "test_auth") + + // The malformed header should be preserved (check raw headers) + setCookieHeaders := w.Header()["Set-Cookie"] + if len(setCookieHeaders) != 1 || setCookieHeaders[0] != "malformed_header" { + t.Errorf("Expected malformed header to be preserved, got %v", setCookieHeaders) + } + + // The malformed header won't be parsed into a valid cookie by Go's parser + cookies := w.Result().Cookies() + if len(cookies) != 0 { + t.Errorf("Expected 0 parseable cookies (malformed header), got %d", len(cookies)) + } + }) +} + +func TestAuth_Required_ClearsInvalidCookies(t *testing.T) { + auth := createTestAuth() + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + protectedHandler := auth.Required(handler) + + t.Run("clears invalid cookie on decryption failure", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + req.AddCookie(&http.Cookie{ + Name: "test_auth", + Value: "invalid_encrypted_value", + }) + w := httptest.NewRecorder() + + protectedHandler.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("Expected status 401, got %d", w.Code) + } + + // Verify that a clear cookie was set + cookies := w.Result().Cookies() + if len(cookies) != 1 { + t.Fatalf("Expected 1 clear cookie to be set, got %d", len(cookies)) + } + + cookie := cookies[0] + if cookie.Name != "test_auth" { + t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name) + } + if cookie.MaxAge != -1 { + t.Errorf("Expected cookie MaxAge -1, got %d", cookie.MaxAge) + } + }) + + t.Run("clears invalid cookie on deserialization failure", func(t *testing.T) { + // Create a cookie with a different auth instance that will deserialize differently + config := AuthConfig[testData]{ + SerDes: failingSerDes{failDeserialize: true}, + CookieName: "test_auth", + } + goodAuth := createTestAuth() + badAuth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config) + + // Set a cookie with the good auth + data := testData{UserID: 123, Name: "Alice"} + setW := httptest.NewRecorder() + if err := goodAuth.Set(setW, data); err != nil { + t.Fatalf("Set() error = %v", err) + } + validCookie := setW.Result().Cookies()[0] + + // Try to use it with the bad auth (will fail deserialization) + req := httptest.NewRequest("GET", "/", nil) + req.AddCookie(validCookie) + w := httptest.NewRecorder() + + badAuth.Required(handler).ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("Expected status 500, got %d", w.Code) + } + + // Verify that a clear cookie was set + cookies := w.Result().Cookies() + if len(cookies) != 1 { + t.Fatalf("Expected 1 clear cookie to be set, got %d", len(cookies)) + } + + cookie := cookies[0] + if cookie.Name != "test_auth" { + t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name) + } + if cookie.MaxAge != -1 { + t.Errorf("Expected cookie MaxAge -1, got %d", cookie.MaxAge) + } + }) +} + +func TestAuth_Optional_ClearsInvalidCookies(t *testing.T) { + auth := createTestAuth() + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + optionalHandler := auth.Optional(handler) + + t.Run("clears invalid cookie on decryption failure and continues", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + req.AddCookie(&http.Cookie{ + Name: "test_auth", + Value: "invalid_encrypted_value", + }) + w := httptest.NewRecorder() + + optionalHandler.ServeHTTP(w, req) + + // Should continue processing (200 OK) + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + + // Verify that a clear cookie was set + cookies := w.Result().Cookies() + if len(cookies) != 1 { + t.Fatalf("Expected 1 clear cookie to be set, got %d", len(cookies)) + } + + cookie := cookies[0] + if cookie.Name != "test_auth" { + t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name) + } + if cookie.MaxAge != -1 { + t.Errorf("Expected cookie MaxAge -1, got %d", cookie.MaxAge) + } + }) + + t.Run("clears invalid cookie on deserialization failure", func(t *testing.T) { + // Create a cookie with a different auth instance that will deserialize differently + config := AuthConfig[testData]{ + SerDes: failingSerDes{failDeserialize: true}, + CookieName: "test_auth", + } + goodAuth := createTestAuth() + badAuth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config) + + // Set a cookie with the good auth + data := testData{UserID: 123, Name: "Alice"} + setW := httptest.NewRecorder() + if err := goodAuth.Set(setW, data); err != nil { + t.Fatalf("Set() error = %v", err) + } + validCookie := setW.Result().Cookies()[0] + + // Try to use it with the bad auth (will fail deserialization) + req := httptest.NewRequest("GET", "/", nil) + req.AddCookie(validCookie) + w := httptest.NewRecorder() + + badAuth.Optional(handler).ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("Expected status 500, got %d", w.Code) + } + + // Verify that a clear cookie was set + cookies := w.Result().Cookies() + if len(cookies) != 1 { + t.Fatalf("Expected 1 clear cookie to be set, got %d", len(cookies)) + } + + cookie := cookies[0] + if cookie.Name != "test_auth" { + t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name) + } + if cookie.MaxAge != -1 { + t.Errorf("Expected cookie MaxAge -1, got %d", cookie.MaxAge) + } + }) +} + +func TestAuth_Clear_Integration(t *testing.T) { + auth := createTestAuth() + data := testData{UserID: 123, Name: "Alice"} + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authData, ok := auth.Get(r.Context()) + if !ok { + t.Error("Expected to find auth data in context") + return + } + if authData.UserID != 123 || authData.Name != "Alice" { + t.Errorf("Expected UserID=123, Name=Alice, got UserID=%d, Name=%s", authData.UserID, authData.Name) + } + w.WriteHeader(http.StatusOK) + }) + + protectedHandler := auth.Required(handler) + + t.Run("authentication works before clear", func(t *testing.T) { + setW := httptest.NewRecorder() + if err := auth.Set(setW, data); err != nil { + t.Fatalf("Set() error = %v", err) + } + cookie := setW.Result().Cookies()[0] + + req := httptest.NewRequest("GET", "/", nil) + req.AddCookie(cookie) + w := httptest.NewRecorder() + + protectedHandler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + }) + + t.Run("authentication fails after clear", func(t *testing.T) { + setW := httptest.NewRecorder() + if err := auth.Set(setW, data); err != nil { + t.Fatalf("Set() error = %v", err) + } + originalCookie := setW.Result().Cookies()[0] + + clearW := httptest.NewRecorder() + auth.Clear(clearW) + clearCookie := clearW.Result().Cookies()[0] + + if clearCookie.MaxAge != -1 { + t.Errorf("Expected clear cookie to have MaxAge -1, got %d", clearCookie.MaxAge) + } + + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + protectedHandler.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("Expected status 401 when no cookie is present (simulating browser behavior after clear), got %d", w.Code) + } + + req2 := httptest.NewRequest("GET", "/", nil) + req2.AddCookie(originalCookie) + w2 := httptest.NewRecorder() + + protectedHandler.ServeHTTP(w2, req2) + + if w2.Code != http.StatusOK { + t.Errorf("Expected original cookie to still work, got %d", w2.Code) + } + }) +} @@ -39,6 +39,9 @@ // user := UserData{ID: 123, Name: "John"} // auth.Set(w, user) // +// // Clear authentication cookie (logout) +// auth.Clear(w) +// // // Get authenticated data // user, ok := auth.Get(r.Context()) // |