package kate import ( "context" "io" "net/http" "net/http/httptest" "strings" "testing" "time" ) type testData struct { UserID int Name string } type testSerDes struct{} func (ts testSerDes) Serialize(w io.Writer, data testData) error { _, err := w.Write([]byte(strings.Join([]string{string(rune(data.UserID)), data.Name}, "|"))) return err } func (ts testSerDes) Deserialize(r io.Reader, data *testData) error { buf, err := io.ReadAll(r) if err != nil { return err } parts := strings.Split(string(buf), "|") if len(parts) != 2 { return io.ErrUnexpectedEOF } data.UserID = int(rune(parts[0][0])) data.Name = parts[1] return nil } func createTestAuth() Auth[testData] { config := AuthConfig[testData]{ SerDes: testSerDes{}, CookieName: "test_auth", URLPath: "/", URLDomain: "example.com", HTTPSOnly: false, MaxAge: time.Hour, } return New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config) } func TestNew(t *testing.T) { tests := []struct { name string privkey string wantErr bool }{ { name: "valid hex key", privkey: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", wantErr: false, }, { name: "invalid hex key - too short", privkey: "0123456789abcdef", wantErr: true, }, { name: "invalid hex key - not hex", privkey: "invalid_hex_key_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { config := AuthConfig[testData]{ SerDes: testSerDes{}, CookieName: "test", } defer func() { if r := recover(); r != nil { if !tt.wantErr { t.Errorf("New() panicked unexpectedly: %v", r) } } else if tt.wantErr { t.Errorf("New() should have panicked but didn't") } }() auth := New(tt.privkey, config) if !tt.wantErr && auth.config.CookieName != "test" { t.Errorf("New() failed to create auth instance properly") } }) } } func TestAuth_Set(t *testing.T) { auth := createTestAuth() data := testData{UserID: 123, Name: "Alice"} w := httptest.NewRecorder() err := auth.Set(w, data) if err != nil { t.Fatalf("Set() error = %v", err) } cookies := w.Result().Cookies() if len(cookies) != 1 { t.Fatalf("Expected 1 cookie, got %d", len(cookies)) } cookie := cookies[0] if cookie.Name != "test_auth" { t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name) } if cookie.Path != "/" { t.Errorf("Expected cookie path '/', got %s", cookie.Path) } if cookie.Domain != "example.com" { t.Errorf("Expected cookie domain 'example.com', got %s", cookie.Domain) } if !cookie.HttpOnly { t.Error("Expected cookie to be HttpOnly") } if cookie.SameSite != http.SameSiteLaxMode { t.Errorf("Expected SameSite to be Lax, got %v", cookie.SameSite) } } func TestAuth_Required(t *testing.T) { auth := createTestAuth() data := testData{UserID: 123, Name: "Alice"} handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { authData, ok := auth.Get(r.Context()) if !ok { t.Error("Expected to find auth data in context") return } if authData.UserID != 123 || authData.Name != "Alice" { t.Errorf("Expected UserID=123, Name=Alice, got UserID=%d, Name=%s", authData.UserID, authData.Name) } w.WriteHeader(http.StatusOK) }) protectedHandler := auth.Required(handler) t.Run("missing cookie", func(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() protectedHandler.ServeHTTP(w, req) if w.Code != http.StatusUnauthorized { t.Errorf("Expected status 401, got %d", w.Code) } }) t.Run("invalid cookie", func(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) req.AddCookie(&http.Cookie{ Name: "test_auth", Value: "invalid_token", }) w := httptest.NewRecorder() protectedHandler.ServeHTTP(w, req) if w.Code != http.StatusUnauthorized { t.Errorf("Expected status 401, got %d", w.Code) } }) t.Run("valid cookie", func(t *testing.T) { setW := httptest.NewRecorder() if err := auth.Set(setW, data); err != nil { t.Fatalf("Set() error = %v", err) } cookie := setW.Result().Cookies()[0] req := httptest.NewRequest("GET", "/", nil) req.AddCookie(cookie) w := httptest.NewRecorder() protectedHandler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("Expected status 200, got %d", w.Code) } }) } func TestAuth_Optional(t *testing.T) { auth := createTestAuth() data := testData{UserID: 123, Name: "Alice"} handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { authData, ok := auth.Get(r.Context()) if ok { if authData.UserID != 123 || authData.Name != "Alice" { t.Errorf("Expected UserID=123, Name=Alice, got UserID=%d, Name=%s", authData.UserID, authData.Name) } } w.WriteHeader(http.StatusOK) }) optionalHandler := auth.Optional(handler) t.Run("missing cookie - should allow", func(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() optionalHandler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("Expected status 200, got %d", w.Code) } }) t.Run("invalid cookie - should allow", func(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) req.AddCookie(&http.Cookie{ Name: "test_auth", Value: "invalid_token", }) w := httptest.NewRecorder() optionalHandler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("Expected status 200, got %d", w.Code) } }) t.Run("valid cookie", func(t *testing.T) { setW := httptest.NewRecorder() if err := auth.Set(setW, data); err != nil { t.Fatalf("Set() error = %v", err) } cookie := setW.Result().Cookies()[0] req := httptest.NewRequest("GET", "/", nil) req.AddCookie(cookie) w := httptest.NewRecorder() optionalHandler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("Expected status 200, got %d", w.Code) } }) } func TestAuth_Get(t *testing.T) { auth := createTestAuth() data := testData{UserID: 123, Name: "Alice"} t.Run("no data in context", func(t *testing.T) { ctx := context.Background() result, ok := auth.Get(ctx) if ok { t.Error("Expected Get to return false for empty context") } if result.UserID != 0 || result.Name != "" { t.Error("Expected zero value for missing data") } }) t.Run("data in context", func(t *testing.T) { ctx := context.WithValue(context.Background(), key, data) result, ok := auth.Get(ctx) if !ok { t.Error("Expected Get to return true for context with data") } if result.UserID != 123 || result.Name != "Alice" { t.Errorf("Expected UserID=123, Name=Alice, got UserID=%d, Name=%s", result.UserID, result.Name) } }) t.Run("wrong type in context", func(t *testing.T) { ctx := context.WithValue(context.Background(), key, "wrong_type") result, ok := auth.Get(ctx) if ok { t.Error("Expected Get to return false for wrong type in context") } if result.UserID != 0 || result.Name != "" { t.Error("Expected zero value for wrong type") } }) } type failingSerDes struct { failSerialize bool failDeserialize bool } func (f failingSerDes) Serialize(w io.Writer, data testData) error { if f.failSerialize { return io.ErrClosedPipe } return testSerDes{}.Serialize(w, data) } func (f failingSerDes) Deserialize(r io.Reader, data *testData) error { if f.failDeserialize { return io.ErrUnexpectedEOF } return testSerDes{}.Deserialize(r, data) } func TestAuth_Set_SerializationError(t *testing.T) { config := AuthConfig[testData]{ SerDes: failingSerDes{failSerialize: true}, CookieName: "test_auth", } auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config) data := testData{UserID: 123, Name: "Alice"} w := httptest.NewRecorder() err := auth.Set(w, data) if err == nil { t.Error("Expected Set to return error for serialization failure") } } func TestAuth_Required_DeserializationError(t *testing.T) { config := AuthConfig[testData]{ SerDes: failingSerDes{failDeserialize: true}, CookieName: "test_auth", } goodAuth := createTestAuth() badAuth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) protectedHandler := badAuth.Required(handler) data := testData{UserID: 123, Name: "Alice"} setW := httptest.NewRecorder() if err := goodAuth.Set(setW, data); err != nil { t.Fatalf("Set() error = %v", err) } cookie := setW.Result().Cookies()[0] req := httptest.NewRequest("GET", "/", nil) req.AddCookie(cookie) w := httptest.NewRecorder() protectedHandler.ServeHTTP(w, req) if w.Code != http.StatusInternalServerError { t.Errorf("Expected status 500 for deserialization error, got %d", w.Code) } } func TestAuth_Optional_DeserializationError(t *testing.T) { config := AuthConfig[testData]{ SerDes: failingSerDes{failDeserialize: true}, CookieName: "test_auth", } goodAuth := createTestAuth() badAuth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) optionalHandler := badAuth.Optional(handler) data := testData{UserID: 123, Name: "Alice"} setW := httptest.NewRecorder() if err := goodAuth.Set(setW, data); err != nil { t.Fatalf("Set() error = %v", err) } cookie := setW.Result().Cookies()[0] req := httptest.NewRequest("GET", "/", nil) req.AddCookie(cookie) w := httptest.NewRecorder() optionalHandler.ServeHTTP(w, req) if w.Code != http.StatusInternalServerError { t.Errorf("Expected status 500 for deserialization error, got %d", w.Code) } }