diff options
Diffstat (limited to 'auth_test.go')
-rw-r--r-- | auth_test.go | 440 |
1 files changed, 440 insertions, 0 deletions
diff --git a/auth_test.go b/auth_test.go index 131e132..9b679bf 100644 --- a/auth_test.go +++ b/auth_test.go @@ -389,3 +389,443 @@ func TestAuth_Optional_DeserializationError(t *testing.T) { t.Errorf("Expected status 500 for deserialization error, got %d", w.Code) } } + +func TestAuth_Clear(t *testing.T) { + auth := createTestAuth() + + t.Run("clear cookie properties", func(t *testing.T) { + w := httptest.NewRecorder() + auth.Clear(w) + + cookies := w.Result().Cookies() + if len(cookies) != 1 { + t.Fatalf("Expected 1 cookie, got %d", len(cookies)) + } + + cookie := cookies[0] + if cookie.Name != "test_auth" { + t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name) + } + if cookie.Value != "" { + t.Errorf("Expected empty cookie value, got %s", cookie.Value) + } + if cookie.Path != "/" { + t.Errorf("Expected cookie path '/', got %s", cookie.Path) + } + if cookie.Domain != "example.com" { + t.Errorf("Expected cookie domain 'example.com', got %s", cookie.Domain) + } + if cookie.MaxAge != -1 { + t.Errorf("Expected cookie MaxAge -1, got %d", cookie.MaxAge) + } + if !cookie.HttpOnly { + t.Error("Expected cookie to be HttpOnly") + } + if cookie.SameSite != http.SameSiteLaxMode { + t.Errorf("Expected SameSite to be Lax, got %v", cookie.SameSite) + } + }) +} + +func TestAuth_Clear_RemovesExistingHeaders(t *testing.T) { + auth := createTestAuth() + data := testData{UserID: 123, Name: "Alice"} + + t.Run("clear removes existing Set-Cookie headers for same cookie", func(t *testing.T) { + w := httptest.NewRecorder() + + // Set the auth cookie first + err := auth.Set(w, data) + if err != nil { + t.Fatalf("Set() error = %v", err) + } + + // Verify we have one Set-Cookie header + setCookieHeaders := w.Header()["Set-Cookie"] + if len(setCookieHeaders) != 1 { + t.Fatalf("Expected 1 Set-Cookie header after Set(), got %d", len(setCookieHeaders)) + } + + // Clear the auth cookie + auth.Clear(w) + + // Verify we still have only one Set-Cookie header (the clear one) + setCookieHeaders = w.Header()["Set-Cookie"] + if len(setCookieHeaders) != 1 { + t.Fatalf("Expected 1 Set-Cookie header after Clear(), got %d", len(setCookieHeaders)) + } + + // Parse the remaining cookie to verify it's the clear cookie + cookies := w.Result().Cookies() + if len(cookies) != 1 { + t.Fatalf("Expected 1 parseable cookie, got %d", len(cookies)) + } + + cookie := cookies[0] + if cookie.Name != "test_auth" { + t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name) + } + if cookie.MaxAge != -1 { + t.Errorf("Expected cookie MaxAge -1, got %d", cookie.MaxAge) + } + if cookie.Value != "" { + t.Errorf("Expected empty cookie value, got %s", cookie.Value) + } + }) + + t.Run("clear does not affect other cookies", func(t *testing.T) { + w := httptest.NewRecorder() + + // Add a different cookie manually + otherCookie := &http.Cookie{ + Name: "other_cookie", + Value: "other_value", + } + w.Header().Add("Set-Cookie", otherCookie.String()) + + // Set the auth cookie + err := auth.Set(w, data) + if err != nil { + t.Fatalf("Set() error = %v", err) + } + + // Clear the auth cookie + auth.Clear(w) + + // Verify we have two cookies: the other cookie and the clear auth cookie + cookies := w.Result().Cookies() + if len(cookies) != 2 { + t.Fatalf("Expected 2 cookies after Clear(), got %d", len(cookies)) + } + + // Find the other cookie + var foundOtherCookie bool + var foundClearCookie bool + for _, cookie := range cookies { + if cookie.Name == "other_cookie" && cookie.Value == "other_value" { + foundOtherCookie = true + } + if cookie.Name == "test_auth" && cookie.MaxAge == -1 { + foundClearCookie = true + } + } + + if !foundOtherCookie { + t.Error("Expected to find other_cookie in response") + } + if !foundClearCookie { + t.Error("Expected to find clear auth cookie in response") + } + }) + + t.Run("clear works when no existing headers", func(t *testing.T) { + w := httptest.NewRecorder() + + // Clear without setting first + auth.Clear(w) + + // Verify we have one Set-Cookie header (the clear one) + cookies := w.Result().Cookies() + if len(cookies) != 1 { + t.Fatalf("Expected 1 cookie after Clear(), got %d", len(cookies)) + } + + cookie := cookies[0] + if cookie.Name != "test_auth" { + t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name) + } + if cookie.MaxAge != -1 { + t.Errorf("Expected cookie MaxAge -1, got %d", cookie.MaxAge) + } + }) +} + +func TestRemoveSetCookieHeaders(t *testing.T) { + t.Run("removes matching cookie headers", func(t *testing.T) { + w := httptest.NewRecorder() + + // Add multiple cookies + w.Header().Add("Set-Cookie", "test_auth=value1; Path=/") + w.Header().Add("Set-Cookie", "other_cookie=value2; Path=/") + w.Header().Add("Set-Cookie", "test_auth=value3; Path=/; Domain=example.com") + + removeSetCookieHeaders(w, "test_auth") + + cookies := w.Result().Cookies() + if len(cookies) != 1 { + t.Fatalf("Expected 1 cookie after removal, got %d", len(cookies)) + } + + if cookies[0].Name != "other_cookie" { + t.Errorf("Expected other_cookie to remain, got %s", cookies[0].Name) + } + }) + + t.Run("handles empty headers", func(t *testing.T) { + w := httptest.NewRecorder() + + removeSetCookieHeaders(w, "test_auth") + + cookies := w.Result().Cookies() + if len(cookies) != 0 { + t.Fatalf("Expected 0 cookies, got %d", len(cookies)) + } + }) + + t.Run("handles malformed headers", func(t *testing.T) { + w := httptest.NewRecorder() + + // Add a malformed header (no = sign) + w.Header().Add("Set-Cookie", "malformed_header") + w.Header().Add("Set-Cookie", "test_auth=value") + + removeSetCookieHeaders(w, "test_auth") + + // The malformed header should be preserved (check raw headers) + setCookieHeaders := w.Header()["Set-Cookie"] + if len(setCookieHeaders) != 1 || setCookieHeaders[0] != "malformed_header" { + t.Errorf("Expected malformed header to be preserved, got %v", setCookieHeaders) + } + + // The malformed header won't be parsed into a valid cookie by Go's parser + cookies := w.Result().Cookies() + if len(cookies) != 0 { + t.Errorf("Expected 0 parseable cookies (malformed header), got %d", len(cookies)) + } + }) +} + +func TestAuth_Required_ClearsInvalidCookies(t *testing.T) { + auth := createTestAuth() + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + protectedHandler := auth.Required(handler) + + t.Run("clears invalid cookie on decryption failure", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + req.AddCookie(&http.Cookie{ + Name: "test_auth", + Value: "invalid_encrypted_value", + }) + w := httptest.NewRecorder() + + protectedHandler.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("Expected status 401, got %d", w.Code) + } + + // Verify that a clear cookie was set + cookies := w.Result().Cookies() + if len(cookies) != 1 { + t.Fatalf("Expected 1 clear cookie to be set, got %d", len(cookies)) + } + + cookie := cookies[0] + if cookie.Name != "test_auth" { + t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name) + } + if cookie.MaxAge != -1 { + t.Errorf("Expected cookie MaxAge -1, got %d", cookie.MaxAge) + } + }) + + t.Run("clears invalid cookie on deserialization failure", func(t *testing.T) { + // Create a cookie with a different auth instance that will deserialize differently + config := AuthConfig[testData]{ + SerDes: failingSerDes{failDeserialize: true}, + CookieName: "test_auth", + } + goodAuth := createTestAuth() + badAuth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config) + + // Set a cookie with the good auth + data := testData{UserID: 123, Name: "Alice"} + setW := httptest.NewRecorder() + if err := goodAuth.Set(setW, data); err != nil { + t.Fatalf("Set() error = %v", err) + } + validCookie := setW.Result().Cookies()[0] + + // Try to use it with the bad auth (will fail deserialization) + req := httptest.NewRequest("GET", "/", nil) + req.AddCookie(validCookie) + w := httptest.NewRecorder() + + badAuth.Required(handler).ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("Expected status 500, got %d", w.Code) + } + + // Verify that a clear cookie was set + cookies := w.Result().Cookies() + if len(cookies) != 1 { + t.Fatalf("Expected 1 clear cookie to be set, got %d", len(cookies)) + } + + cookie := cookies[0] + if cookie.Name != "test_auth" { + t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name) + } + if cookie.MaxAge != -1 { + t.Errorf("Expected cookie MaxAge -1, got %d", cookie.MaxAge) + } + }) +} + +func TestAuth_Optional_ClearsInvalidCookies(t *testing.T) { + auth := createTestAuth() + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + optionalHandler := auth.Optional(handler) + + t.Run("clears invalid cookie on decryption failure and continues", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + req.AddCookie(&http.Cookie{ + Name: "test_auth", + Value: "invalid_encrypted_value", + }) + w := httptest.NewRecorder() + + optionalHandler.ServeHTTP(w, req) + + // Should continue processing (200 OK) + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + + // Verify that a clear cookie was set + cookies := w.Result().Cookies() + if len(cookies) != 1 { + t.Fatalf("Expected 1 clear cookie to be set, got %d", len(cookies)) + } + + cookie := cookies[0] + if cookie.Name != "test_auth" { + t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name) + } + if cookie.MaxAge != -1 { + t.Errorf("Expected cookie MaxAge -1, got %d", cookie.MaxAge) + } + }) + + t.Run("clears invalid cookie on deserialization failure", func(t *testing.T) { + // Create a cookie with a different auth instance that will deserialize differently + config := AuthConfig[testData]{ + SerDes: failingSerDes{failDeserialize: true}, + CookieName: "test_auth", + } + goodAuth := createTestAuth() + badAuth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config) + + // Set a cookie with the good auth + data := testData{UserID: 123, Name: "Alice"} + setW := httptest.NewRecorder() + if err := goodAuth.Set(setW, data); err != nil { + t.Fatalf("Set() error = %v", err) + } + validCookie := setW.Result().Cookies()[0] + + // Try to use it with the bad auth (will fail deserialization) + req := httptest.NewRequest("GET", "/", nil) + req.AddCookie(validCookie) + w := httptest.NewRecorder() + + badAuth.Optional(handler).ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("Expected status 500, got %d", w.Code) + } + + // Verify that a clear cookie was set + cookies := w.Result().Cookies() + if len(cookies) != 1 { + t.Fatalf("Expected 1 clear cookie to be set, got %d", len(cookies)) + } + + cookie := cookies[0] + if cookie.Name != "test_auth" { + t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name) + } + if cookie.MaxAge != -1 { + t.Errorf("Expected cookie MaxAge -1, got %d", cookie.MaxAge) + } + }) +} + +func TestAuth_Clear_Integration(t *testing.T) { + auth := createTestAuth() + data := testData{UserID: 123, Name: "Alice"} + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authData, ok := auth.Get(r.Context()) + if !ok { + t.Error("Expected to find auth data in context") + return + } + if authData.UserID != 123 || authData.Name != "Alice" { + t.Errorf("Expected UserID=123, Name=Alice, got UserID=%d, Name=%s", authData.UserID, authData.Name) + } + w.WriteHeader(http.StatusOK) + }) + + protectedHandler := auth.Required(handler) + + t.Run("authentication works before clear", func(t *testing.T) { + setW := httptest.NewRecorder() + if err := auth.Set(setW, data); err != nil { + t.Fatalf("Set() error = %v", err) + } + cookie := setW.Result().Cookies()[0] + + req := httptest.NewRequest("GET", "/", nil) + req.AddCookie(cookie) + w := httptest.NewRecorder() + + protectedHandler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + }) + + t.Run("authentication fails after clear", func(t *testing.T) { + setW := httptest.NewRecorder() + if err := auth.Set(setW, data); err != nil { + t.Fatalf("Set() error = %v", err) + } + originalCookie := setW.Result().Cookies()[0] + + clearW := httptest.NewRecorder() + auth.Clear(clearW) + clearCookie := clearW.Result().Cookies()[0] + + if clearCookie.MaxAge != -1 { + t.Errorf("Expected clear cookie to have MaxAge -1, got %d", clearCookie.MaxAge) + } + + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + protectedHandler.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("Expected status 401 when no cookie is present (simulating browser behavior after clear), got %d", w.Code) + } + + req2 := httptest.NewRequest("GET", "/", nil) + req2.AddCookie(originalCookie) + w2 := httptest.NewRecorder() + + protectedHandler.ServeHTTP(w2, req2) + + if w2.Code != http.StatusOK { + t.Errorf("Expected original cookie to still work, got %d", w2.Code) + } + }) +} |