package kate import ( "errors" "fmt" "io" "net/http" "net/http/httptest" "net/url" "strconv" "strings" "testing" "time" ) // Mock implementation of MagicLinkMailer for testing type mockMagicLinkMailer[T any] struct { users map[string]T sentEmails []struct { user T token string } sendEmailFunc func(T, string) error } func newMockMagicLinkMailer[T any](users map[string]T, sendEmailFunc func(T, string) error) *mockMagicLinkMailer[T] { return &mockMagicLinkMailer[T]{ users: users, sendEmailFunc: sendEmailFunc, } } func (m *mockMagicLinkMailer[T]) Fetch(username string) (T, bool, error) { user, exists := m.users[username] return user, exists, nil } func (m *mockMagicLinkMailer[T]) SendEmail(userData T, token string) error { m.sentEmails = append(m.sentEmails, struct { user T token string }{userData, token}) if m.sendEmailFunc != nil { return m.sendEmailFunc(userData, token) } return nil } type testUser struct { Username string Hash string ID int } type testUserSerDes struct{} func (ts testUserSerDes) Serialize(w io.Writer, data testUser) error { _, err := fmt.Fprintf(w, "%s|%s|%d", data.Username, data.Hash, data.ID) return err } func (ts testUserSerDes) Deserialize(r io.Reader, data *testUser) error { buf, err := io.ReadAll(r) if err != nil { return err } parts := strings.Split(string(buf), "|") if len(parts) != 3 { return errors.New("invalid format") } data.Username = parts[0] data.Hash = parts[1] id, err := strconv.Atoi(parts[2]) if err != nil { return err } data.ID = id return nil } func TestMagicLinkHandler(t *testing.T) { auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{ SerDes: testUserSerDes{}, CookieName: "test_session", }) users := map[string]testUser{ "john@example.com": {Username: "john@example.com", Hash: "", ID: 1}, "jane@example.com": {Username: "jane@example.com", Hash: "", ID: 2}, } mockMailer := newMockMagicLinkMailer(users, nil) config := MagicLinkConfig[testUser]{ Mailer: mockMailer, Redirects: Redirects{ Default: "/dashboard", AllowedPrefixes: []string{"/app/", "/admin/"}, }, } handler := auth.MagicLinkLoginHandler(config) tests := []struct { name string method string formData url.Values expectedStatus int expectedBody string checkEmail bool }{ { name: "successful magic link request", method: "POST", formData: url.Values{"email": {"john@example.com"}}, expectedStatus: http.StatusOK, expectedBody: "Magic link sent", checkEmail: true, }, { name: "magic link with redirect", method: "POST", formData: url.Values{"email": {"john@example.com"}, "redirect": {"/app/settings"}}, expectedStatus: http.StatusOK, expectedBody: "Magic link sent", checkEmail: true, }, { name: "nonexistent user returns success (no user enumeration)", method: "POST", formData: url.Values{"email": {"nobody@example.com"}}, expectedStatus: http.StatusOK, expectedBody: "Magic link sent", checkEmail: false, }, { name: "missing email", method: "POST", formData: url.Values{}, expectedStatus: http.StatusBadRequest, checkEmail: false, }, { name: "GET method not allowed", method: "GET", formData: url.Values{}, expectedStatus: http.StatusMethodNotAllowed, checkEmail: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockMailer.sentEmails = nil // Reset var req *http.Request if tt.method == "POST" { req = httptest.NewRequest(tt.method, "/magic-link", strings.NewReader(tt.formData.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") } else { req = httptest.NewRequest(tt.method, "/magic-link", nil) } rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != tt.expectedStatus { t.Errorf("expected status %d, got %d", tt.expectedStatus, rr.Code) } if tt.expectedBody != "" { body := strings.TrimSpace(rr.Body.String()) if body != tt.expectedBody { t.Errorf("expected body %q, got %q", tt.expectedBody, body) } } if tt.checkEmail { if len(mockMailer.sentEmails) != 1 { t.Errorf("expected 1 email sent, got %d", len(mockMailer.sentEmails)) } else { email := mockMailer.sentEmails[0] if email.token == "" { t.Error("expected non-empty token") } } } else { if len(mockMailer.sentEmails) != 0 { t.Errorf("expected no emails sent, got %d", len(mockMailer.sentEmails)) } } }) } } func TestMagicLinkVerifyHandler(t *testing.T) { auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{ SerDes: testUserSerDes{}, CookieName: "test_session", }) users := map[string]testUser{ "john@example.com": {Username: "john@example.com", Hash: "", ID: 1}, } mockMailer := newMockMagicLinkMailer(users, nil) config := MagicLinkConfig[testUser]{ Mailer: mockMailer, Redirects: Redirects{Default: "/dashbaord"}, TokenExpiry: time.Minute, } token := magicLinkToken{ Username: "john@example.com", Redirect: "/app/settings", ExpiresAt: time.Now().Add(time.Minute), } tokenData := []byte(token.serialize()) validToken := auth.enc.Encrypt(tokenData) expiredToken := magicLinkToken{ Username: "john@example.com", Redirect: "/app/settings", ExpiresAt: time.Now().Add(-time.Minute), } expiredTokenData := []byte(expiredToken.serialize()) expiredTokenStr := auth.enc.Encrypt(expiredTokenData) handler := auth.MagicLinkVerifyHandler(config) tests := []struct { name string method string token string expectedStatus int expectedRedirect string checkCookie bool }{ { name: "valid token with redirect", method: "GET", token: validToken, expectedStatus: http.StatusSeeOther, expectedRedirect: "/app/settings", checkCookie: true, }, { name: "missing token", method: "GET", token: "", expectedStatus: http.StatusBadRequest, checkCookie: false, }, { name: "invalid token", method: "GET", token: "invalid-token", expectedStatus: http.StatusUnauthorized, checkCookie: false, }, { name: "expired token", method: "GET", token: expiredTokenStr, expectedStatus: http.StatusUnauthorized, checkCookie: false, }, { name: "POST method not allowed", method: "POST", token: validToken, expectedStatus: http.StatusMethodNotAllowed, checkCookie: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { url := "/verify" if tt.token != "" { url += "?token=" + tt.token } req := httptest.NewRequest(tt.method, url, nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != tt.expectedStatus { t.Errorf("expected status %d, got %d", tt.expectedStatus, rr.Code) } if tt.expectedRedirect != "" { location := rr.Header().Get("Location") if location != tt.expectedRedirect { t.Errorf("expected redirect to %s, got %s", tt.expectedRedirect, location) } } if tt.checkCookie { cookies := rr.Result().Cookies() found := false for _, cookie := range cookies { if cookie.Name == "test_session" && cookie.Value != "" { found = true break } } if !found { t.Error("expected authentication cookie to be set") } } }) } } func TestMagicLinkConfigDefaults(t *testing.T) { config := MagicLinkConfig[testUser]{} config.setDefaults() if config.UsernameField != "email" { t.Errorf("expected UsernameField to be 'email', got %s", config.UsernameField) } if config.TokenField != "token" { t.Errorf("expected TokenField to be 'token', got %s", config.TokenField) } if config.TokenExpiry != 15*time.Minute { t.Errorf("expected TokenExpiry to be 15 minutes, got %v", config.TokenExpiry) } if config.TokenLocation != TokenLocationQuery { t.Errorf("expected TokenLocation to be TokenLocationQuery, got %v", config.TokenLocation) } } func TestMagicLinkTokenSerialization(t *testing.T) { now := time.Now().Truncate(time.Second) // Remove nanoseconds for consistent comparison tests := []struct { name string token magicLinkToken expected string }{ { name: "simple token", token: magicLinkToken{ Redirect: "/dashboard", ExpiresAt: now, Username: "user@example.com", }, expected: "/dashboard\x00" + now.Format(time.RFC3339) + "\x00user@example.com", }, { name: "redirect with pipes", token: magicLinkToken{ Redirect: "/app|section|page", ExpiresAt: now, Username: "user@example.com", }, expected: "/app|section|page\x00" + now.Format(time.RFC3339) + "\x00user@example.com", }, { name: "empty redirect", token: magicLinkToken{ Redirect: "", ExpiresAt: now, Username: "user@example.com", }, expected: "\x00" + now.Format(time.RFC3339) + "\x00user@example.com", }, { name: "username with pipes", token: magicLinkToken{ Redirect: "/dashboard", ExpiresAt: now, Username: "user|with|pipes@example.com", }, expected: "/dashboard\x00" + now.Format(time.RFC3339) + "\x00user|with|pipes@example.com", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { serialized := tt.token.serialize() if serialized != tt.expected { t.Errorf("serialize() = %q, want %q", serialized, tt.expected) } parsed, err := parseMagicLinkToken(serialized) if err != nil { t.Errorf("parseMagicLinkToken() error = %v", err) return } if parsed.Redirect != tt.token.Redirect { t.Errorf("parsed Redirect = %q, want %q", parsed.Redirect, tt.token.Redirect) } if parsed.Username != tt.token.Username { t.Errorf("parsed Username = %q, want %q", parsed.Username, tt.token.Username) } if !parsed.ExpiresAt.Equal(tt.token.ExpiresAt) { t.Errorf("parsed ExpiresAt = %v, want %v", parsed.ExpiresAt, tt.token.ExpiresAt) } }) } } func TestParseMagicLinkTokenErrors(t *testing.T) { tests := []struct { name string input string }{ {"empty string", ""}, {"single part", "onlyonepart"}, {"two parts", "two\x00parts"}, {"invalid time", "redirect\x00invalid-time\x00username"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { _, err := parseMagicLinkToken(tt.input) if err == nil { t.Errorf("parseMagicLinkToken(%q) expected error, got nil", tt.input) } }) } } func TestMagicLinkVerifyHandlerPathToken(t *testing.T) { auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{ SerDes: testUserSerDes{}, CookieName: "test_session", }) users := map[string]testUser{ "test": {Username: "test", ID: 1}, } mockMailer := newMockMagicLinkMailer(users, nil) config := MagicLinkConfig[testUser]{ Mailer: mockMailer, Redirects: Redirects{Default: "/dashboard"}, TokenExpiry: time.Minute, TokenLocation: TokenLocationPath, TokenField: "token", } handler := auth.MagicLinkVerifyHandler(config) if config.TokenLocation != TokenLocationPath { t.Errorf("expected TokenLocation to be TokenLocationPath, got %v", config.TokenLocation) } if config.TokenField != "token" { t.Errorf("expected TokenField to be 'token', got %s", config.TokenField) } req := httptest.NewRequest("GET", "/verify", nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusBadRequest { t.Errorf("expected status %d for missing token, got %d", http.StatusBadRequest, rr.Code) } } func TestTokenLocationEnum(t *testing.T) { if TokenLocationQuery.location != 0 { t.Errorf("expected TokenLocationQuery to be 0, got %d", TokenLocationQuery.location) } if TokenLocationPath.location != 1 { t.Errorf("expected TokenLocationPath to be 1, got %d", TokenLocationPath.location) } }