summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorT <t@tjp.lol>2025-07-05 09:41:47 -0600
committerT <t@tjp.lol>2025-07-05 09:55:47 -0600
commitcaf5bb2ee84079365996a622ab8fc5ed510ef9a7 (patch)
tree5caf8bcbcda5ab5c8d70782fb733923c0ac70af3
parent639ad6a02cbb4b713434671ec09f309aa5410921 (diff)
Add Auth.Clear() method and enhance middlewares to clear invalid cookiesHEADmain
- Add Auth.Clear() method that creates expired cookies (MaxAge: -1) to log out users - Enhance Clear() to remove existing Set-Cookie headers to comply with RFC 6265 - Update Required() and Optional() middlewares to automatically clear invalid cookies - Add comprehensive tests for all new functionality - Update documentation with Auth.Clear() usage examples
-rw-r--r--README.md3
-rw-r--r--auth.go62
-rw-r--r--auth_test.go440
-rw-r--r--doc.go3
4 files changed, 508 insertions, 0 deletions
diff --git a/README.md b/README.md
index d5d7882..da36cad 100644
--- a/README.md
+++ b/README.md
@@ -42,6 +42,9 @@ http.Handle("GET /protected", auth.Required(protectedHandler))
user := UserData{ID: 123, Name: "John"}
auth.Set(w, user)
+// Clear authentication cookie (logout)
+auth.Clear(w)
+
// Get authenticated data
user, ok := auth.Get(r.Context())
```
diff --git a/auth.go b/auth.go
index 1066d9f..ac1fc57 100644
--- a/auth.go
+++ b/auth.go
@@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net/http"
+ "strings"
"time"
)
@@ -78,12 +79,14 @@ func (a Auth[T]) Required(handler http.Handler) http.Handler {
cleartext, ok := a.enc.Decrypt(cookie.Value)
if !ok {
+ a.Clear(w)
http.Error(w, "Authentication failed", http.StatusUnauthorized)
return
}
var data T
if err := a.config.SerDes.Deserialize(bytes.NewBuffer(cleartext), &data); err != nil {
+ a.Clear(w)
http.Error(w, "Server error", http.StatusInternalServerError)
return
}
@@ -108,12 +111,14 @@ func (a Auth[T]) Optional(handler http.Handler) http.Handler {
cleartext, ok := a.enc.Decrypt(cookie.Value)
if !ok {
+ a.Clear(w)
handler.ServeHTTP(w, r)
return
}
var data T
if err := a.config.SerDes.Deserialize(bytes.NewBuffer(cleartext), &data); err != nil {
+ a.Clear(w)
http.Error(w, "Server error", http.StatusInternalServerError)
return
}
@@ -145,6 +150,63 @@ func (a Auth[T]) Set(w http.ResponseWriter, data T) error {
return nil
}
+// removeSetCookieHeaders removes any existing Set-Cookie headers that match the given cookie name.
+// This ensures we don't send multiple Set-Cookie headers with the same cookie name, which
+// violates RFC 6265 recommendations.
+func removeSetCookieHeaders(w http.ResponseWriter, cookieName string) {
+ headers := w.Header()
+ setCookieHeaders := headers["Set-Cookie"]
+ if len(setCookieHeaders) == 0 {
+ return
+ }
+
+ // Filter out headers that match our cookie name
+ var filteredHeaders []string
+ for _, header := range setCookieHeaders {
+ // Parse the cookie name from the Set-Cookie header
+ // Format: "name=value; other=attributes"
+ if idx := strings.Index(header, "="); idx > 0 {
+ headerCookieName := strings.TrimSpace(header[:idx])
+ if headerCookieName != cookieName {
+ filteredHeaders = append(filteredHeaders, header)
+ }
+ } else {
+ // Keep malformed headers as-is
+ filteredHeaders = append(filteredHeaders, header)
+ }
+ }
+
+ // Replace the Set-Cookie headers with the filtered list
+ if len(filteredHeaders) == 0 {
+ headers.Del("Set-Cookie")
+ } else {
+ headers["Set-Cookie"] = filteredHeaders
+ }
+}
+
+// Clear removes the authentication cookie by setting it to expire immediately.
+//
+// This effectively logs out the user by invalidating their authentication cookie.
+// If there are existing Set-Cookie headers for the same cookie name, they are removed
+// to comply with RFC 6265 recommendations against multiple Set-Cookie headers with
+// the same cookie name.
+func (a Auth[T]) Clear(w http.ResponseWriter) {
+ // Remove any existing Set-Cookie headers for this cookie name
+ removeSetCookieHeaders(w, a.config.CookieName)
+
+ cookie := &http.Cookie{
+ Name: a.config.CookieName,
+ Value: "",
+ Path: a.config.URLPath,
+ Domain: a.config.URLDomain,
+ MaxAge: -1,
+ Secure: a.config.HTTPSOnly,
+ HttpOnly: true,
+ SameSite: http.SameSiteLaxMode,
+ }
+ w.Header().Add("Set-Cookie", cookie.String())
+}
+
// Get retrieves authentication data from the request context.
//
// Returns the data and true if authentication data is present and valid,
diff --git a/auth_test.go b/auth_test.go
index 131e132..9b679bf 100644
--- a/auth_test.go
+++ b/auth_test.go
@@ -389,3 +389,443 @@ func TestAuth_Optional_DeserializationError(t *testing.T) {
t.Errorf("Expected status 500 for deserialization error, got %d", w.Code)
}
}
+
+func TestAuth_Clear(t *testing.T) {
+ auth := createTestAuth()
+
+ t.Run("clear cookie properties", func(t *testing.T) {
+ w := httptest.NewRecorder()
+ auth.Clear(w)
+
+ cookies := w.Result().Cookies()
+ if len(cookies) != 1 {
+ t.Fatalf("Expected 1 cookie, got %d", len(cookies))
+ }
+
+ cookie := cookies[0]
+ if cookie.Name != "test_auth" {
+ t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name)
+ }
+ if cookie.Value != "" {
+ t.Errorf("Expected empty cookie value, got %s", cookie.Value)
+ }
+ if cookie.Path != "/" {
+ t.Errorf("Expected cookie path '/', got %s", cookie.Path)
+ }
+ if cookie.Domain != "example.com" {
+ t.Errorf("Expected cookie domain 'example.com', got %s", cookie.Domain)
+ }
+ if cookie.MaxAge != -1 {
+ t.Errorf("Expected cookie MaxAge -1, got %d", cookie.MaxAge)
+ }
+ if !cookie.HttpOnly {
+ t.Error("Expected cookie to be HttpOnly")
+ }
+ if cookie.SameSite != http.SameSiteLaxMode {
+ t.Errorf("Expected SameSite to be Lax, got %v", cookie.SameSite)
+ }
+ })
+}
+
+func TestAuth_Clear_RemovesExistingHeaders(t *testing.T) {
+ auth := createTestAuth()
+ data := testData{UserID: 123, Name: "Alice"}
+
+ t.Run("clear removes existing Set-Cookie headers for same cookie", func(t *testing.T) {
+ w := httptest.NewRecorder()
+
+ // Set the auth cookie first
+ err := auth.Set(w, data)
+ if err != nil {
+ t.Fatalf("Set() error = %v", err)
+ }
+
+ // Verify we have one Set-Cookie header
+ setCookieHeaders := w.Header()["Set-Cookie"]
+ if len(setCookieHeaders) != 1 {
+ t.Fatalf("Expected 1 Set-Cookie header after Set(), got %d", len(setCookieHeaders))
+ }
+
+ // Clear the auth cookie
+ auth.Clear(w)
+
+ // Verify we still have only one Set-Cookie header (the clear one)
+ setCookieHeaders = w.Header()["Set-Cookie"]
+ if len(setCookieHeaders) != 1 {
+ t.Fatalf("Expected 1 Set-Cookie header after Clear(), got %d", len(setCookieHeaders))
+ }
+
+ // Parse the remaining cookie to verify it's the clear cookie
+ cookies := w.Result().Cookies()
+ if len(cookies) != 1 {
+ t.Fatalf("Expected 1 parseable cookie, got %d", len(cookies))
+ }
+
+ cookie := cookies[0]
+ if cookie.Name != "test_auth" {
+ t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name)
+ }
+ if cookie.MaxAge != -1 {
+ t.Errorf("Expected cookie MaxAge -1, got %d", cookie.MaxAge)
+ }
+ if cookie.Value != "" {
+ t.Errorf("Expected empty cookie value, got %s", cookie.Value)
+ }
+ })
+
+ t.Run("clear does not affect other cookies", func(t *testing.T) {
+ w := httptest.NewRecorder()
+
+ // Add a different cookie manually
+ otherCookie := &http.Cookie{
+ Name: "other_cookie",
+ Value: "other_value",
+ }
+ w.Header().Add("Set-Cookie", otherCookie.String())
+
+ // Set the auth cookie
+ err := auth.Set(w, data)
+ if err != nil {
+ t.Fatalf("Set() error = %v", err)
+ }
+
+ // Clear the auth cookie
+ auth.Clear(w)
+
+ // Verify we have two cookies: the other cookie and the clear auth cookie
+ cookies := w.Result().Cookies()
+ if len(cookies) != 2 {
+ t.Fatalf("Expected 2 cookies after Clear(), got %d", len(cookies))
+ }
+
+ // Find the other cookie
+ var foundOtherCookie bool
+ var foundClearCookie bool
+ for _, cookie := range cookies {
+ if cookie.Name == "other_cookie" && cookie.Value == "other_value" {
+ foundOtherCookie = true
+ }
+ if cookie.Name == "test_auth" && cookie.MaxAge == -1 {
+ foundClearCookie = true
+ }
+ }
+
+ if !foundOtherCookie {
+ t.Error("Expected to find other_cookie in response")
+ }
+ if !foundClearCookie {
+ t.Error("Expected to find clear auth cookie in response")
+ }
+ })
+
+ t.Run("clear works when no existing headers", func(t *testing.T) {
+ w := httptest.NewRecorder()
+
+ // Clear without setting first
+ auth.Clear(w)
+
+ // Verify we have one Set-Cookie header (the clear one)
+ cookies := w.Result().Cookies()
+ if len(cookies) != 1 {
+ t.Fatalf("Expected 1 cookie after Clear(), got %d", len(cookies))
+ }
+
+ cookie := cookies[0]
+ if cookie.Name != "test_auth" {
+ t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name)
+ }
+ if cookie.MaxAge != -1 {
+ t.Errorf("Expected cookie MaxAge -1, got %d", cookie.MaxAge)
+ }
+ })
+}
+
+func TestRemoveSetCookieHeaders(t *testing.T) {
+ t.Run("removes matching cookie headers", func(t *testing.T) {
+ w := httptest.NewRecorder()
+
+ // Add multiple cookies
+ w.Header().Add("Set-Cookie", "test_auth=value1; Path=/")
+ w.Header().Add("Set-Cookie", "other_cookie=value2; Path=/")
+ w.Header().Add("Set-Cookie", "test_auth=value3; Path=/; Domain=example.com")
+
+ removeSetCookieHeaders(w, "test_auth")
+
+ cookies := w.Result().Cookies()
+ if len(cookies) != 1 {
+ t.Fatalf("Expected 1 cookie after removal, got %d", len(cookies))
+ }
+
+ if cookies[0].Name != "other_cookie" {
+ t.Errorf("Expected other_cookie to remain, got %s", cookies[0].Name)
+ }
+ })
+
+ t.Run("handles empty headers", func(t *testing.T) {
+ w := httptest.NewRecorder()
+
+ removeSetCookieHeaders(w, "test_auth")
+
+ cookies := w.Result().Cookies()
+ if len(cookies) != 0 {
+ t.Fatalf("Expected 0 cookies, got %d", len(cookies))
+ }
+ })
+
+ t.Run("handles malformed headers", func(t *testing.T) {
+ w := httptest.NewRecorder()
+
+ // Add a malformed header (no = sign)
+ w.Header().Add("Set-Cookie", "malformed_header")
+ w.Header().Add("Set-Cookie", "test_auth=value")
+
+ removeSetCookieHeaders(w, "test_auth")
+
+ // The malformed header should be preserved (check raw headers)
+ setCookieHeaders := w.Header()["Set-Cookie"]
+ if len(setCookieHeaders) != 1 || setCookieHeaders[0] != "malformed_header" {
+ t.Errorf("Expected malformed header to be preserved, got %v", setCookieHeaders)
+ }
+
+ // The malformed header won't be parsed into a valid cookie by Go's parser
+ cookies := w.Result().Cookies()
+ if len(cookies) != 0 {
+ t.Errorf("Expected 0 parseable cookies (malformed header), got %d", len(cookies))
+ }
+ })
+}
+
+func TestAuth_Required_ClearsInvalidCookies(t *testing.T) {
+ auth := createTestAuth()
+
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ })
+ protectedHandler := auth.Required(handler)
+
+ t.Run("clears invalid cookie on decryption failure", func(t *testing.T) {
+ req := httptest.NewRequest("GET", "/", nil)
+ req.AddCookie(&http.Cookie{
+ Name: "test_auth",
+ Value: "invalid_encrypted_value",
+ })
+ w := httptest.NewRecorder()
+
+ protectedHandler.ServeHTTP(w, req)
+
+ if w.Code != http.StatusUnauthorized {
+ t.Errorf("Expected status 401, got %d", w.Code)
+ }
+
+ // Verify that a clear cookie was set
+ cookies := w.Result().Cookies()
+ if len(cookies) != 1 {
+ t.Fatalf("Expected 1 clear cookie to be set, got %d", len(cookies))
+ }
+
+ cookie := cookies[0]
+ if cookie.Name != "test_auth" {
+ t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name)
+ }
+ if cookie.MaxAge != -1 {
+ t.Errorf("Expected cookie MaxAge -1, got %d", cookie.MaxAge)
+ }
+ })
+
+ t.Run("clears invalid cookie on deserialization failure", func(t *testing.T) {
+ // Create a cookie with a different auth instance that will deserialize differently
+ config := AuthConfig[testData]{
+ SerDes: failingSerDes{failDeserialize: true},
+ CookieName: "test_auth",
+ }
+ goodAuth := createTestAuth()
+ badAuth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config)
+
+ // Set a cookie with the good auth
+ data := testData{UserID: 123, Name: "Alice"}
+ setW := httptest.NewRecorder()
+ if err := goodAuth.Set(setW, data); err != nil {
+ t.Fatalf("Set() error = %v", err)
+ }
+ validCookie := setW.Result().Cookies()[0]
+
+ // Try to use it with the bad auth (will fail deserialization)
+ req := httptest.NewRequest("GET", "/", nil)
+ req.AddCookie(validCookie)
+ w := httptest.NewRecorder()
+
+ badAuth.Required(handler).ServeHTTP(w, req)
+
+ if w.Code != http.StatusInternalServerError {
+ t.Errorf("Expected status 500, got %d", w.Code)
+ }
+
+ // Verify that a clear cookie was set
+ cookies := w.Result().Cookies()
+ if len(cookies) != 1 {
+ t.Fatalf("Expected 1 clear cookie to be set, got %d", len(cookies))
+ }
+
+ cookie := cookies[0]
+ if cookie.Name != "test_auth" {
+ t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name)
+ }
+ if cookie.MaxAge != -1 {
+ t.Errorf("Expected cookie MaxAge -1, got %d", cookie.MaxAge)
+ }
+ })
+}
+
+func TestAuth_Optional_ClearsInvalidCookies(t *testing.T) {
+ auth := createTestAuth()
+
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ })
+ optionalHandler := auth.Optional(handler)
+
+ t.Run("clears invalid cookie on decryption failure and continues", func(t *testing.T) {
+ req := httptest.NewRequest("GET", "/", nil)
+ req.AddCookie(&http.Cookie{
+ Name: "test_auth",
+ Value: "invalid_encrypted_value",
+ })
+ w := httptest.NewRecorder()
+
+ optionalHandler.ServeHTTP(w, req)
+
+ // Should continue processing (200 OK)
+ if w.Code != http.StatusOK {
+ t.Errorf("Expected status 200, got %d", w.Code)
+ }
+
+ // Verify that a clear cookie was set
+ cookies := w.Result().Cookies()
+ if len(cookies) != 1 {
+ t.Fatalf("Expected 1 clear cookie to be set, got %d", len(cookies))
+ }
+
+ cookie := cookies[0]
+ if cookie.Name != "test_auth" {
+ t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name)
+ }
+ if cookie.MaxAge != -1 {
+ t.Errorf("Expected cookie MaxAge -1, got %d", cookie.MaxAge)
+ }
+ })
+
+ t.Run("clears invalid cookie on deserialization failure", func(t *testing.T) {
+ // Create a cookie with a different auth instance that will deserialize differently
+ config := AuthConfig[testData]{
+ SerDes: failingSerDes{failDeserialize: true},
+ CookieName: "test_auth",
+ }
+ goodAuth := createTestAuth()
+ badAuth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config)
+
+ // Set a cookie with the good auth
+ data := testData{UserID: 123, Name: "Alice"}
+ setW := httptest.NewRecorder()
+ if err := goodAuth.Set(setW, data); err != nil {
+ t.Fatalf("Set() error = %v", err)
+ }
+ validCookie := setW.Result().Cookies()[0]
+
+ // Try to use it with the bad auth (will fail deserialization)
+ req := httptest.NewRequest("GET", "/", nil)
+ req.AddCookie(validCookie)
+ w := httptest.NewRecorder()
+
+ badAuth.Optional(handler).ServeHTTP(w, req)
+
+ if w.Code != http.StatusInternalServerError {
+ t.Errorf("Expected status 500, got %d", w.Code)
+ }
+
+ // Verify that a clear cookie was set
+ cookies := w.Result().Cookies()
+ if len(cookies) != 1 {
+ t.Fatalf("Expected 1 clear cookie to be set, got %d", len(cookies))
+ }
+
+ cookie := cookies[0]
+ if cookie.Name != "test_auth" {
+ t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name)
+ }
+ if cookie.MaxAge != -1 {
+ t.Errorf("Expected cookie MaxAge -1, got %d", cookie.MaxAge)
+ }
+ })
+}
+
+func TestAuth_Clear_Integration(t *testing.T) {
+ auth := createTestAuth()
+ data := testData{UserID: 123, Name: "Alice"}
+
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ authData, ok := auth.Get(r.Context())
+ if !ok {
+ t.Error("Expected to find auth data in context")
+ return
+ }
+ if authData.UserID != 123 || authData.Name != "Alice" {
+ t.Errorf("Expected UserID=123, Name=Alice, got UserID=%d, Name=%s", authData.UserID, authData.Name)
+ }
+ w.WriteHeader(http.StatusOK)
+ })
+
+ protectedHandler := auth.Required(handler)
+
+ t.Run("authentication works before clear", func(t *testing.T) {
+ setW := httptest.NewRecorder()
+ if err := auth.Set(setW, data); err != nil {
+ t.Fatalf("Set() error = %v", err)
+ }
+ cookie := setW.Result().Cookies()[0]
+
+ req := httptest.NewRequest("GET", "/", nil)
+ req.AddCookie(cookie)
+ w := httptest.NewRecorder()
+
+ protectedHandler.ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Errorf("Expected status 200, got %d", w.Code)
+ }
+ })
+
+ t.Run("authentication fails after clear", func(t *testing.T) {
+ setW := httptest.NewRecorder()
+ if err := auth.Set(setW, data); err != nil {
+ t.Fatalf("Set() error = %v", err)
+ }
+ originalCookie := setW.Result().Cookies()[0]
+
+ clearW := httptest.NewRecorder()
+ auth.Clear(clearW)
+ clearCookie := clearW.Result().Cookies()[0]
+
+ if clearCookie.MaxAge != -1 {
+ t.Errorf("Expected clear cookie to have MaxAge -1, got %d", clearCookie.MaxAge)
+ }
+
+ req := httptest.NewRequest("GET", "/", nil)
+ w := httptest.NewRecorder()
+
+ protectedHandler.ServeHTTP(w, req)
+
+ if w.Code != http.StatusUnauthorized {
+ t.Errorf("Expected status 401 when no cookie is present (simulating browser behavior after clear), got %d", w.Code)
+ }
+
+ req2 := httptest.NewRequest("GET", "/", nil)
+ req2.AddCookie(originalCookie)
+ w2 := httptest.NewRecorder()
+
+ protectedHandler.ServeHTTP(w2, req2)
+
+ if w2.Code != http.StatusOK {
+ t.Errorf("Expected original cookie to still work, got %d", w2.Code)
+ }
+ })
+}
diff --git a/doc.go b/doc.go
index 8d60f2a..12c314d 100644
--- a/doc.go
+++ b/doc.go
@@ -39,6 +39,9 @@
// user := UserData{ID: 123, Name: "John"}
// auth.Set(w, user)
//
+// // Clear authentication cookie (logout)
+// auth.Clear(w)
+//
// // Get authenticated data
// user, ok := auth.Get(r.Context())
//