diff options
author | T <t@tjp.lol> | 2025-06-26 11:42:17 -0600 |
---|---|---|
committer | T <t@tjp.lol> | 2025-07-01 17:50:49 -0600 |
commit | 639ad6a02cbb4b713434671ec09f309aa5410921 (patch) | |
tree | 7dde9cce8136636d11f2f7c961072984cfc705e7 /magic_link_login_test.go |
Create authentic_kate: user authentication for go HTTP applications
Diffstat (limited to 'magic_link_login_test.go')
-rw-r--r-- | magic_link_login_test.go | 473 |
1 files changed, 473 insertions, 0 deletions
diff --git a/magic_link_login_test.go b/magic_link_login_test.go new file mode 100644 index 0000000..da4bfa6 --- /dev/null +++ b/magic_link_login_test.go @@ -0,0 +1,473 @@ +package kate + +import ( + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "strings" + "testing" + "time" +) + +// Mock implementation of MagicLinkMailer for testing +type mockMagicLinkMailer[T any] struct { + users map[string]T + sentEmails []struct { + user T + token string + } + sendEmailFunc func(T, string) error +} + +func newMockMagicLinkMailer[T any](users map[string]T, sendEmailFunc func(T, string) error) *mockMagicLinkMailer[T] { + return &mockMagicLinkMailer[T]{ + users: users, + sendEmailFunc: sendEmailFunc, + } +} + +func (m *mockMagicLinkMailer[T]) Fetch(username string) (T, bool, error) { + user, exists := m.users[username] + return user, exists, nil +} + +func (m *mockMagicLinkMailer[T]) SendEmail(userData T, token string) error { + m.sentEmails = append(m.sentEmails, struct { + user T + token string + }{userData, token}) + if m.sendEmailFunc != nil { + return m.sendEmailFunc(userData, token) + } + return nil +} + +type testUser struct { + Username string + Hash string + ID int +} + +type testUserSerDes struct{} + +func (ts testUserSerDes) Serialize(w io.Writer, data testUser) error { + _, err := fmt.Fprintf(w, "%s|%s|%d", data.Username, data.Hash, data.ID) + return err +} + +func (ts testUserSerDes) Deserialize(r io.Reader, data *testUser) error { + buf, err := io.ReadAll(r) + if err != nil { + return err + } + parts := strings.Split(string(buf), "|") + if len(parts) != 3 { + return errors.New("invalid format") + } + data.Username = parts[0] + data.Hash = parts[1] + id, err := strconv.Atoi(parts[2]) + if err != nil { + return err + } + data.ID = id + return nil +} + +func TestMagicLinkHandler(t *testing.T) { + auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{ + SerDes: testUserSerDes{}, + CookieName: "test_session", + }) + + users := map[string]testUser{ + "john@example.com": {Username: "john@example.com", Hash: "", ID: 1}, + "jane@example.com": {Username: "jane@example.com", Hash: "", ID: 2}, + } + + mockMailer := newMockMagicLinkMailer(users, nil) + + config := MagicLinkConfig[testUser]{ + Mailer: mockMailer, + Redirects: Redirects{ + Default: "/dashboard", + AllowedPrefixes: []string{"/app/", "/admin/"}, + }, + } + + handler := auth.MagicLinkLoginHandler(config) + + tests := []struct { + name string + method string + formData url.Values + expectedStatus int + expectedBody string + checkEmail bool + }{ + { + name: "successful magic link request", + method: "POST", + formData: url.Values{"email": {"john@example.com"}}, + expectedStatus: http.StatusOK, + expectedBody: "Magic link sent", + checkEmail: true, + }, + { + name: "magic link with redirect", + method: "POST", + formData: url.Values{"email": {"john@example.com"}, "redirect": {"/app/settings"}}, + expectedStatus: http.StatusOK, + expectedBody: "Magic link sent", + checkEmail: true, + }, + { + name: "nonexistent user returns success (no user enumeration)", + method: "POST", + formData: url.Values{"email": {"nobody@example.com"}}, + expectedStatus: http.StatusOK, + expectedBody: "Magic link sent", + checkEmail: false, + }, + { + name: "missing email", + method: "POST", + formData: url.Values{}, + expectedStatus: http.StatusBadRequest, + checkEmail: false, + }, + { + name: "GET method not allowed", + method: "GET", + formData: url.Values{}, + expectedStatus: http.StatusMethodNotAllowed, + checkEmail: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockMailer.sentEmails = nil // Reset + + var req *http.Request + if tt.method == "POST" { + req = httptest.NewRequest(tt.method, "/magic-link", strings.NewReader(tt.formData.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + } else { + req = httptest.NewRequest(tt.method, "/magic-link", nil) + } + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != tt.expectedStatus { + t.Errorf("expected status %d, got %d", tt.expectedStatus, rr.Code) + } + + if tt.expectedBody != "" { + body := strings.TrimSpace(rr.Body.String()) + if body != tt.expectedBody { + t.Errorf("expected body %q, got %q", tt.expectedBody, body) + } + } + + if tt.checkEmail { + if len(mockMailer.sentEmails) != 1 { + t.Errorf("expected 1 email sent, got %d", len(mockMailer.sentEmails)) + } else { + email := mockMailer.sentEmails[0] + if email.token == "" { + t.Error("expected non-empty token") + } + } + } else { + if len(mockMailer.sentEmails) != 0 { + t.Errorf("expected no emails sent, got %d", len(mockMailer.sentEmails)) + } + } + }) + } +} + +func TestMagicLinkVerifyHandler(t *testing.T) { + auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{ + SerDes: testUserSerDes{}, + CookieName: "test_session", + }) + + users := map[string]testUser{ + "john@example.com": {Username: "john@example.com", Hash: "", ID: 1}, + } + + mockMailer := newMockMagicLinkMailer(users, nil) + + config := MagicLinkConfig[testUser]{ + Mailer: mockMailer, + Redirects: Redirects{Default: "/dashbaord"}, + TokenExpiry: time.Minute, + } + + token := magicLinkToken{ + Username: "john@example.com", + Redirect: "/app/settings", + ExpiresAt: time.Now().Add(time.Minute), + } + tokenData := []byte(token.serialize()) + validToken := auth.enc.Encrypt(tokenData) + + expiredToken := magicLinkToken{ + Username: "john@example.com", + Redirect: "/app/settings", + ExpiresAt: time.Now().Add(-time.Minute), + } + expiredTokenData := []byte(expiredToken.serialize()) + expiredTokenStr := auth.enc.Encrypt(expiredTokenData) + + handler := auth.MagicLinkVerifyHandler(config) + + tests := []struct { + name string + method string + token string + expectedStatus int + expectedRedirect string + checkCookie bool + }{ + { + name: "valid token with redirect", + method: "GET", + token: validToken, + expectedStatus: http.StatusSeeOther, + expectedRedirect: "/app/settings", + checkCookie: true, + }, + { + name: "missing token", + method: "GET", + token: "", + expectedStatus: http.StatusBadRequest, + checkCookie: false, + }, + { + name: "invalid token", + method: "GET", + token: "invalid-token", + expectedStatus: http.StatusUnauthorized, + checkCookie: false, + }, + { + name: "expired token", + method: "GET", + token: expiredTokenStr, + expectedStatus: http.StatusUnauthorized, + checkCookie: false, + }, + { + name: "POST method not allowed", + method: "POST", + token: validToken, + expectedStatus: http.StatusMethodNotAllowed, + checkCookie: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + url := "/verify" + if tt.token != "" { + url += "?token=" + tt.token + } + + req := httptest.NewRequest(tt.method, url, nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != tt.expectedStatus { + t.Errorf("expected status %d, got %d", tt.expectedStatus, rr.Code) + } + + if tt.expectedRedirect != "" { + location := rr.Header().Get("Location") + if location != tt.expectedRedirect { + t.Errorf("expected redirect to %s, got %s", tt.expectedRedirect, location) + } + } + + if tt.checkCookie { + cookies := rr.Result().Cookies() + found := false + for _, cookie := range cookies { + if cookie.Name == "test_session" && cookie.Value != "" { + found = true + break + } + } + if !found { + t.Error("expected authentication cookie to be set") + } + } + }) + } +} + +func TestMagicLinkConfigDefaults(t *testing.T) { + config := MagicLinkConfig[testUser]{} + config.setDefaults() + + if config.UsernameField != "email" { + t.Errorf("expected UsernameField to be 'email', got %s", config.UsernameField) + } + if config.TokenField != "token" { + t.Errorf("expected TokenField to be 'token', got %s", config.TokenField) + } + if config.TokenExpiry != 15*time.Minute { + t.Errorf("expected TokenExpiry to be 15 minutes, got %v", config.TokenExpiry) + } + if config.TokenLocation != TokenLocationQuery { + t.Errorf("expected TokenLocation to be TokenLocationQuery, got %v", config.TokenLocation) + } +} + +func TestMagicLinkTokenSerialization(t *testing.T) { + now := time.Now().Truncate(time.Second) // Remove nanoseconds for consistent comparison + + tests := []struct { + name string + token magicLinkToken + expected string + }{ + { + name: "simple token", + token: magicLinkToken{ + Redirect: "/dashboard", + ExpiresAt: now, + Username: "user@example.com", + }, + expected: "/dashboard\x00" + now.Format(time.RFC3339) + "\x00user@example.com", + }, + { + name: "redirect with pipes", + token: magicLinkToken{ + Redirect: "/app|section|page", + ExpiresAt: now, + Username: "user@example.com", + }, + expected: "/app|section|page\x00" + now.Format(time.RFC3339) + "\x00user@example.com", + }, + { + name: "empty redirect", + token: magicLinkToken{ + Redirect: "", + ExpiresAt: now, + Username: "user@example.com", + }, + expected: "\x00" + now.Format(time.RFC3339) + "\x00user@example.com", + }, + { + name: "username with pipes", + token: magicLinkToken{ + Redirect: "/dashboard", + ExpiresAt: now, + Username: "user|with|pipes@example.com", + }, + expected: "/dashboard\x00" + now.Format(time.RFC3339) + "\x00user|with|pipes@example.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + serialized := tt.token.serialize() + if serialized != tt.expected { + t.Errorf("serialize() = %q, want %q", serialized, tt.expected) + } + + parsed, err := parseMagicLinkToken(serialized) + if err != nil { + t.Errorf("parseMagicLinkToken() error = %v", err) + return + } + + if parsed.Redirect != tt.token.Redirect { + t.Errorf("parsed Redirect = %q, want %q", parsed.Redirect, tt.token.Redirect) + } + if parsed.Username != tt.token.Username { + t.Errorf("parsed Username = %q, want %q", parsed.Username, tt.token.Username) + } + if !parsed.ExpiresAt.Equal(tt.token.ExpiresAt) { + t.Errorf("parsed ExpiresAt = %v, want %v", parsed.ExpiresAt, tt.token.ExpiresAt) + } + }) + } +} + +func TestParseMagicLinkTokenErrors(t *testing.T) { + tests := []struct { + name string + input string + }{ + {"empty string", ""}, + {"single part", "onlyonepart"}, + {"two parts", "two\x00parts"}, + {"invalid time", "redirect\x00invalid-time\x00username"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := parseMagicLinkToken(tt.input) + if err == nil { + t.Errorf("parseMagicLinkToken(%q) expected error, got nil", tt.input) + } + }) + } +} + +func TestMagicLinkVerifyHandlerPathToken(t *testing.T) { + auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{ + SerDes: testUserSerDes{}, + CookieName: "test_session", + }) + + users := map[string]testUser{ + "test": {Username: "test", ID: 1}, + } + + mockMailer := newMockMagicLinkMailer(users, nil) + + config := MagicLinkConfig[testUser]{ + Mailer: mockMailer, + Redirects: Redirects{Default: "/dashboard"}, + TokenExpiry: time.Minute, + TokenLocation: TokenLocationPath, + TokenField: "token", + } + + handler := auth.MagicLinkVerifyHandler(config) + + if config.TokenLocation != TokenLocationPath { + t.Errorf("expected TokenLocation to be TokenLocationPath, got %v", config.TokenLocation) + } + if config.TokenField != "token" { + t.Errorf("expected TokenField to be 'token', got %s", config.TokenField) + } + + req := httptest.NewRequest("GET", "/verify", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusBadRequest { + t.Errorf("expected status %d for missing token, got %d", http.StatusBadRequest, rr.Code) + } +} + +func TestTokenLocationEnum(t *testing.T) { + if TokenLocationQuery.location != 0 { + t.Errorf("expected TokenLocationQuery to be 0, got %d", TokenLocationQuery.location) + } + if TokenLocationPath.location != 1 { + t.Errorf("expected TokenLocationPath to be 1, got %d", TokenLocationPath.location) + } +} |