summaryrefslogtreecommitdiff
path: root/magic_link_login_test.go
diff options
context:
space:
mode:
authorT <t@tjp.lol>2025-06-26 11:42:17 -0600
committerT <t@tjp.lol>2025-07-01 17:50:49 -0600
commit639ad6a02cbb4b713434671ec09f309aa5410921 (patch)
tree7dde9cce8136636d11f2f7c961072984cfc705e7 /magic_link_login_test.go
Create authentic_kate: user authentication for go HTTP applications
Diffstat (limited to 'magic_link_login_test.go')
-rw-r--r--magic_link_login_test.go473
1 files changed, 473 insertions, 0 deletions
diff --git a/magic_link_login_test.go b/magic_link_login_test.go
new file mode 100644
index 0000000..da4bfa6
--- /dev/null
+++ b/magic_link_login_test.go
@@ -0,0 +1,473 @@
+package kate
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+)
+
+// Mock implementation of MagicLinkMailer for testing
+type mockMagicLinkMailer[T any] struct {
+ users map[string]T
+ sentEmails []struct {
+ user T
+ token string
+ }
+ sendEmailFunc func(T, string) error
+}
+
+func newMockMagicLinkMailer[T any](users map[string]T, sendEmailFunc func(T, string) error) *mockMagicLinkMailer[T] {
+ return &mockMagicLinkMailer[T]{
+ users: users,
+ sendEmailFunc: sendEmailFunc,
+ }
+}
+
+func (m *mockMagicLinkMailer[T]) Fetch(username string) (T, bool, error) {
+ user, exists := m.users[username]
+ return user, exists, nil
+}
+
+func (m *mockMagicLinkMailer[T]) SendEmail(userData T, token string) error {
+ m.sentEmails = append(m.sentEmails, struct {
+ user T
+ token string
+ }{userData, token})
+ if m.sendEmailFunc != nil {
+ return m.sendEmailFunc(userData, token)
+ }
+ return nil
+}
+
+type testUser struct {
+ Username string
+ Hash string
+ ID int
+}
+
+type testUserSerDes struct{}
+
+func (ts testUserSerDes) Serialize(w io.Writer, data testUser) error {
+ _, err := fmt.Fprintf(w, "%s|%s|%d", data.Username, data.Hash, data.ID)
+ return err
+}
+
+func (ts testUserSerDes) Deserialize(r io.Reader, data *testUser) error {
+ buf, err := io.ReadAll(r)
+ if err != nil {
+ return err
+ }
+ parts := strings.Split(string(buf), "|")
+ if len(parts) != 3 {
+ return errors.New("invalid format")
+ }
+ data.Username = parts[0]
+ data.Hash = parts[1]
+ id, err := strconv.Atoi(parts[2])
+ if err != nil {
+ return err
+ }
+ data.ID = id
+ return nil
+}
+
+func TestMagicLinkHandler(t *testing.T) {
+ auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{
+ SerDes: testUserSerDes{},
+ CookieName: "test_session",
+ })
+
+ users := map[string]testUser{
+ "john@example.com": {Username: "john@example.com", Hash: "", ID: 1},
+ "jane@example.com": {Username: "jane@example.com", Hash: "", ID: 2},
+ }
+
+ mockMailer := newMockMagicLinkMailer(users, nil)
+
+ config := MagicLinkConfig[testUser]{
+ Mailer: mockMailer,
+ Redirects: Redirects{
+ Default: "/dashboard",
+ AllowedPrefixes: []string{"/app/", "/admin/"},
+ },
+ }
+
+ handler := auth.MagicLinkLoginHandler(config)
+
+ tests := []struct {
+ name string
+ method string
+ formData url.Values
+ expectedStatus int
+ expectedBody string
+ checkEmail bool
+ }{
+ {
+ name: "successful magic link request",
+ method: "POST",
+ formData: url.Values{"email": {"john@example.com"}},
+ expectedStatus: http.StatusOK,
+ expectedBody: "Magic link sent",
+ checkEmail: true,
+ },
+ {
+ name: "magic link with redirect",
+ method: "POST",
+ formData: url.Values{"email": {"john@example.com"}, "redirect": {"/app/settings"}},
+ expectedStatus: http.StatusOK,
+ expectedBody: "Magic link sent",
+ checkEmail: true,
+ },
+ {
+ name: "nonexistent user returns success (no user enumeration)",
+ method: "POST",
+ formData: url.Values{"email": {"nobody@example.com"}},
+ expectedStatus: http.StatusOK,
+ expectedBody: "Magic link sent",
+ checkEmail: false,
+ },
+ {
+ name: "missing email",
+ method: "POST",
+ formData: url.Values{},
+ expectedStatus: http.StatusBadRequest,
+ checkEmail: false,
+ },
+ {
+ name: "GET method not allowed",
+ method: "GET",
+ formData: url.Values{},
+ expectedStatus: http.StatusMethodNotAllowed,
+ checkEmail: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockMailer.sentEmails = nil // Reset
+
+ var req *http.Request
+ if tt.method == "POST" {
+ req = httptest.NewRequest(tt.method, "/magic-link", strings.NewReader(tt.formData.Encode()))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ } else {
+ req = httptest.NewRequest(tt.method, "/magic-link", nil)
+ }
+
+ rr := httptest.NewRecorder()
+ handler.ServeHTTP(rr, req)
+
+ if rr.Code != tt.expectedStatus {
+ t.Errorf("expected status %d, got %d", tt.expectedStatus, rr.Code)
+ }
+
+ if tt.expectedBody != "" {
+ body := strings.TrimSpace(rr.Body.String())
+ if body != tt.expectedBody {
+ t.Errorf("expected body %q, got %q", tt.expectedBody, body)
+ }
+ }
+
+ if tt.checkEmail {
+ if len(mockMailer.sentEmails) != 1 {
+ t.Errorf("expected 1 email sent, got %d", len(mockMailer.sentEmails))
+ } else {
+ email := mockMailer.sentEmails[0]
+ if email.token == "" {
+ t.Error("expected non-empty token")
+ }
+ }
+ } else {
+ if len(mockMailer.sentEmails) != 0 {
+ t.Errorf("expected no emails sent, got %d", len(mockMailer.sentEmails))
+ }
+ }
+ })
+ }
+}
+
+func TestMagicLinkVerifyHandler(t *testing.T) {
+ auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{
+ SerDes: testUserSerDes{},
+ CookieName: "test_session",
+ })
+
+ users := map[string]testUser{
+ "john@example.com": {Username: "john@example.com", Hash: "", ID: 1},
+ }
+
+ mockMailer := newMockMagicLinkMailer(users, nil)
+
+ config := MagicLinkConfig[testUser]{
+ Mailer: mockMailer,
+ Redirects: Redirects{Default: "/dashbaord"},
+ TokenExpiry: time.Minute,
+ }
+
+ token := magicLinkToken{
+ Username: "john@example.com",
+ Redirect: "/app/settings",
+ ExpiresAt: time.Now().Add(time.Minute),
+ }
+ tokenData := []byte(token.serialize())
+ validToken := auth.enc.Encrypt(tokenData)
+
+ expiredToken := magicLinkToken{
+ Username: "john@example.com",
+ Redirect: "/app/settings",
+ ExpiresAt: time.Now().Add(-time.Minute),
+ }
+ expiredTokenData := []byte(expiredToken.serialize())
+ expiredTokenStr := auth.enc.Encrypt(expiredTokenData)
+
+ handler := auth.MagicLinkVerifyHandler(config)
+
+ tests := []struct {
+ name string
+ method string
+ token string
+ expectedStatus int
+ expectedRedirect string
+ checkCookie bool
+ }{
+ {
+ name: "valid token with redirect",
+ method: "GET",
+ token: validToken,
+ expectedStatus: http.StatusSeeOther,
+ expectedRedirect: "/app/settings",
+ checkCookie: true,
+ },
+ {
+ name: "missing token",
+ method: "GET",
+ token: "",
+ expectedStatus: http.StatusBadRequest,
+ checkCookie: false,
+ },
+ {
+ name: "invalid token",
+ method: "GET",
+ token: "invalid-token",
+ expectedStatus: http.StatusUnauthorized,
+ checkCookie: false,
+ },
+ {
+ name: "expired token",
+ method: "GET",
+ token: expiredTokenStr,
+ expectedStatus: http.StatusUnauthorized,
+ checkCookie: false,
+ },
+ {
+ name: "POST method not allowed",
+ method: "POST",
+ token: validToken,
+ expectedStatus: http.StatusMethodNotAllowed,
+ checkCookie: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ url := "/verify"
+ if tt.token != "" {
+ url += "?token=" + tt.token
+ }
+
+ req := httptest.NewRequest(tt.method, url, nil)
+ rr := httptest.NewRecorder()
+ handler.ServeHTTP(rr, req)
+
+ if rr.Code != tt.expectedStatus {
+ t.Errorf("expected status %d, got %d", tt.expectedStatus, rr.Code)
+ }
+
+ if tt.expectedRedirect != "" {
+ location := rr.Header().Get("Location")
+ if location != tt.expectedRedirect {
+ t.Errorf("expected redirect to %s, got %s", tt.expectedRedirect, location)
+ }
+ }
+
+ if tt.checkCookie {
+ cookies := rr.Result().Cookies()
+ found := false
+ for _, cookie := range cookies {
+ if cookie.Name == "test_session" && cookie.Value != "" {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Error("expected authentication cookie to be set")
+ }
+ }
+ })
+ }
+}
+
+func TestMagicLinkConfigDefaults(t *testing.T) {
+ config := MagicLinkConfig[testUser]{}
+ config.setDefaults()
+
+ if config.UsernameField != "email" {
+ t.Errorf("expected UsernameField to be 'email', got %s", config.UsernameField)
+ }
+ if config.TokenField != "token" {
+ t.Errorf("expected TokenField to be 'token', got %s", config.TokenField)
+ }
+ if config.TokenExpiry != 15*time.Minute {
+ t.Errorf("expected TokenExpiry to be 15 minutes, got %v", config.TokenExpiry)
+ }
+ if config.TokenLocation != TokenLocationQuery {
+ t.Errorf("expected TokenLocation to be TokenLocationQuery, got %v", config.TokenLocation)
+ }
+}
+
+func TestMagicLinkTokenSerialization(t *testing.T) {
+ now := time.Now().Truncate(time.Second) // Remove nanoseconds for consistent comparison
+
+ tests := []struct {
+ name string
+ token magicLinkToken
+ expected string
+ }{
+ {
+ name: "simple token",
+ token: magicLinkToken{
+ Redirect: "/dashboard",
+ ExpiresAt: now,
+ Username: "user@example.com",
+ },
+ expected: "/dashboard\x00" + now.Format(time.RFC3339) + "\x00user@example.com",
+ },
+ {
+ name: "redirect with pipes",
+ token: magicLinkToken{
+ Redirect: "/app|section|page",
+ ExpiresAt: now,
+ Username: "user@example.com",
+ },
+ expected: "/app|section|page\x00" + now.Format(time.RFC3339) + "\x00user@example.com",
+ },
+ {
+ name: "empty redirect",
+ token: magicLinkToken{
+ Redirect: "",
+ ExpiresAt: now,
+ Username: "user@example.com",
+ },
+ expected: "\x00" + now.Format(time.RFC3339) + "\x00user@example.com",
+ },
+ {
+ name: "username with pipes",
+ token: magicLinkToken{
+ Redirect: "/dashboard",
+ ExpiresAt: now,
+ Username: "user|with|pipes@example.com",
+ },
+ expected: "/dashboard\x00" + now.Format(time.RFC3339) + "\x00user|with|pipes@example.com",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ serialized := tt.token.serialize()
+ if serialized != tt.expected {
+ t.Errorf("serialize() = %q, want %q", serialized, tt.expected)
+ }
+
+ parsed, err := parseMagicLinkToken(serialized)
+ if err != nil {
+ t.Errorf("parseMagicLinkToken() error = %v", err)
+ return
+ }
+
+ if parsed.Redirect != tt.token.Redirect {
+ t.Errorf("parsed Redirect = %q, want %q", parsed.Redirect, tt.token.Redirect)
+ }
+ if parsed.Username != tt.token.Username {
+ t.Errorf("parsed Username = %q, want %q", parsed.Username, tt.token.Username)
+ }
+ if !parsed.ExpiresAt.Equal(tt.token.ExpiresAt) {
+ t.Errorf("parsed ExpiresAt = %v, want %v", parsed.ExpiresAt, tt.token.ExpiresAt)
+ }
+ })
+ }
+}
+
+func TestParseMagicLinkTokenErrors(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ }{
+ {"empty string", ""},
+ {"single part", "onlyonepart"},
+ {"two parts", "two\x00parts"},
+ {"invalid time", "redirect\x00invalid-time\x00username"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ _, err := parseMagicLinkToken(tt.input)
+ if err == nil {
+ t.Errorf("parseMagicLinkToken(%q) expected error, got nil", tt.input)
+ }
+ })
+ }
+}
+
+func TestMagicLinkVerifyHandlerPathToken(t *testing.T) {
+ auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{
+ SerDes: testUserSerDes{},
+ CookieName: "test_session",
+ })
+
+ users := map[string]testUser{
+ "test": {Username: "test", ID: 1},
+ }
+
+ mockMailer := newMockMagicLinkMailer(users, nil)
+
+ config := MagicLinkConfig[testUser]{
+ Mailer: mockMailer,
+ Redirects: Redirects{Default: "/dashboard"},
+ TokenExpiry: time.Minute,
+ TokenLocation: TokenLocationPath,
+ TokenField: "token",
+ }
+
+ handler := auth.MagicLinkVerifyHandler(config)
+
+ if config.TokenLocation != TokenLocationPath {
+ t.Errorf("expected TokenLocation to be TokenLocationPath, got %v", config.TokenLocation)
+ }
+ if config.TokenField != "token" {
+ t.Errorf("expected TokenField to be 'token', got %s", config.TokenField)
+ }
+
+ req := httptest.NewRequest("GET", "/verify", nil)
+ rr := httptest.NewRecorder()
+ handler.ServeHTTP(rr, req)
+
+ if rr.Code != http.StatusBadRequest {
+ t.Errorf("expected status %d for missing token, got %d", http.StatusBadRequest, rr.Code)
+ }
+}
+
+func TestTokenLocationEnum(t *testing.T) {
+ if TokenLocationQuery.location != 0 {
+ t.Errorf("expected TokenLocationQuery to be 0, got %d", TokenLocationQuery.location)
+ }
+ if TokenLocationPath.location != 1 {
+ t.Errorf("expected TokenLocationPath to be 1, got %d", TokenLocationPath.location)
+ }
+}