summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorT <t@tjp.lol>2025-06-26 11:42:17 -0600
committerT <t@tjp.lol>2025-07-01 17:50:49 -0600
commit639ad6a02cbb4b713434671ec09f309aa5410921 (patch)
tree7dde9cce8136636d11f2f7c961072984cfc705e7
Create authentic_kate: user authentication for go HTTP applications
-rw-r--r--README.md203
-rw-r--r--auth.go169
-rw-r--r--auth_test.go391
-rw-r--r--compat-test/README.md44
-rw-r--r--compat-test/compat_test.go167
-rw-r--r--compat-test/go.mod17
-rw-r--r--compat-test/go.sum10
-rw-r--r--doc.go150
-rw-r--r--encryption.go136
-rw-r--r--encryption_test.go151
-rw-r--r--go.mod13
-rw-r--r--go.sum8
-rw-r--r--login_helpers.go64
-rw-r--r--magic_link_login.go233
-rw-r--r--magic_link_login_test.go473
-rw-r--r--oauth_login.go257
-rw-r--r--oauth_login_test.go369
-rw-r--r--password.go131
-rw-r--r--password_login.go105
-rw-r--r--password_login_test.go385
-rw-r--r--password_test.go230
21 files changed, 3706 insertions, 0 deletions
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..d5d7882
--- /dev/null
+++ b/README.md
@@ -0,0 +1,203 @@
+# authentic_kate
+
+Package kate provides secure cookie-based authentication for HTTP applications.
+
+This package implements encrypted cookies with authenticated encryption.
+It supports generic types to allow storage of any serializable data in encrypted HTTP cookies.
+
+## Key features
+
+- Type-safe authentication data storage using Go generics
+- Authenticated encryption for secure cookie storage
+- HTTP middlewares for required and optional authentication
+- Ready-to-use login handlers for password, oauth, and magic link authentication
+- Configurable cookie properties (domain, path, HTTPS-only, max age)
+- Secure password hashing with industry-standard algorithms
+
+## Example usage
+
+```go
+type UserData struct {
+ ID int
+ Name string
+}
+
+// Implement SerDes interface for UserData
+type UserSerDes struct{}
+func (s UserSerDes) Serialize(w io.Writer, data UserData) error { /* ... */ }
+func (s UserSerDes) Deserialize(r io.Reader, data *UserData) error { /* ... */ }
+
+// Create authentication instance
+auth := kate.New("hex-encoded-32-byte-key", kate.AuthConfig[UserData]{
+ SerDes: UserSerDes{},
+ CookieName: "session",
+ HTTPSOnly: true,
+ MaxAge: 24 * time.Hour,
+})
+
+// Use as middleware
+http.Handle("GET /protected", auth.Required(protectedHandler))
+
+// Set authentication cookie
+user := UserData{ID: 123, Name: "John"}
+auth.Set(w, user)
+
+// Get authenticated data
+user, ok := auth.Get(r.Context())
+```
+
+## Login handlers
+
+The package provides ready-to-use HTTP handlers for several authentication methods:
+
+### Password-based login
+
+```go
+type User struct {
+ Username string
+ Hash string
+ ID int
+}
+
+// Implement PasswordUserDataStore interface
+type UserStore struct{}
+
+func (us UserStore) Fetch(username string) (User, bool, error) {
+ // Look up user in your database
+ user, exists := database.FindUser(username)
+ return user, exists, nil
+}
+
+func (us UserStore) GetPassHash(user User) string {
+ return user.Hash
+}
+
+// Configure password login
+passwordConfig := kate.PasswordLoginConfig[User]{
+ UserData: UserStore{},
+ Redirects: kate.Redirects{
+ Default: "/dashboard",
+ AllowedPrefixes: []string{"/app/", "/admin/"},
+ FieldName: "redirect",
+ },
+}
+
+// Create handler
+http.Handle("POST /login", auth.PasswordLoginHandler(passwordConfig))
+```
+
+### Magic link authentication
+
+```go
+// Implement MagicLinkMailer interface
+type UserMailer struct{}
+
+func (um UserMailer) Fetch(email string) (User, bool, error) {
+ // Look up user by email
+ user, exists := database.FindUserByEmail(email)
+ return user, exists, nil
+}
+
+func (um UserMailer) SendEmail(user User, token string) error {
+ // Send magic link email with the token
+ return emailService.SendMagicLink(user.Email, token)
+}
+
+// Configure magic link login
+magicConfig := kate.MagicLinkConfig[User]{
+ Mailer: UserMailer{},
+ Redirects: kate.Redirects{ Default: "/dashboard" },
+ TokenExpiry: 15 * time.Minute,
+ TokenLocation: kate.TokenLocationQuery, // or TokenLocationPath
+}
+
+// Create handlers
+http.Handle("POST /magic-link", auth.MagicLinkLoginHandler(magicConfig))
+http.Handle("GET /verify", auth.MagicLinkVerifyHandler(magicConfig))
+```
+
+### OAuth2 authentication
+
+```go
+// Implement the OAuthDataStore interface
+type UserDataStore struct{}
+
+func (ds UserDataStore) GetOrCreateUser(email string) (User, error) {
+ // Look up user by email, create if doesn't exist
+ if user, exists := database.FindUserByEmail(email); exists {
+ return user, nil
+ }
+ return database.CreateUser(email)
+}
+
+func (ds UserDataStore) StoreState(state kate.OAuthState) (string, error) {
+ // Generate unique state ID and store temporarily
+ id := generateUniqueID()
+ return id, cache.Set(id, state, 10*time.Minute)
+}
+
+func (ds UserDataStore) GetAndClearState(id string) (*kate.OAuthState, error) {
+ // Retrieve and delete state data
+ var state kate.OAuthState
+ exists, err := cache.GetAndDelete(id, &state)
+ if !exists {
+ return nil, err
+ }
+ return &state, err
+}
+
+// Configure Google OAuth2 login
+oauthConfig := kate.GoogleOAuthConfig[User](
+ "your-google-client-id",
+ "your-google-client-secret",
+ "https://yourapp.com/auth/callback",
+ UserDataStore{},
+)
+
+// Customize redirects and security settings
+oauthConfig.Redirects = kate.Redirects{
+ Default: "/dashboard",
+ AllowedPrefixes: []string{"/app/", "/admin/"},
+ FieldName: "redirect",
+}
+
+// Create handlers
+http.Handle("GET /auth/login", auth.OAuthLoginHandler(oauthConfig))
+http.Handle("GET /auth/callback", auth.OAuthCallbackHandler(oauthConfig))
+
+// GitHub OAuth2 is also supported
+githubConfig := kate.GitHubOAuthConfig[User](
+ "your-github-client-id",
+ "your-github-client-secret",
+ "https://yourapp.com/auth/github/callback",
+ UserDataStore{},
+)
+```
+
+## Password hashing
+
+The package provides secure password hashing functions using Argon2id:
+
+```go
+// Hash a password with secure defaults
+hash, err := kate.HashPassword("user-password", nil)
+if err != nil {
+ // Handle error
+}
+
+// Verify a password
+match, err := kate.ComparePassword("user-password", hash)
+if err != nil {
+ // Handle malformed hash error
+}
+if match {
+ // Password is correct
+}
+```
+
+## Cryptographic algorithms
+
+This package uses the following cryptographic algorithms:
+
+- **Token encryption**: XSalsa20 stream cipher with Poly1305 MAC for authenticated encryption (compatible with libsodium secretbox)
+- **Password hashing**: Argon2id with secure default parameters and PHC format storage
diff --git a/auth.go b/auth.go
new file mode 100644
index 0000000..1066d9f
--- /dev/null
+++ b/auth.go
@@ -0,0 +1,169 @@
+package kate
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "time"
+)
+
+// Auth provides secure cookie-based authentication for HTTP applications.
+//
+// It uses generic type T to allow storage of any serializable data in encrypted cookies.
+type Auth[T any] struct {
+ enc encryption
+ config AuthConfig[T]
+}
+
+// AuthConfig holds configuration settings for the Auth instance.
+//
+// It specifies how data is serialized, cookie properties, and security settings.
+type AuthConfig[T any] struct {
+ // SerDes handles serialization and deserialization of authentication data
+ SerDes SerDes[T]
+
+ // CookieName is the name of the HTTP cookie used for authentication
+ CookieName string
+
+ // URLPath restricts the cookie to a specific path on the server (optional)
+ URLPath string
+
+ // URLDomain restricts the cookie to a specific domain (optional)
+ URLDomain string
+
+ // HTTPSOnly when true, requires cookies to be sent only over HTTPS connections
+ HTTPSOnly bool
+
+ // MaxAge determines how long in seconds the authentication cookie remains valid
+ MaxAge time.Duration
+}
+
+// SerDes defines the interface for serializing and deserializing authentication data.
+//
+// Implementations must handle conversion between type T and byte streams.
+type SerDes[T any] interface {
+ // Serialize writes the data of type T to the provided writer
+ Serialize(io.Writer, T) error
+
+ // Deserialize reads data from the reader and populates the provided pointer
+ Deserialize(io.Reader, *T) error
+}
+
+// New creates a new Auth instance with the given private key and configuration.
+//
+// The private key must be a hex-encoded string used for cookie encryption.
+// Panics if the private key is invalid.
+func New[T any](privkey string, config AuthConfig[T]) Auth[T] {
+ enc, err := encryptionFromHexKey(privkey)
+ if err != nil {
+ panic(err.Error())
+ }
+ return Auth[T]{enc: enc, config: config}
+}
+
+// Required is an HTTP middleware that enforces authentication.
+//
+// It checks for a valid authentication cookie and makes the authenticated data
+// available in the request context. Returns 401 Unauthorized if authentication fails.
+func (a Auth[T]) Required(handler http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ cookie, err := r.Cookie(a.config.CookieName)
+ if errors.Is(err, http.ErrNoCookie) {
+ http.Error(w, "Authentication missing", http.StatusUnauthorized)
+ return
+ }
+
+ cleartext, ok := a.enc.Decrypt(cookie.Value)
+ if !ok {
+ http.Error(w, "Authentication failed", http.StatusUnauthorized)
+ return
+ }
+
+ var data T
+ if err := a.config.SerDes.Deserialize(bytes.NewBuffer(cleartext), &data); err != nil {
+ http.Error(w, "Server error", http.StatusInternalServerError)
+ return
+ }
+
+ handler.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), key, data)))
+ })
+}
+
+// Optional returns an HTTP middleware that allows optional authentication.
+//
+// It checks for a valid authentication cookie and makes the authenticated data
+// available in the request context if present. Unlike Required, this middleware
+// allows requests to proceed even when authentication is missing or invalid.
+// Returns 500 Internal Server Error only if deserialization fails on valid authentication data.
+func (a Auth[T]) Optional(handler http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ cookie, err := r.Cookie(a.config.CookieName)
+ if errors.Is(err, http.ErrNoCookie) {
+ handler.ServeHTTP(w, r)
+ return
+ }
+
+ cleartext, ok := a.enc.Decrypt(cookie.Value)
+ if !ok {
+ handler.ServeHTTP(w, r)
+ return
+ }
+
+ var data T
+ if err := a.config.SerDes.Deserialize(bytes.NewBuffer(cleartext), &data); err != nil {
+ http.Error(w, "Server error", http.StatusInternalServerError)
+ return
+ }
+
+ handler.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), key, data)))
+ })
+}
+
+// Set creates and sets an authentication cookie containing the provided data.
+//
+// The data is serialized, encrypted, and stored in an HTTP cookie.
+// Returns an error if serialization fails.
+func (a Auth[T]) Set(w http.ResponseWriter, data T) error {
+ buf := &bytes.Buffer{}
+ if err := a.config.SerDes.Serialize(buf, data); err != nil {
+ return fmt.Errorf("Auth.Set: %w", err)
+ }
+ cookie := &http.Cookie{
+ Name: a.config.CookieName,
+ Value: a.enc.Encrypt(buf.Bytes()),
+ Path: a.config.URLPath,
+ Domain: a.config.URLDomain,
+ MaxAge: int(a.config.MaxAge / time.Second),
+ Secure: a.config.HTTPSOnly,
+ HttpOnly: true,
+ SameSite: http.SameSiteLaxMode,
+ }
+ w.Header().Add("Set-Cookie", cookie.String())
+ return nil
+}
+
+// Get retrieves authentication data from the request context.
+//
+// Returns the data and true if authentication data is present and valid,
+// otherwise returns the zero value and false. Should be called within handlers
+// protected by the Required middleware.
+func (a Auth[T]) Get(ctx context.Context) (T, bool) {
+ var zero T
+ val := ctx.Value(key)
+ if val == nil {
+ return zero, false
+ }
+ switch v := val.(type) {
+ case T:
+ return v, true
+ default:
+ return zero, false
+ }
+}
+
+type keyt struct{}
+
+var key = keyt{}
diff --git a/auth_test.go b/auth_test.go
new file mode 100644
index 0000000..131e132
--- /dev/null
+++ b/auth_test.go
@@ -0,0 +1,391 @@
+package kate
+
+import (
+ "context"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+)
+
+type testData struct {
+ UserID int
+ Name string
+}
+
+type testSerDes struct{}
+
+func (ts testSerDes) Serialize(w io.Writer, data testData) error {
+ _, err := w.Write([]byte(strings.Join([]string{string(rune(data.UserID)), data.Name}, "|")))
+ return err
+}
+
+func (ts testSerDes) Deserialize(r io.Reader, data *testData) error {
+ buf, err := io.ReadAll(r)
+ if err != nil {
+ return err
+ }
+ parts := strings.Split(string(buf), "|")
+ if len(parts) != 2 {
+ return io.ErrUnexpectedEOF
+ }
+ data.UserID = int(rune(parts[0][0]))
+ data.Name = parts[1]
+ return nil
+}
+
+func createTestAuth() Auth[testData] {
+ config := AuthConfig[testData]{
+ SerDes: testSerDes{},
+ CookieName: "test_auth",
+ URLPath: "/",
+ URLDomain: "example.com",
+ HTTPSOnly: false,
+ MaxAge: time.Hour,
+ }
+ return New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config)
+}
+
+func TestNew(t *testing.T) {
+ tests := []struct {
+ name string
+ privkey string
+ wantErr bool
+ }{
+ {
+ name: "valid hex key",
+ privkey: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
+ wantErr: false,
+ },
+ {
+ name: "invalid hex key - too short",
+ privkey: "0123456789abcdef",
+ wantErr: true,
+ },
+ {
+ name: "invalid hex key - not hex",
+ privkey: "invalid_hex_key_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ config := AuthConfig[testData]{
+ SerDes: testSerDes{},
+ CookieName: "test",
+ }
+
+ defer func() {
+ if r := recover(); r != nil {
+ if !tt.wantErr {
+ t.Errorf("New() panicked unexpectedly: %v", r)
+ }
+ } else if tt.wantErr {
+ t.Errorf("New() should have panicked but didn't")
+ }
+ }()
+
+ auth := New(tt.privkey, config)
+ if !tt.wantErr && auth.config.CookieName != "test" {
+ t.Errorf("New() failed to create auth instance properly")
+ }
+ })
+ }
+}
+
+func TestAuth_Set(t *testing.T) {
+ auth := createTestAuth()
+ data := testData{UserID: 123, Name: "Alice"}
+
+ w := httptest.NewRecorder()
+ err := auth.Set(w, data)
+ if err != nil {
+ t.Fatalf("Set() error = %v", err)
+ }
+
+ cookies := w.Result().Cookies()
+ if len(cookies) != 1 {
+ t.Fatalf("Expected 1 cookie, got %d", len(cookies))
+ }
+
+ cookie := cookies[0]
+ if cookie.Name != "test_auth" {
+ t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name)
+ }
+ if cookie.Path != "/" {
+ t.Errorf("Expected cookie path '/', got %s", cookie.Path)
+ }
+ if cookie.Domain != "example.com" {
+ t.Errorf("Expected cookie domain 'example.com', got %s", cookie.Domain)
+ }
+ if !cookie.HttpOnly {
+ t.Error("Expected cookie to be HttpOnly")
+ }
+ if cookie.SameSite != http.SameSiteLaxMode {
+ t.Errorf("Expected SameSite to be Lax, got %v", cookie.SameSite)
+ }
+}
+
+func TestAuth_Required(t *testing.T) {
+ auth := createTestAuth()
+ data := testData{UserID: 123, Name: "Alice"}
+
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ authData, ok := auth.Get(r.Context())
+ if !ok {
+ t.Error("Expected to find auth data in context")
+ return
+ }
+ if authData.UserID != 123 || authData.Name != "Alice" {
+ t.Errorf("Expected UserID=123, Name=Alice, got UserID=%d, Name=%s", authData.UserID, authData.Name)
+ }
+ w.WriteHeader(http.StatusOK)
+ })
+
+ protectedHandler := auth.Required(handler)
+
+ t.Run("missing cookie", func(t *testing.T) {
+ req := httptest.NewRequest("GET", "/", nil)
+ w := httptest.NewRecorder()
+
+ protectedHandler.ServeHTTP(w, req)
+
+ if w.Code != http.StatusUnauthorized {
+ t.Errorf("Expected status 401, got %d", w.Code)
+ }
+ })
+
+ t.Run("invalid cookie", func(t *testing.T) {
+ req := httptest.NewRequest("GET", "/", nil)
+ req.AddCookie(&http.Cookie{
+ Name: "test_auth",
+ Value: "invalid_token",
+ })
+ w := httptest.NewRecorder()
+
+ protectedHandler.ServeHTTP(w, req)
+
+ if w.Code != http.StatusUnauthorized {
+ t.Errorf("Expected status 401, got %d", w.Code)
+ }
+ })
+
+ t.Run("valid cookie", func(t *testing.T) {
+ setW := httptest.NewRecorder()
+ if err := auth.Set(setW, data); err != nil {
+ t.Fatalf("Set() error = %v", err)
+ }
+ cookie := setW.Result().Cookies()[0]
+
+ req := httptest.NewRequest("GET", "/", nil)
+ req.AddCookie(cookie)
+ w := httptest.NewRecorder()
+
+ protectedHandler.ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Errorf("Expected status 200, got %d", w.Code)
+ }
+ })
+}
+
+func TestAuth_Optional(t *testing.T) {
+ auth := createTestAuth()
+ data := testData{UserID: 123, Name: "Alice"}
+
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ authData, ok := auth.Get(r.Context())
+ if ok {
+ if authData.UserID != 123 || authData.Name != "Alice" {
+ t.Errorf("Expected UserID=123, Name=Alice, got UserID=%d, Name=%s", authData.UserID, authData.Name)
+ }
+ }
+ w.WriteHeader(http.StatusOK)
+ })
+
+ optionalHandler := auth.Optional(handler)
+
+ t.Run("missing cookie - should allow", func(t *testing.T) {
+ req := httptest.NewRequest("GET", "/", nil)
+ w := httptest.NewRecorder()
+
+ optionalHandler.ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Errorf("Expected status 200, got %d", w.Code)
+ }
+ })
+
+ t.Run("invalid cookie - should allow", func(t *testing.T) {
+ req := httptest.NewRequest("GET", "/", nil)
+ req.AddCookie(&http.Cookie{
+ Name: "test_auth",
+ Value: "invalid_token",
+ })
+ w := httptest.NewRecorder()
+
+ optionalHandler.ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Errorf("Expected status 200, got %d", w.Code)
+ }
+ })
+
+ t.Run("valid cookie", func(t *testing.T) {
+ setW := httptest.NewRecorder()
+ if err := auth.Set(setW, data); err != nil {
+ t.Fatalf("Set() error = %v", err)
+ }
+ cookie := setW.Result().Cookies()[0]
+
+ req := httptest.NewRequest("GET", "/", nil)
+ req.AddCookie(cookie)
+ w := httptest.NewRecorder()
+
+ optionalHandler.ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Errorf("Expected status 200, got %d", w.Code)
+ }
+ })
+}
+
+func TestAuth_Get(t *testing.T) {
+ auth := createTestAuth()
+ data := testData{UserID: 123, Name: "Alice"}
+
+ t.Run("no data in context", func(t *testing.T) {
+ ctx := context.Background()
+ result, ok := auth.Get(ctx)
+ if ok {
+ t.Error("Expected Get to return false for empty context")
+ }
+ if result.UserID != 0 || result.Name != "" {
+ t.Error("Expected zero value for missing data")
+ }
+ })
+
+ t.Run("data in context", func(t *testing.T) {
+ ctx := context.WithValue(context.Background(), key, data)
+ result, ok := auth.Get(ctx)
+ if !ok {
+ t.Error("Expected Get to return true for context with data")
+ }
+ if result.UserID != 123 || result.Name != "Alice" {
+ t.Errorf("Expected UserID=123, Name=Alice, got UserID=%d, Name=%s", result.UserID, result.Name)
+ }
+ })
+
+ t.Run("wrong type in context", func(t *testing.T) {
+ ctx := context.WithValue(context.Background(), key, "wrong_type")
+ result, ok := auth.Get(ctx)
+ if ok {
+ t.Error("Expected Get to return false for wrong type in context")
+ }
+ if result.UserID != 0 || result.Name != "" {
+ t.Error("Expected zero value for wrong type")
+ }
+ })
+}
+
+type failingSerDes struct {
+ failSerialize bool
+ failDeserialize bool
+}
+
+func (f failingSerDes) Serialize(w io.Writer, data testData) error {
+ if f.failSerialize {
+ return io.ErrClosedPipe
+ }
+ return testSerDes{}.Serialize(w, data)
+}
+
+func (f failingSerDes) Deserialize(r io.Reader, data *testData) error {
+ if f.failDeserialize {
+ return io.ErrUnexpectedEOF
+ }
+ return testSerDes{}.Deserialize(r, data)
+}
+
+func TestAuth_Set_SerializationError(t *testing.T) {
+ config := AuthConfig[testData]{
+ SerDes: failingSerDes{failSerialize: true},
+ CookieName: "test_auth",
+ }
+ auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config)
+
+ data := testData{UserID: 123, Name: "Alice"}
+ w := httptest.NewRecorder()
+
+ err := auth.Set(w, data)
+ if err == nil {
+ t.Error("Expected Set to return error for serialization failure")
+ }
+}
+
+func TestAuth_Required_DeserializationError(t *testing.T) {
+ config := AuthConfig[testData]{
+ SerDes: failingSerDes{failDeserialize: true},
+ CookieName: "test_auth",
+ }
+ goodAuth := createTestAuth()
+ badAuth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config)
+
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ })
+
+ protectedHandler := badAuth.Required(handler)
+
+ data := testData{UserID: 123, Name: "Alice"}
+ setW := httptest.NewRecorder()
+ if err := goodAuth.Set(setW, data); err != nil {
+ t.Fatalf("Set() error = %v", err)
+ }
+ cookie := setW.Result().Cookies()[0]
+
+ req := httptest.NewRequest("GET", "/", nil)
+ req.AddCookie(cookie)
+ w := httptest.NewRecorder()
+
+ protectedHandler.ServeHTTP(w, req)
+
+ if w.Code != http.StatusInternalServerError {
+ t.Errorf("Expected status 500 for deserialization error, got %d", w.Code)
+ }
+}
+
+func TestAuth_Optional_DeserializationError(t *testing.T) {
+ config := AuthConfig[testData]{
+ SerDes: failingSerDes{failDeserialize: true},
+ CookieName: "test_auth",
+ }
+ goodAuth := createTestAuth()
+ badAuth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config)
+
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ })
+
+ optionalHandler := badAuth.Optional(handler)
+
+ data := testData{UserID: 123, Name: "Alice"}
+ setW := httptest.NewRecorder()
+ if err := goodAuth.Set(setW, data); err != nil {
+ t.Fatalf("Set() error = %v", err)
+ }
+ cookie := setW.Result().Cookies()[0]
+
+ req := httptest.NewRequest("GET", "/", nil)
+ req.AddCookie(cookie)
+ w := httptest.NewRecorder()
+
+ optionalHandler.ServeHTTP(w, req)
+
+ if w.Code != http.StatusInternalServerError {
+ t.Errorf("Expected status 500 for deserialization error, got %d", w.Code)
+ }
+}
diff --git a/compat-test/README.md b/compat-test/README.md
new file mode 100644
index 0000000..c1e91b8
--- /dev/null
+++ b/compat-test/README.md
@@ -0,0 +1,44 @@
+# Kate libsodium Compatibility Tests
+
+This directory contains tests to verify that the `kate` package's `seal` and `open` functions are compatible with the real libsodium implementation.
+
+## Why?
+
+The `kate` library is designed to be CGO-free to avoid the complexity and deployment challenges that come with C dependencies. Rather than requiring applications to link against libsodium, `kate` implements a pure Go version of the NaCl secretbox encryption scheme.
+
+This separate compatibility test project verifies that `kate`'s implementation produces ciphertext that is fully interoperable with libsodium while keeping the CGO dependency isolated from the main package. This way, applications get the benefits of a pure Go library while maintaining confidence in libsodium compatibility.
+
+## Setup
+
+The tests require libsodium to be installed on your system. The project uses the `github.com/jamesruan/sodium` Go bindings for libsodium.
+
+### Installing libsodium
+
+On macOS:
+```bash
+brew install libsodium
+```
+
+On Ubuntu/Debian:
+```bash
+sudo apt-get install libsodium-dev
+```
+
+On CentOS/RHEL:
+```bash
+sudo yum install libsodium-devel
+```
+
+## Running the Tests
+
+```bash
+cd compat-test
+go test
+```
+
+## Test Cases
+
+The test suite includes:
+
+1. **KateEncrypt_LibsodiumDecrypt**: Encrypts data with kate's library and decrypts with libsodium
+2. **LibsodiumEncrypt_KateDecrypt**: Encrypts data with libsodium and decrypts with kate's library
diff --git a/compat-test/compat_test.go b/compat-test/compat_test.go
new file mode 100644
index 0000000..2b0bfd7
--- /dev/null
+++ b/compat-test/compat_test.go
@@ -0,0 +1,167 @@
+package main
+
+import (
+ "crypto/rand"
+ "encoding/hex"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "testing/quick"
+
+ "git.tjp.lol/authentic_kate"
+ "github.com/jamesruan/sodium"
+)
+
+type ByteSerDes struct{}
+
+func (ByteSerDes) Serialize(w io.Writer, data []byte) error {
+ _, err := w.Write(data)
+ return err
+}
+
+func (ByteSerDes) Deserialize(r io.Reader, data *[]byte) error {
+ buf, err := io.ReadAll(r)
+ if err != nil {
+ return err
+ }
+ *data = buf
+ return nil
+}
+
+func kateSeal(key [32]byte, plaintext []byte) []byte {
+ keyHex := hex.EncodeToString(key[:])
+
+ auth := kate.New(keyHex, kate.AuthConfig[[]byte]{
+ SerDes: ByteSerDes{},
+ CookieName: "test",
+ })
+
+ w := httptest.NewRecorder()
+ if err := auth.Set(w, plaintext); err != nil {
+ panic(err)
+ }
+
+ cookies := w.Result().Cookies()
+ if len(cookies) == 0 {
+ panic("No cookie set")
+ }
+
+ encryptedBytes, err := hex.DecodeString(cookies[0].Value)
+ if err != nil {
+ panic(err)
+ }
+
+ return encryptedBytes
+}
+
+func libsodiumSeal(key [32]byte, plaintext []byte) []byte {
+ var nonce [24]byte
+ if _, err := rand.Read(nonce[:]); err != nil {
+ panic(err)
+ }
+
+ ciphertextAndMac := sodium.Bytes(plaintext).SecretBox(
+ sodium.SecretBoxNonce{Bytes: nonce[:]},
+ sodium.SecretBoxKey{Bytes: key[:]},
+ )
+
+ result := make([]byte, 24+len(ciphertextAndMac))
+ copy(result[:24], nonce[:])
+ copy(result[24:], ciphertextAndMac)
+
+ return result
+}
+
+func kateOpen(key [32]byte, box []byte) ([]byte, bool) {
+ keyHex := hex.EncodeToString(key[:])
+
+ auth := kate.New(keyHex, kate.AuthConfig[[]byte]{
+ SerDes: ByteSerDes{},
+ CookieName: "test",
+ })
+
+ cookieValue := hex.EncodeToString(box)
+ req := httptest.NewRequest("GET", "/", nil)
+ req.AddCookie(&http.Cookie{Name: "test", Value: cookieValue})
+
+ var decryptedData []byte
+ var success bool
+ handler := auth.Optional(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ data, ok := auth.Get(r.Context())
+ if ok {
+ decryptedData = data
+ success = true
+ }
+ }))
+
+ testW := httptest.NewRecorder()
+ handler.ServeHTTP(testW, req)
+
+ return decryptedData, success
+}
+
+func libsodiumOpen(key [32]byte, box []byte) ([]byte, bool) {
+ if len(box) < 24 {
+ return nil, false
+ }
+
+ nonce := box[:24]
+ ciphertextAndMac := box[24:]
+
+ decrypted, err := sodium.Bytes(ciphertextAndMac).SecretBoxOpen(
+ sodium.SecretBoxNonce{Bytes: nonce},
+ sodium.SecretBoxKey{Bytes: key[:]},
+ )
+ if err != nil {
+ return nil, false
+ }
+
+ return decrypted, true
+}
+
+func TestCrossLibraryCompatibility(t *testing.T) {
+ t.Run("KateEncrypt_LibsodiumDecrypt", func(t *testing.T) {
+ f := func(data []byte) bool {
+ if len(data) == 0 {
+ return true
+ }
+
+ var key [32]byte
+ if _, err := rand.Read(key[:]); err != nil {
+ return false
+ }
+
+ kateBox := kateSeal(key, data)
+ decrypted, ok := libsodiumOpen(key, kateBox)
+
+ return ok && string(decrypted) == string(data)
+ }
+
+ if err := quick.Check(f, nil); err != nil {
+ t.Errorf("Property test failed: %v", err)
+ }
+ })
+
+ t.Run("LibsodiumEncrypt_KateDecrypt", func(t *testing.T) {
+ f := func(data []byte) bool {
+ if len(data) == 0 {
+ return true
+ }
+
+ var key [32]byte
+ if _, err := rand.Read(key[:]); err != nil {
+ return false
+ }
+
+ libsodiumBox := libsodiumSeal(key, data)
+ decrypted, ok := kateOpen(key, libsodiumBox)
+
+ return ok && string(decrypted) == string(data)
+ }
+
+ if err := quick.Check(f, nil); err != nil {
+ t.Errorf("Property test failed: %v", err)
+ }
+ })
+}
diff --git a/compat-test/go.mod b/compat-test/go.mod
new file mode 100644
index 0000000..c7711c9
--- /dev/null
+++ b/compat-test/go.mod
@@ -0,0 +1,17 @@
+module git.tjp.lol/authentic_kate/compat-test
+
+go 1.24.4
+
+require (
+ git.tjp.lol/authentic_kate v0.0.0-00010101000000-000000000000
+ github.com/jamesruan/sodium v1.0.14
+)
+
+require (
+ cloud.google.com/go/compute/metadata v0.3.0 // indirect
+ golang.org/x/crypto v0.39.0 // indirect
+ golang.org/x/oauth2 v0.30.0 // indirect
+ golang.org/x/sys v0.33.0 // indirect
+)
+
+replace git.tjp.lol/authentic_kate => ../
diff --git a/compat-test/go.sum b/compat-test/go.sum
new file mode 100644
index 0000000..c96829d
--- /dev/null
+++ b/compat-test/go.sum
@@ -0,0 +1,10 @@
+cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
+cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
+github.com/jamesruan/sodium v1.0.14 h1:JfOHobip/lUWouxHV3PwYwu3gsLewPrDrZXO3HuBzUU=
+github.com/jamesruan/sodium v1.0.14/go.mod h1:GK2+LACf7kuVQ9k7Irk0MB2B65j5rVqkz+9ylGIggZk=
+golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
+golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
+golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
+golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
+golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
+golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
diff --git a/doc.go b/doc.go
new file mode 100644
index 0000000..8d60f2a
--- /dev/null
+++ b/doc.go
@@ -0,0 +1,150 @@
+// Package kate provides secure cookie-based authentication for HTTP applications.
+//
+// This package implements encrypted cookies with authenticated encryption.
+// It supports generic types to allow storage of any serializable data in encrypted HTTP cookies.
+//
+// Key features:
+// - Type-safe authentication data storage using Go generics
+// - Authenticated encryption for secure cookie storage
+// - HTTP middleware for required and optional authentication
+// - Ready-to-use login handlers for password, oauth, and magic link authentication
+// - Configurable cookie properties (domain, path, HTTPS-only, max age)
+// - Pluggable serialization interface for custom data formats
+// - Secure password hashing with industry-standard algorithms
+//
+// Example usage:
+//
+// type UserData struct {
+// ID int
+// Name string
+// }
+//
+// // Implement SerDes interface for UserData
+// type UserSerDes struct{}
+// func (s UserSerDes) Serialize(w io.Writer, data UserData) error { /* ... */ }
+// func (s UserSerDes) Deserialize(r io.Reader, data *UserData) error { /* ... */ }
+//
+// // Create authentication instance
+// auth := kate.New("hex-encoded-32-byte-key", kate.AuthConfig[UserData]{
+// SerDes: UserSerDes{},
+// CookieName: "session",
+// HTTPSOnly: true,
+// MaxAge: 24 * time.Hour,
+// })
+//
+// // Use as middleware
+// http.Handle("GET /protected", auth.Required(protectedHandler))
+//
+// // Set authentication cookie
+// user := UserData{ID: 123, Name: "John"}
+// auth.Set(w, user)
+//
+// // Get authenticated data
+// user, ok := auth.Get(r.Context())
+//
+// Login handlers:
+//
+// type User struct {
+// Username string
+// Hash string
+// ID int
+// }
+//
+// // Implement PasswordUserDataStore interface
+// type UserStore struct{}
+// func (us UserStore) Fetch(username string) (User, bool, error) {
+// user, exists := database.FindUser(username)
+// return user, exists, nil
+// }
+// func (us UserStore) GetPassHash(user User) string {
+// return user.Hash
+// }
+//
+// // Password-based login
+// passwordConfig := kate.PasswordLoginConfig[User]{
+// UserData: UserStore{},
+// Redirects: kate.Redirects{
+// Default: "/dashboard",
+// AllowedPrefixes: []string{"/app/", "/admin/"},
+// FieldName: "redirect",
+// },
+// }
+// http.Handle("POST /login", auth.PasswordLoginHandler(passwordConfig))
+//
+// // Implement MagicLinkMailer interface
+// type UserMailer struct{}
+// func (um UserMailer) Fetch(email string) (User, bool, error) {
+// user, exists := database.FindUserByEmail(email)
+// return user, exists, nil
+// }
+// func (um UserMailer) SendEmail(user User, token string) error {
+// return emailService.SendMagicLink(user.Email, token)
+// }
+//
+// // Magic link authentication
+// magicConfig := kate.MagicLinkConfig[User]{
+// Mailer: UserMailer{},
+// Redirects: kate.Redirects{ Default: "/dashboard" },
+// TokenExpiry: 15 * time.Minute,
+// TokenLocation: kate.TokenLocationQuery,
+// }
+// http.Handle("POST /magic-link", auth.MagicLinkLoginHandler(magicConfig))
+// http.Handle("GET /verify", auth.MagicLinkVerifyHandler(magicConfig))
+//
+// // OAuth2 authentication (Google example)
+// type UserDataStore struct{}
+// func (ds UserDataStore) GetOrCreateUser(email string) (User, error) {
+// user, exists := database.FindUserByEmail(email)
+// if exists {
+// return user, nil
+// }
+// return database.CreateUser(email)
+// }
+// func (ds UserDataStore) StoreState(state kate.OAuthState) (string, error) {
+// id := generateUniqueID()
+// return id, cache.Set(id, state, 10*time.Minute)
+// }
+// func (ds UserDataStore) GetAndClearState(id string) (*kate.OAuthState, error) {
+// var state kate.OAuthState
+// exists, err := cache.GetAndDelete(id, &state)
+// if !exists {
+// return nil, err
+// }
+// return &state, err
+// }
+//
+// oauthConfig := kate.GoogleOAuthConfig[User](
+// "your-client-id",
+// "your-client-secret",
+// "https://yourapp.com/auth/callback",
+// UserDataStore{},
+// )
+// // Customize redirects if needed
+// oauthConfig.Redirects = kate.Redirects{
+// Default: "/dashboard",
+// AllowedPrefixes: []string{"/app/", "/admin/"},
+// FieldName: "redirect",
+// }
+// http.Handle("GET /auth/login", auth.OAuthLoginHandler(oauthConfig))
+// http.Handle("GET /auth/callback", auth.OAuthCallbackHandler(oauthConfig))
+//
+// Password hashing:
+//
+// // Hash a password with defaults
+// hash, err := kate.HashPassword("user-password", nil)
+//
+// // Verify a password
+// match, err := kate.ComparePassword("user-password", hash)
+// if err != nil {
+// // Handle malformed hash error
+// }
+// if match {
+// // Password is correct
+// }
+//
+// Cryptographic algorithms:
+//
+// This package uses the following cryptographic algorithms:
+// - Token encryption: XSalsa20 stream cipher with Poly1305 MAC for authenticated encryption (compatible with libsodium secretbox)
+// - Password hashing: Argon2id with secure default parameters and PHC format storage
+package kate
diff --git a/encryption.go b/encryption.go
new file mode 100644
index 0000000..a445668
--- /dev/null
+++ b/encryption.go
@@ -0,0 +1,136 @@
+package kate
+
+import (
+ "crypto/rand"
+ "encoding/hex"
+ "errors"
+ "fmt"
+
+ "golang.org/x/crypto/poly1305" //nolint:staticcheck
+ "golang.org/x/crypto/salsa20"
+)
+
+type encryption struct {
+ Key [naclKeyLen]byte
+}
+
+var ErrInvalidToken = errors.New("invalid token")
+
+func encryptionFromHexKey(key string) (encryption, error) {
+ e := encryption{}
+ l := hex.EncodedLen(naclKeyLen)
+ keybytes := []byte(key)
+ if l != len(keybytes) {
+ return e, fmt.Errorf("expected %d length encryption key, got %d", l, len(keybytes))
+ }
+ if _, err := hex.Decode(e.Key[:], keybytes); err != nil {
+ return e, fmt.Errorf("encryptionFromHexKey: %w", err)
+ }
+ return e, nil
+}
+
+func (e encryption) Encrypt(data []byte) string {
+ buf := make([]byte, naclNonceLen+len(data)+naclMacLen)
+ nonce := [naclNonceLen]byte{}
+
+ _, _ = rand.Read(nonce[:])
+
+ seal(buf[naclNonceLen:], e.Key, nonce, data)
+
+ copy(buf[:naclNonceLen], nonce[:])
+ return hex.EncodeToString(buf)
+}
+
+func (e encryption) Decrypt(token string) ([]byte, bool) {
+ tokenBytes := []byte(token)
+ ciphertext := make([]byte, hex.DecodedLen(len(tokenBytes)))
+ if _, err := hex.Decode(ciphertext, tokenBytes); err != nil {
+ return nil, false
+ }
+
+ nonce := ciphertext[:naclNonceLen]
+ noncebuf := [naclNonceLen]byte{}
+ copy(noncebuf[:], nonce)
+ ciphertext = ciphertext[naclNonceLen:]
+
+ cleartext := make([]byte, len(ciphertext)-naclMacLen)
+ if _, ok := open(cleartext, e.Key, noncebuf, ciphertext); !ok {
+ return nil, false
+ }
+ return cleartext, true
+}
+
+const (
+ naclKeyLen = 32
+ naclNonceLen = 24
+ naclMacLen = 16
+)
+
+//
+// Replicating the libsodium secretbox API below.
+//
+// This was a last resort. I really didn't want to introduce CGO,
+// and the only pure-go implementation I could find (libgodium)
+// is unmaintained and riddled with bugs.
+//
+// With solid go implementations of the core primitives already in
+// golang.org/x/crypto, it made the most sense to just wrap those.
+//
+
+func seal(dst []byte, key [naclKeyLen]byte, nonce [naclNonceLen]byte, plaintext []byte) []byte {
+ if len(dst) < len(plaintext)+naclMacLen {
+ dst = make([]byte, len(plaintext)+naclMacLen)
+ }
+
+ keystream := make([]byte, naclKeyLen+len(plaintext))
+ salsa20.XORKeyStream(keystream, keystream, nonce[:], &key)
+
+ var poly1305Key [naclKeyLen]byte
+ copy(poly1305Key[:], keystream[:naclKeyLen])
+
+ ciphertext := make([]byte, len(plaintext))
+ for i := range plaintext {
+ ciphertext[i] = plaintext[i] ^ keystream[naclKeyLen+i]
+ }
+
+ var mac [naclMacLen]byte
+ poly1305.Sum(&mac, ciphertext, &poly1305Key)
+
+ copy(dst[:naclMacLen], mac[:])
+ copy(dst[naclMacLen:], ciphertext)
+
+ return dst
+}
+
+func open(dst []byte, key [naclKeyLen]byte, nonce [naclNonceLen]byte, box []byte) ([]byte, bool) {
+ if len(box) < naclMacLen {
+ return nil, false
+ }
+
+ if len(dst) < len(box)-naclMacLen {
+ dst = make([]byte, len(box)-naclMacLen)
+ }
+
+ var receivedMAC [naclMacLen]byte
+ copy(receivedMAC[:], box[:naclMacLen])
+ ciphertext := box[naclMacLen:]
+
+ keystream := make([]byte, naclKeyLen+len(ciphertext))
+ salsa20.XORKeyStream(keystream, keystream, nonce[:], &key)
+
+ var poly1305Key [naclKeyLen]byte
+ copy(poly1305Key[:], keystream[:naclKeyLen])
+
+ var expectedMAC [naclMacLen]byte
+ poly1305.Sum(&expectedMAC, ciphertext, &poly1305Key)
+
+ if !poly1305.Verify(&receivedMAC, ciphertext, &poly1305Key) {
+ return nil, false
+ }
+
+ for i := range ciphertext {
+ dst[i] = ciphertext[i] ^ keystream[naclKeyLen+i]
+ }
+
+ return dst, true
+}
diff --git a/encryption_test.go b/encryption_test.go
new file mode 100644
index 0000000..fcec058
--- /dev/null
+++ b/encryption_test.go
@@ -0,0 +1,151 @@
+package kate
+
+import (
+ "bytes"
+ "crypto/rand"
+ "fmt"
+ "testing"
+ "testing/quick"
+)
+
+func TestEncryptDecryptRoundtrip(t *testing.T) {
+ property := func(data []byte) bool {
+ var key [naclKeyLen]byte
+ if _, err := rand.Read(key[:]); err != nil {
+ return false
+ }
+
+ encryption := encryption{Key: key}
+ token := encryption.Encrypt(data)
+ decryptedData, success := encryption.Decrypt(token)
+
+ if !bytes.Equal(decryptedData, data) {
+ fmt.Printf("expected %v, got %v\n", data, decryptedData)
+ }
+ return success && bytes.Equal(decryptedData, data)
+ }
+
+ if err := quick.Check(property, nil); err != nil {
+ t.Errorf("Encrypt/Decrypt roundtrip property failed: %v", err)
+ }
+}
+
+func TestEncryptionFromHexKey(t *testing.T) {
+ tests := []struct {
+ name string
+ key string
+ wantErr bool
+ }{
+ {
+ name: "valid hex key",
+ key: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
+ wantErr: false,
+ },
+ {
+ name: "key too short",
+ key: "0123456789abcdef",
+ wantErr: true,
+ },
+ {
+ name: "key too long",
+ key: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef00",
+ wantErr: true,
+ },
+ {
+ name: "invalid hex characters",
+ key: "0123456789abcdefghijklmnopqrstuvwxyz0123456789abcdef0123456789abcdef",
+ wantErr: true,
+ },
+ {
+ name: "empty key",
+ key: "",
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if _, err := encryptionFromHexKey(tt.key); (err != nil) != tt.wantErr {
+ t.Errorf("encryptionFromHexKey() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ })
+ }
+}
+
+func TestDecryptInvalidTokens(t *testing.T) {
+ var key [naclKeyLen]byte
+ if _, err := rand.Read(key[:]); err != nil {
+ t.Fatal(err)
+ }
+ enc := encryption{Key: key}
+
+ tests := []struct {
+ name string
+ token string
+ }{
+ {
+ name: "invalid hex token",
+ token: "invalid_hex_string",
+ },
+ {
+ name: "token with wrong MAC",
+ token: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ data, ok := enc.Decrypt(tt.token)
+ if ok {
+ t.Errorf("Decrypt() should have failed for %s, but returned data: %v", tt.name, data)
+ }
+ })
+ }
+}
+
+func TestEncryptEmptyData(t *testing.T) {
+ var key [naclKeyLen]byte
+ if _, err := rand.Read(key[:]); err != nil {
+ t.Fatal(err)
+ }
+ enc := encryption{Key: key}
+
+ token := enc.Encrypt([]byte{})
+ if token == "" {
+ t.Error("Encrypt() should return non-empty token even for empty data")
+ }
+
+ data, ok := enc.Decrypt(token)
+ if !ok {
+ t.Error("Decrypt() should succeed for encrypted empty data")
+ }
+ if len(data) != 0 {
+ t.Errorf("Decrypt() should return empty data, got %v", data)
+ }
+}
+
+func TestEncryptLargeData(t *testing.T) {
+ var key [naclKeyLen]byte
+ if _, err := rand.Read(key[:]); err != nil {
+ t.Fatal(err)
+ }
+ enc := encryption{Key: key}
+
+ largeData := make([]byte, 10000)
+ if _, err := rand.Read(largeData); err != nil {
+ t.Fatal(err)
+ }
+
+ token := enc.Encrypt(largeData)
+ if token == "" {
+ t.Error("Encrypt() should return non-empty token for large data")
+ }
+
+ decryptedData, ok := enc.Decrypt(token)
+ if !ok {
+ t.Error("Decrypt() should succeed for encrypted large data")
+ }
+ if !bytes.Equal(decryptedData, largeData) {
+ t.Error("Decrypt() should return original large data")
+ }
+}
diff --git a/go.mod b/go.mod
new file mode 100644
index 0000000..9bcb898
--- /dev/null
+++ b/go.mod
@@ -0,0 +1,13 @@
+module git.tjp.lol/authentic_kate
+
+go 1.24.4
+
+require (
+ golang.org/x/crypto v0.39.0
+ golang.org/x/oauth2 v0.30.0
+)
+
+require (
+ cloud.google.com/go/compute/metadata v0.3.0 // indirect
+ golang.org/x/sys v0.33.0 // indirect
+)
diff --git a/go.sum b/go.sum
new file mode 100644
index 0000000..a2cbcfb
--- /dev/null
+++ b/go.sum
@@ -0,0 +1,8 @@
+cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
+cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
+golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
+golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
+golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
+golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
+golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
+golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
diff --git a/login_helpers.go b/login_helpers.go
new file mode 100644
index 0000000..d7cac41
--- /dev/null
+++ b/login_helpers.go
@@ -0,0 +1,64 @@
+package kate
+
+import (
+ "net/http"
+ "net/url"
+ "strings"
+)
+
+// Redirects configures redirect behavior for authentication handlers.
+type Redirects struct {
+ // Default is the default URL to redirect to after successful authentication
+ Default string
+
+ // AllowedPrefixes is a list of allowed redirect URL prefixes for security.
+ //
+ // If empty, any redirect target is allowed (not recommended for production)
+ AllowedPrefixes []string
+
+ // FieldName is the form/query field name for the redirect target
+ //
+ // If empty, Default will always be used as the target
+ FieldName string
+}
+
+func (r Redirects) isValid(target string) bool {
+ targetURL, err := url.Parse(target)
+ if err != nil {
+ return false
+ }
+ if targetURL.IsAbs() && targetURL.Host != "" {
+ return false
+ }
+
+ if len(r.AllowedPrefixes) == 0 {
+ return true
+ }
+
+ for _, prefix := range r.AllowedPrefixes {
+ if strings.HasPrefix(target, prefix) {
+ return true
+ }
+ }
+ return false
+}
+
+func (r Redirects) target(req *http.Request) string {
+ d := r.Default
+ if d == "" {
+ d = "/"
+ }
+
+ if r.FieldName == "" {
+ return d
+ }
+ if err := req.ParseForm(); err != nil {
+ return d
+ }
+
+ target := req.Form.Get(r.FieldName)
+ if target != "" && r.isValid(target) {
+ return target
+ }
+ return d
+}
diff --git a/magic_link_login.go b/magic_link_login.go
new file mode 100644
index 0000000..9228a97
--- /dev/null
+++ b/magic_link_login.go
@@ -0,0 +1,233 @@
+package kate
+
+import (
+ "errors"
+ "net/http"
+ "strings"
+ "time"
+)
+
+// MagicLinkConfig configures the magic link authentication handler behavior.
+type MagicLinkConfig[T any] struct {
+ // Mailer provides user data lookup and email sending
+ Mailer MagicLinkMailer[T]
+
+ // Redirects configures post-authentication redirect behavior
+ Redirects Redirects
+
+ // UsernameField is the form field name for username
+ UsernameField string
+
+ // TokenField is the URL parameter name for the magic link token
+ TokenField string
+
+ // TokenLocation specifies where to retrieve the token from
+ TokenLocation TokenLocation
+
+ // TokenExpiry is how long the magic link token is valid
+ TokenExpiry time.Duration
+
+ // LogError is an optional function to log errors
+ LogError func(error)
+}
+
+// MagicLinkMailer provides user data lookup and email sending for magic link authentication.
+type MagicLinkMailer[T any] interface {
+ // Fetch retrieves user data by username.
+ //
+ // Returns the user data, whether the user was found, and any error.
+ // If the user is not found, should return (zero value, false, nil).
+ Fetch(username string) (T, bool, error)
+
+ // SendEmail sends a magic link email to the user.
+ //
+ // The token parameter contains the encrypted magic link token that should
+ // be included in the email URL for authentication.
+ SendEmail(userData T, token string) error
+}
+
+// TokenLocation specifies where the magic link token should be retrieved from
+type TokenLocation struct{ location int }
+
+var (
+ // TokenLocationQuery retrieves the token from URL query parameters
+ TokenLocationQuery = TokenLocation{0}
+ // TokenLocationPath retrieves the token from URL path parameters using Request.PathValue()
+ TokenLocationPath = TokenLocation{1}
+)
+
+func (mlc *MagicLinkConfig[T]) setDefaults() {
+ if mlc.UsernameField == "" {
+ mlc.UsernameField = "email"
+ }
+ if mlc.TokenField == "" {
+ mlc.TokenField = "token"
+ }
+ // TokenLocation defaults to TokenLocationQuery (zero value)
+ if mlc.TokenExpiry == 0 {
+ mlc.TokenExpiry = 15 * time.Minute
+ }
+}
+
+func (mlc MagicLinkConfig[T]) logError(err error) {
+ if mlc.LogError != nil {
+ mlc.LogError(err)
+ }
+}
+
+type magicLinkToken struct {
+ Username string
+ Redirect string
+ ExpiresAt time.Time
+}
+
+func (t magicLinkToken) serialize() string {
+ return t.Redirect + "\x00" + t.ExpiresAt.Format(time.RFC3339) + "\x00" + t.Username
+}
+
+func parseMagicLinkToken(data string) (magicLinkToken, error) {
+ parts := strings.SplitN(data, "\x00", 3)
+ if len(parts) != 3 {
+ return magicLinkToken{}, errors.New("invalid token format")
+ }
+
+ redirect := parts[0]
+ expiresStr := parts[1]
+ username := parts[2]
+
+ expiresAt, err := time.Parse(time.RFC3339, expiresStr)
+ if err != nil {
+ return magicLinkToken{}, err
+ }
+
+ return magicLinkToken{
+ Redirect: redirect,
+ ExpiresAt: expiresAt,
+ Username: username,
+ }, nil
+}
+
+// MagicLinkLoginHandler returns an HTTP handler that processes magic link requests.
+//
+// It looks up the user, generates a token, and sends an email with the magic link.
+func (a Auth[T]) MagicLinkLoginHandler(config MagicLinkConfig[T]) http.Handler {
+ config.setDefaults()
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ if err := r.ParseForm(); err != nil {
+ http.Error(w, "Invalid form data", http.StatusBadRequest)
+ return
+ }
+
+ username := r.PostForm.Get(config.UsernameField)
+
+ if username == "" {
+ http.Error(w, config.UsernameField+" required", http.StatusBadRequest)
+ return
+ }
+
+ userData, ok, err := config.Mailer.Fetch(username)
+ if err != nil {
+ config.logError(err)
+ http.Error(w, "Error finding user", http.StatusInternalServerError)
+ return
+ }
+ if !ok {
+ // Don't reveal whether user exists
+ w.WriteHeader(http.StatusOK)
+ if _, err := w.Write([]byte("Magic link sent")); err != nil {
+ config.logError(err)
+ }
+ return
+ }
+
+ tokenData := []byte(magicLinkToken{
+ Username: username,
+ Redirect: config.Redirects.target(r),
+ ExpiresAt: time.Now().Add(config.TokenExpiry),
+ }.serialize())
+
+ encryptedToken := a.enc.Encrypt(tokenData)
+
+ if err := config.Mailer.SendEmail(userData, encryptedToken); err != nil {
+ config.logError(err)
+ http.Error(w, "Failed to send email", http.StatusInternalServerError)
+ return
+ }
+
+ w.WriteHeader(http.StatusOK)
+ if _, err := w.Write([]byte("Magic link sent")); err != nil {
+ config.logError(err)
+ }
+ })
+}
+
+// MagicLinkVerifyHandler returns an HTTP handler that verifies magic link tokens.
+//
+// It decrypts and validates the token, sets the authentication cookie, and redirects.
+func (a Auth[T]) MagicLinkVerifyHandler(config MagicLinkConfig[T]) http.Handler {
+ config.setDefaults()
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodGet {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ var encryptedToken string
+ switch config.TokenLocation {
+ case TokenLocationPath:
+ encryptedToken = r.PathValue(config.TokenField)
+ case TokenLocationQuery:
+ encryptedToken = r.URL.Query().Get(config.TokenField)
+ default:
+ encryptedToken = r.URL.Query().Get(config.TokenField)
+ }
+
+ if encryptedToken == "" {
+ http.Error(w, "Missing token", http.StatusBadRequest)
+ return
+ }
+
+ tokenData, ok := a.enc.Decrypt(encryptedToken)
+ if !ok {
+ http.Error(w, "Invalid token", http.StatusUnauthorized)
+ return
+ }
+
+ token, err := parseMagicLinkToken(string(tokenData))
+ if err != nil {
+ config.logError(err)
+ http.Error(w, "Invalid token", http.StatusUnauthorized)
+ return
+ }
+
+ if time.Now().After(token.ExpiresAt) {
+ http.Error(w, "Token expired", http.StatusUnauthorized)
+ return
+ }
+
+ userData, ok, err := config.Mailer.Fetch(token.Username)
+ if err != nil {
+ config.logError(err)
+ http.Error(w, "Authentication failed", http.StatusUnauthorized)
+ return
+ }
+ if !ok {
+ http.Error(w, "Authentication failed", http.StatusUnauthorized)
+ return
+ }
+
+ if err := a.Set(w, userData); err != nil {
+ config.logError(err)
+ http.Error(w, "Failed to set authentication", http.StatusInternalServerError)
+ return
+ }
+
+ http.Redirect(w, r, token.Redirect, http.StatusSeeOther)
+ })
+}
+
diff --git a/magic_link_login_test.go b/magic_link_login_test.go
new file mode 100644
index 0000000..da4bfa6
--- /dev/null
+++ b/magic_link_login_test.go
@@ -0,0 +1,473 @@
+package kate
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+)
+
+// Mock implementation of MagicLinkMailer for testing
+type mockMagicLinkMailer[T any] struct {
+ users map[string]T
+ sentEmails []struct {
+ user T
+ token string
+ }
+ sendEmailFunc func(T, string) error
+}
+
+func newMockMagicLinkMailer[T any](users map[string]T, sendEmailFunc func(T, string) error) *mockMagicLinkMailer[T] {
+ return &mockMagicLinkMailer[T]{
+ users: users,
+ sendEmailFunc: sendEmailFunc,
+ }
+}
+
+func (m *mockMagicLinkMailer[T]) Fetch(username string) (T, bool, error) {
+ user, exists := m.users[username]
+ return user, exists, nil
+}
+
+func (m *mockMagicLinkMailer[T]) SendEmail(userData T, token string) error {
+ m.sentEmails = append(m.sentEmails, struct {
+ user T
+ token string
+ }{userData, token})
+ if m.sendEmailFunc != nil {
+ return m.sendEmailFunc(userData, token)
+ }
+ return nil
+}
+
+type testUser struct {
+ Username string
+ Hash string
+ ID int
+}
+
+type testUserSerDes struct{}
+
+func (ts testUserSerDes) Serialize(w io.Writer, data testUser) error {
+ _, err := fmt.Fprintf(w, "%s|%s|%d", data.Username, data.Hash, data.ID)
+ return err
+}
+
+func (ts testUserSerDes) Deserialize(r io.Reader, data *testUser) error {
+ buf, err := io.ReadAll(r)
+ if err != nil {
+ return err
+ }
+ parts := strings.Split(string(buf), "|")
+ if len(parts) != 3 {
+ return errors.New("invalid format")
+ }
+ data.Username = parts[0]
+ data.Hash = parts[1]
+ id, err := strconv.Atoi(parts[2])
+ if err != nil {
+ return err
+ }
+ data.ID = id
+ return nil
+}
+
+func TestMagicLinkHandler(t *testing.T) {
+ auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{
+ SerDes: testUserSerDes{},
+ CookieName: "test_session",
+ })
+
+ users := map[string]testUser{
+ "john@example.com": {Username: "john@example.com", Hash: "", ID: 1},
+ "jane@example.com": {Username: "jane@example.com", Hash: "", ID: 2},
+ }
+
+ mockMailer := newMockMagicLinkMailer(users, nil)
+
+ config := MagicLinkConfig[testUser]{
+ Mailer: mockMailer,
+ Redirects: Redirects{
+ Default: "/dashboard",
+ AllowedPrefixes: []string{"/app/", "/admin/"},
+ },
+ }
+
+ handler := auth.MagicLinkLoginHandler(config)
+
+ tests := []struct {
+ name string
+ method string
+ formData url.Values
+ expectedStatus int
+ expectedBody string
+ checkEmail bool
+ }{
+ {
+ name: "successful magic link request",
+ method: "POST",
+ formData: url.Values{"email": {"john@example.com"}},
+ expectedStatus: http.StatusOK,
+ expectedBody: "Magic link sent",
+ checkEmail: true,
+ },
+ {
+ name: "magic link with redirect",
+ method: "POST",
+ formData: url.Values{"email": {"john@example.com"}, "redirect": {"/app/settings"}},
+ expectedStatus: http.StatusOK,
+ expectedBody: "Magic link sent",
+ checkEmail: true,
+ },
+ {
+ name: "nonexistent user returns success (no user enumeration)",
+ method: "POST",
+ formData: url.Values{"email": {"nobody@example.com"}},
+ expectedStatus: http.StatusOK,
+ expectedBody: "Magic link sent",
+ checkEmail: false,
+ },
+ {
+ name: "missing email",
+ method: "POST",
+ formData: url.Values{},
+ expectedStatus: http.StatusBadRequest,
+ checkEmail: false,
+ },
+ {
+ name: "GET method not allowed",
+ method: "GET",
+ formData: url.Values{},
+ expectedStatus: http.StatusMethodNotAllowed,
+ checkEmail: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockMailer.sentEmails = nil // Reset
+
+ var req *http.Request
+ if tt.method == "POST" {
+ req = httptest.NewRequest(tt.method, "/magic-link", strings.NewReader(tt.formData.Encode()))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ } else {
+ req = httptest.NewRequest(tt.method, "/magic-link", nil)
+ }
+
+ rr := httptest.NewRecorder()
+ handler.ServeHTTP(rr, req)
+
+ if rr.Code != tt.expectedStatus {
+ t.Errorf("expected status %d, got %d", tt.expectedStatus, rr.Code)
+ }
+
+ if tt.expectedBody != "" {
+ body := strings.TrimSpace(rr.Body.String())
+ if body != tt.expectedBody {
+ t.Errorf("expected body %q, got %q", tt.expectedBody, body)
+ }
+ }
+
+ if tt.checkEmail {
+ if len(mockMailer.sentEmails) != 1 {
+ t.Errorf("expected 1 email sent, got %d", len(mockMailer.sentEmails))
+ } else {
+ email := mockMailer.sentEmails[0]
+ if email.token == "" {
+ t.Error("expected non-empty token")
+ }
+ }
+ } else {
+ if len(mockMailer.sentEmails) != 0 {
+ t.Errorf("expected no emails sent, got %d", len(mockMailer.sentEmails))
+ }
+ }
+ })
+ }
+}
+
+func TestMagicLinkVerifyHandler(t *testing.T) {
+ auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{
+ SerDes: testUserSerDes{},
+ CookieName: "test_session",
+ })
+
+ users := map[string]testUser{
+ "john@example.com": {Username: "john@example.com", Hash: "", ID: 1},
+ }
+
+ mockMailer := newMockMagicLinkMailer(users, nil)
+
+ config := MagicLinkConfig[testUser]{
+ Mailer: mockMailer,
+ Redirects: Redirects{Default: "/dashbaord"},
+ TokenExpiry: time.Minute,
+ }
+
+ token := magicLinkToken{
+ Username: "john@example.com",
+ Redirect: "/app/settings",
+ ExpiresAt: time.Now().Add(time.Minute),
+ }
+ tokenData := []byte(token.serialize())
+ validToken := auth.enc.Encrypt(tokenData)
+
+ expiredToken := magicLinkToken{
+ Username: "john@example.com",
+ Redirect: "/app/settings",
+ ExpiresAt: time.Now().Add(-time.Minute),
+ }
+ expiredTokenData := []byte(expiredToken.serialize())
+ expiredTokenStr := auth.enc.Encrypt(expiredTokenData)
+
+ handler := auth.MagicLinkVerifyHandler(config)
+
+ tests := []struct {
+ name string
+ method string
+ token string
+ expectedStatus int
+ expectedRedirect string
+ checkCookie bool
+ }{
+ {
+ name: "valid token with redirect",
+ method: "GET",
+ token: validToken,
+ expectedStatus: http.StatusSeeOther,
+ expectedRedirect: "/app/settings",
+ checkCookie: true,
+ },
+ {
+ name: "missing token",
+ method: "GET",
+ token: "",
+ expectedStatus: http.StatusBadRequest,
+ checkCookie: false,
+ },
+ {
+ name: "invalid token",
+ method: "GET",
+ token: "invalid-token",
+ expectedStatus: http.StatusUnauthorized,
+ checkCookie: false,
+ },
+ {
+ name: "expired token",
+ method: "GET",
+ token: expiredTokenStr,
+ expectedStatus: http.StatusUnauthorized,
+ checkCookie: false,
+ },
+ {
+ name: "POST method not allowed",
+ method: "POST",
+ token: validToken,
+ expectedStatus: http.StatusMethodNotAllowed,
+ checkCookie: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ url := "/verify"
+ if tt.token != "" {
+ url += "?token=" + tt.token
+ }
+
+ req := httptest.NewRequest(tt.method, url, nil)
+ rr := httptest.NewRecorder()
+ handler.ServeHTTP(rr, req)
+
+ if rr.Code != tt.expectedStatus {
+ t.Errorf("expected status %d, got %d", tt.expectedStatus, rr.Code)
+ }
+
+ if tt.expectedRedirect != "" {
+ location := rr.Header().Get("Location")
+ if location != tt.expectedRedirect {
+ t.Errorf("expected redirect to %s, got %s", tt.expectedRedirect, location)
+ }
+ }
+
+ if tt.checkCookie {
+ cookies := rr.Result().Cookies()
+ found := false
+ for _, cookie := range cookies {
+ if cookie.Name == "test_session" && cookie.Value != "" {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Error("expected authentication cookie to be set")
+ }
+ }
+ })
+ }
+}
+
+func TestMagicLinkConfigDefaults(t *testing.T) {
+ config := MagicLinkConfig[testUser]{}
+ config.setDefaults()
+
+ if config.UsernameField != "email" {
+ t.Errorf("expected UsernameField to be 'email', got %s", config.UsernameField)
+ }
+ if config.TokenField != "token" {
+ t.Errorf("expected TokenField to be 'token', got %s", config.TokenField)
+ }
+ if config.TokenExpiry != 15*time.Minute {
+ t.Errorf("expected TokenExpiry to be 15 minutes, got %v", config.TokenExpiry)
+ }
+ if config.TokenLocation != TokenLocationQuery {
+ t.Errorf("expected TokenLocation to be TokenLocationQuery, got %v", config.TokenLocation)
+ }
+}
+
+func TestMagicLinkTokenSerialization(t *testing.T) {
+ now := time.Now().Truncate(time.Second) // Remove nanoseconds for consistent comparison
+
+ tests := []struct {
+ name string
+ token magicLinkToken
+ expected string
+ }{
+ {
+ name: "simple token",
+ token: magicLinkToken{
+ Redirect: "/dashboard",
+ ExpiresAt: now,
+ Username: "user@example.com",
+ },
+ expected: "/dashboard\x00" + now.Format(time.RFC3339) + "\x00user@example.com",
+ },
+ {
+ name: "redirect with pipes",
+ token: magicLinkToken{
+ Redirect: "/app|section|page",
+ ExpiresAt: now,
+ Username: "user@example.com",
+ },
+ expected: "/app|section|page\x00" + now.Format(time.RFC3339) + "\x00user@example.com",
+ },
+ {
+ name: "empty redirect",
+ token: magicLinkToken{
+ Redirect: "",
+ ExpiresAt: now,
+ Username: "user@example.com",
+ },
+ expected: "\x00" + now.Format(time.RFC3339) + "\x00user@example.com",
+ },
+ {
+ name: "username with pipes",
+ token: magicLinkToken{
+ Redirect: "/dashboard",
+ ExpiresAt: now,
+ Username: "user|with|pipes@example.com",
+ },
+ expected: "/dashboard\x00" + now.Format(time.RFC3339) + "\x00user|with|pipes@example.com",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ serialized := tt.token.serialize()
+ if serialized != tt.expected {
+ t.Errorf("serialize() = %q, want %q", serialized, tt.expected)
+ }
+
+ parsed, err := parseMagicLinkToken(serialized)
+ if err != nil {
+ t.Errorf("parseMagicLinkToken() error = %v", err)
+ return
+ }
+
+ if parsed.Redirect != tt.token.Redirect {
+ t.Errorf("parsed Redirect = %q, want %q", parsed.Redirect, tt.token.Redirect)
+ }
+ if parsed.Username != tt.token.Username {
+ t.Errorf("parsed Username = %q, want %q", parsed.Username, tt.token.Username)
+ }
+ if !parsed.ExpiresAt.Equal(tt.token.ExpiresAt) {
+ t.Errorf("parsed ExpiresAt = %v, want %v", parsed.ExpiresAt, tt.token.ExpiresAt)
+ }
+ })
+ }
+}
+
+func TestParseMagicLinkTokenErrors(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ }{
+ {"empty string", ""},
+ {"single part", "onlyonepart"},
+ {"two parts", "two\x00parts"},
+ {"invalid time", "redirect\x00invalid-time\x00username"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ _, err := parseMagicLinkToken(tt.input)
+ if err == nil {
+ t.Errorf("parseMagicLinkToken(%q) expected error, got nil", tt.input)
+ }
+ })
+ }
+}
+
+func TestMagicLinkVerifyHandlerPathToken(t *testing.T) {
+ auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{
+ SerDes: testUserSerDes{},
+ CookieName: "test_session",
+ })
+
+ users := map[string]testUser{
+ "test": {Username: "test", ID: 1},
+ }
+
+ mockMailer := newMockMagicLinkMailer(users, nil)
+
+ config := MagicLinkConfig[testUser]{
+ Mailer: mockMailer,
+ Redirects: Redirects{Default: "/dashboard"},
+ TokenExpiry: time.Minute,
+ TokenLocation: TokenLocationPath,
+ TokenField: "token",
+ }
+
+ handler := auth.MagicLinkVerifyHandler(config)
+
+ if config.TokenLocation != TokenLocationPath {
+ t.Errorf("expected TokenLocation to be TokenLocationPath, got %v", config.TokenLocation)
+ }
+ if config.TokenField != "token" {
+ t.Errorf("expected TokenField to be 'token', got %s", config.TokenField)
+ }
+
+ req := httptest.NewRequest("GET", "/verify", nil)
+ rr := httptest.NewRecorder()
+ handler.ServeHTTP(rr, req)
+
+ if rr.Code != http.StatusBadRequest {
+ t.Errorf("expected status %d for missing token, got %d", http.StatusBadRequest, rr.Code)
+ }
+}
+
+func TestTokenLocationEnum(t *testing.T) {
+ if TokenLocationQuery.location != 0 {
+ t.Errorf("expected TokenLocationQuery to be 0, got %d", TokenLocationQuery.location)
+ }
+ if TokenLocationPath.location != 1 {
+ t.Errorf("expected TokenLocationPath to be 1, got %d", TokenLocationPath.location)
+ }
+}
diff --git a/oauth_login.go b/oauth_login.go
new file mode 100644
index 0000000..6c16275
--- /dev/null
+++ b/oauth_login.go
@@ -0,0 +1,257 @@
+package kate
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "time"
+
+ "golang.org/x/oauth2"
+ "golang.org/x/oauth2/github"
+ "golang.org/x/oauth2/google"
+)
+
+// OAuthConfig configures the OAuth login handler behavior.
+type OAuthConfig[T any] struct {
+ // OAuth2Config is the standard OAuth2 configuration
+ OAuth2Config *oauth2.Config
+
+ // UserInfoURL is the provider endpoint to get user information
+ UserInfoURL string
+
+ // DataStore provides user data lookup and temporary state storage
+ DataStore OAuthDataStore[T]
+
+ // Redirects configures post-authentication redirect behavior
+ Redirects Redirects
+
+ // StateExpiry is how long the OAuth state is valid
+ StateExpiry time.Duration
+
+ // LogError is an optional function to log errors
+ LogError func(error)
+}
+
+// OAuthDataStore provides user data lookup and temporary storage for OAuth state data.
+type OAuthDataStore[T any] interface {
+ // GetOrCreateUser retrieves or creates user data by email address.
+ //
+ // If the user exists, returns their data. If not, creates a new user
+ // and returns the new user data. This handles both OAuth login and
+ // registration in a single operation.
+ GetOrCreateUser(email string) (T, error)
+
+ // StoreState stores OAuth state data temporarily and returns a unique state ID.
+ //
+ // The returned state ID will be used as the OAuth state parameter.
+ // Implementations should generate a unique, unguessable ID for security.
+ StoreState(state OAuthState) (string, error)
+
+ // GetAndClearState retrieves and deletes OAuth state data by ID.
+ //
+ // Returns the state data if found, otherwise nil.
+ // The state should be deleted after retrieval to ensure one-time use.
+ GetAndClearState(id string) (*OAuthState, error)
+}
+
+// GoogleOAuthConfig creates an OAuthConfig for Google OAuth with sensible defaults.
+func GoogleOAuthConfig[T any](
+ clientID,
+ clientSecret,
+ callbackURI string,
+ dataStore OAuthDataStore[T],
+) OAuthConfig[T] {
+ return OAuthConfig[T]{
+ OAuth2Config: &oauth2.Config{
+ ClientID: clientID,
+ ClientSecret: clientSecret,
+ RedirectURL: callbackURI,
+ Scopes: []string{"email"},
+ Endpoint: google.Endpoint,
+ },
+ UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
+ DataStore: dataStore,
+ Redirects: Redirects{Default: "/", FieldName: "redirect"},
+ StateExpiry: 10 * time.Minute,
+ }
+}
+
+// GitHubOAuthConfig creates an OAuthConfig for GitHub OAuth with sensible defaults.
+func GitHubOAuthConfig[T any](
+ clientID,
+ clientSecret,
+ callbackURI string,
+ dataStore OAuthDataStore[T],
+) OAuthConfig[T] {
+ return OAuthConfig[T]{
+ OAuth2Config: &oauth2.Config{
+ ClientID: clientID,
+ ClientSecret: clientSecret,
+ RedirectURL: callbackURI,
+ Scopes: []string{"user:email"},
+ Endpoint: github.Endpoint,
+ },
+ UserInfoURL: "https://api.github.com/user",
+ DataStore: dataStore,
+ Redirects: Redirects{Default: "/", FieldName: "redirect"},
+ StateExpiry: 10 * time.Minute,
+ }
+}
+
+func (oc *OAuthConfig[T]) setDefaults() {
+ if oc.StateExpiry == 0 {
+ oc.StateExpiry = 10 * time.Minute
+ }
+}
+
+func (oc OAuthConfig[T]) logError(err error) {
+ if oc.LogError != nil {
+ oc.LogError(err)
+ }
+}
+
+// OAuthState represents the OAuth state data that needs to be temporarily stored.
+type OAuthState struct {
+ Redirect string
+ ExpiresAt time.Time
+}
+
+// OAuthLoginHandler returns an HTTP handler that initiates OAuth authentication.
+func (a Auth[T]) OAuthLoginHandler(config OAuthConfig[T]) http.Handler {
+ config.setDefaults()
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodGet {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ state := OAuthState{
+ Redirect: config.Redirects.target(r),
+ ExpiresAt: time.Now().Add(config.StateExpiry),
+ }
+
+ stateID, err := config.DataStore.StoreState(state)
+ if err != nil {
+ config.logError(err)
+ http.Error(w, "Failed to store state", http.StatusInternalServerError)
+ return
+ }
+
+ authURL := config.OAuth2Config.AuthCodeURL(stateID)
+ http.Redirect(w, r, authURL, http.StatusSeeOther)
+ })
+}
+
+// OAuthCallbackHandler returns an HTTP handler that handles the OAuth callback.
+func (a Auth[T]) OAuthCallbackHandler(config OAuthConfig[T]) http.Handler {
+ config.setDefaults()
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodGet {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ stateID := r.URL.Query().Get("state")
+ code := r.URL.Query().Get("code")
+ errorParam := r.URL.Query().Get("error")
+
+ if errorParam != "" {
+ config.logError(fmt.Errorf("OAuth error: %s", errorParam))
+ http.Error(w, "OAuth authorization failed", http.StatusBadRequest)
+ return
+ }
+
+ if code == "" {
+ http.Error(w, "Missing authorization code", http.StatusBadRequest)
+ return
+ }
+
+ if stateID == "" {
+ http.Error(w, "Missing state parameter", http.StatusBadRequest)
+ return
+ }
+
+ state, err := config.DataStore.GetAndClearState(stateID)
+ if err != nil {
+ config.logError(err)
+ http.Error(w, "Error finding login state", http.StatusInternalServerError)
+ return
+ }
+ if state == nil {
+ http.Error(w, "Invalid login state", http.StatusBadRequest)
+ return
+ }
+
+ if time.Now().After(state.ExpiresAt) {
+ http.Error(w, "Login expired", http.StatusBadRequest)
+ return
+ }
+
+ ctx := context.Background()
+ token, err := config.OAuth2Config.Exchange(ctx, code)
+ if err != nil {
+ config.logError(err)
+ http.Error(w, "Failed to exchange code for token", http.StatusInternalServerError)
+ return
+ }
+
+ client := config.OAuth2Config.Client(ctx, token)
+ resp, err := client.Get(config.UserInfoURL)
+ if err != nil {
+ config.logError(err)
+ http.Error(w, "Failed to get user info", http.StatusInternalServerError)
+ return
+ }
+ defer func() {
+ if err := resp.Body.Close(); err != nil {
+ config.logError(err)
+ }
+ }()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ config.logError(err)
+ http.Error(w, "Failed to read user info response", http.StatusInternalServerError)
+ return
+ }
+
+ var userInfo struct {
+ Email string `json:"email"`
+ }
+ if err := json.Unmarshal(body, &userInfo); err != nil {
+ config.logError(err)
+ http.Error(w, "Failed to parse user info", http.StatusInternalServerError)
+ return
+ }
+ email := userInfo.Email
+ if email == "" {
+ http.Error(w, "No email address found", http.StatusBadRequest)
+ return
+ }
+
+ userData, err := config.DataStore.GetOrCreateUser(email)
+ if err != nil {
+ config.logError(err)
+ http.Error(w, "Failed to get or create user", http.StatusInternalServerError)
+ return
+ }
+
+ if err := a.Set(w, userData); err != nil {
+ config.logError(err)
+ http.Error(w, "Failed to set authentication cookie", http.StatusInternalServerError)
+ return
+ }
+
+ redirectTarget := state.Redirect
+ if redirectTarget == "" {
+ redirectTarget = config.Redirects.Default
+ if redirectTarget == "" {
+ redirectTarget = "/"
+ }
+ }
+
+ http.Redirect(w, r, redirectTarget, http.StatusSeeOther)
+ })
+}
diff --git a/oauth_login_test.go b/oauth_login_test.go
new file mode 100644
index 0000000..caec89b
--- /dev/null
+++ b/oauth_login_test.go
@@ -0,0 +1,369 @@
+package kate
+
+import (
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+)
+
+// Mock implementation of OAuthDataStore for testing
+type mockOAuthDataStore[T any] struct {
+ userData map[string]T
+ states map[string]OAuthState
+ mu sync.RWMutex
+ getOrCreateUser func(email string) (T, error)
+ stateIDCounter int
+}
+
+func newMockOAuthDataStore[T any](getOrCreateUser func(email string) (T, error)) *mockOAuthDataStore[T] {
+ return &mockOAuthDataStore[T]{
+ userData: make(map[string]T),
+ states: make(map[string]OAuthState),
+ getOrCreateUser: getOrCreateUser,
+ }
+}
+
+func (m *mockOAuthDataStore[T]) GetOrCreateUser(email string) (T, error) {
+ return m.getOrCreateUser(email)
+}
+
+func (m *mockOAuthDataStore[T]) StoreState(state OAuthState) (string, error) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ m.stateIDCounter++
+ stateID := fmt.Sprintf("state_%d", m.stateIDCounter)
+ m.states[stateID] = state
+ return stateID, nil
+}
+
+func (m *mockOAuthDataStore[T]) GetAndClearState(id string) (*OAuthState, error) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ state, exists := m.states[id]
+ if !exists {
+ return nil, nil
+ }
+
+ delete(m.states, id)
+ return &state, nil
+}
+
+func TestOAuthLoginHandler(t *testing.T) {
+ auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{
+ SerDes: testUserSerDes{},
+ CookieName: "test_session",
+ })
+
+ mockStore := newMockOAuthDataStore(func(email string) (testUser, error) {
+ return testUser{Username: email, ID: 1}, nil
+ })
+
+ config := GoogleOAuthConfig(
+ "test-client-id",
+ "test-client-secret",
+ "http://localhost:8080/auth/callback",
+ mockStore,
+ )
+ config.Redirects.Default = "/dashboard"
+ config.Redirects.AllowedPrefixes = []string{"/app/", "/admin/"}
+
+ handler := auth.OAuthLoginHandler(config)
+
+ tests := []struct {
+ name string
+ method string
+ queryParams url.Values
+ formData url.Values
+ expectedStatus int
+ checkRedirect bool
+ expectedContains []string
+ }{
+ {
+ name: "GET request initiates OAuth flow",
+ method: "GET",
+ queryParams: url.Values{},
+ expectedStatus: http.StatusSeeOther,
+ checkRedirect: true,
+ expectedContains: []string{
+ "accounts.google.com",
+ "client_id=test-client-id",
+ "redirect_uri=http%3A%2F%2Flocalhost%3A8080%2Fauth%2Fcallback",
+ "response_type=code",
+ "state=",
+ "scope=email",
+ },
+ },
+ {
+ name: "GET request with redirect parameter",
+ method: "GET",
+ queryParams: url.Values{"redirect": {"/app/settings"}},
+ expectedStatus: http.StatusSeeOther,
+ checkRedirect: true,
+ expectedContains: []string{
+ "accounts.google.com",
+ "client_id=test-client-id",
+ },
+ },
+ {
+ name: "POST method not allowed",
+ method: "POST",
+ formData: url.Values{"redirect": {"/app/profile"}},
+ expectedStatus: http.StatusMethodNotAllowed,
+ checkRedirect: false,
+ },
+ {
+ name: "invalid redirect falls back to default",
+ method: "GET",
+ queryParams: url.Values{"redirect": {"/evil/"}},
+ expectedStatus: http.StatusSeeOther,
+ checkRedirect: true,
+ expectedContains: []string{
+ "accounts.google.com",
+ "client_id=test-client-id",
+ },
+ },
+ {
+ name: "PUT method not allowed",
+ method: "PUT",
+ expectedStatus: http.StatusMethodNotAllowed,
+ checkRedirect: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var req *http.Request
+ reqURL := "/oauth/login"
+
+ if len(tt.queryParams) > 0 {
+ reqURL += "?" + tt.queryParams.Encode()
+ }
+
+ if tt.method == "POST" && tt.formData != nil {
+ req = httptest.NewRequest(tt.method, reqURL, strings.NewReader(tt.formData.Encode()))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ } else {
+ req = httptest.NewRequest(tt.method, reqURL, nil)
+ }
+
+ rr := httptest.NewRecorder()
+ handler.ServeHTTP(rr, req)
+
+ if rr.Code != tt.expectedStatus {
+ t.Errorf("expected status %d, got %d", tt.expectedStatus, rr.Code)
+ }
+
+ if tt.checkRedirect {
+ location := rr.Header().Get("Location")
+ if location == "" {
+ t.Error("expected redirect location to be set")
+ }
+
+ for _, expected := range tt.expectedContains {
+ if !strings.Contains(location, expected) {
+ t.Errorf("expected redirect URL to contain %q, got %q", expected, location)
+ }
+ }
+ }
+ })
+ }
+}
+
+func TestOAuthCallbackHandler(t *testing.T) {
+ auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{
+ SerDes: testUserSerDes{},
+ CookieName: "test_session",
+ })
+
+ mockStore := newMockOAuthDataStore(func(email string) (testUser, error) {
+ return testUser{Username: email, ID: 1}, nil
+ })
+
+ config := GoogleOAuthConfig(
+ "test-client-id",
+ "test-client-secret",
+ "http://localhost:8080/auth/callback",
+ mockStore,
+ )
+ config.Redirects.Default = "/dashboard"
+
+ handler := auth.OAuthCallbackHandler(config)
+
+ tests := []struct {
+ name string
+ method string
+ queryParams url.Values
+ expectedStatus int
+ }{
+ {
+ name: "GET callback with invalid state",
+ method: "GET",
+ queryParams: url.Values{"code": {"test-code"}, "state": {"test-state"}},
+ expectedStatus: http.StatusBadRequest,
+ },
+ {
+ name: "POST method not allowed",
+ method: "POST",
+ expectedStatus: http.StatusMethodNotAllowed,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ reqURL := "/oauth/callback"
+ if len(tt.queryParams) > 0 {
+ reqURL += "?" + tt.queryParams.Encode()
+ }
+
+ req := httptest.NewRequest(tt.method, reqURL, nil)
+ rr := httptest.NewRecorder()
+ handler.ServeHTTP(rr, req)
+
+ if rr.Code != tt.expectedStatus {
+ t.Errorf("expected status %d, got %d", tt.expectedStatus, rr.Code)
+ }
+ })
+ }
+}
+
+func TestGoogleOAuthConfig(t *testing.T) {
+ mockStore := newMockOAuthDataStore(func(email string) (testUser, error) {
+ return testUser{Username: email, ID: 1}, nil
+ })
+
+ config := GoogleOAuthConfig(
+ "test-client-id",
+ "test-client-secret",
+ "http://localhost:8080/auth/callback",
+ mockStore,
+ )
+
+ if config.OAuth2Config.ClientID != "test-client-id" {
+ t.Errorf("expected ClientID 'test-client-id', got %q", config.OAuth2Config.ClientID)
+ }
+ if config.OAuth2Config.ClientSecret != "test-client-secret" {
+ t.Errorf("expected ClientSecret 'test-client-secret', got %q", config.OAuth2Config.ClientSecret)
+ }
+ if config.OAuth2Config.RedirectURL != "http://localhost:8080/auth/callback" {
+ t.Errorf("expected RedirectURL 'http://localhost:8080/auth/callback', got %q", config.OAuth2Config.RedirectURL)
+ }
+ if config.OAuth2Config.Endpoint.AuthURL != "https://accounts.google.com/o/oauth2/auth" {
+ t.Errorf("expected AuthURL 'https://accounts.google.com/o/oauth2/auth', got %q", config.OAuth2Config.Endpoint.AuthURL)
+ }
+ if config.OAuth2Config.Endpoint.TokenURL != "https://oauth2.googleapis.com/token" {
+ t.Errorf("expected TokenURL 'https://oauth2.googleapis.com/token', got %q", config.OAuth2Config.Endpoint.TokenURL)
+ }
+ if config.UserInfoURL != "https://www.googleapis.com/oauth2/v2/userinfo" {
+ t.Errorf("expected UserInfoURL 'https://www.googleapis.com/oauth2/v2/userinfo', got %q", config.UserInfoURL)
+ }
+ expectedScopes := []string{"email"}
+ if len(config.OAuth2Config.Scopes) != len(expectedScopes) {
+ t.Errorf("expected %d scopes, got %d", len(expectedScopes), len(config.OAuth2Config.Scopes))
+ } else {
+ for i, scope := range expectedScopes {
+ if config.OAuth2Config.Scopes[i] != scope {
+ t.Errorf("expected scope[%d] %q, got %q", i, scope, config.OAuth2Config.Scopes[i])
+ }
+ }
+ }
+}
+
+func TestGitHubOAuthConfig(t *testing.T) {
+ mockStore := newMockOAuthDataStore(func(email string) (testUser, error) {
+ return testUser{Username: email, ID: 1}, nil
+ })
+
+ config := GitHubOAuthConfig(
+ "test-github-client-id",
+ "test-github-client-secret",
+ "http://localhost:8080/auth/github/callback",
+ mockStore,
+ )
+
+ if config.OAuth2Config.ClientID != "test-github-client-id" {
+ t.Errorf("expected ClientID 'test-github-client-id', got %q", config.OAuth2Config.ClientID)
+ }
+ if config.OAuth2Config.ClientSecret != "test-github-client-secret" {
+ t.Errorf("expected ClientSecret 'test-github-client-secret', got %q", config.OAuth2Config.ClientSecret)
+ }
+ if config.OAuth2Config.RedirectURL != "http://localhost:8080/auth/github/callback" {
+ t.Errorf("expected RedirectURL 'http://localhost:8080/auth/github/callback', got %q", config.OAuth2Config.RedirectURL)
+ }
+ if config.OAuth2Config.Endpoint.AuthURL != "https://github.com/login/oauth/authorize" {
+ t.Errorf("expected AuthURL 'https://github.com/login/oauth/authorize', got %q", config.OAuth2Config.Endpoint.AuthURL)
+ }
+ if config.OAuth2Config.Endpoint.TokenURL != "https://github.com/login/oauth/access_token" {
+ t.Errorf("expected TokenURL 'https://github.com/login/oauth/access_token', got %q", config.OAuth2Config.Endpoint.TokenURL)
+ }
+ if config.UserInfoURL != "https://api.github.com/user" {
+ t.Errorf("expected UserInfoURL 'https://api.github.com/user', got %q", config.UserInfoURL)
+ }
+ expectedScopes := []string{"user:email"}
+ if len(config.OAuth2Config.Scopes) != len(expectedScopes) {
+ t.Errorf("expected %d scopes, got %d", len(expectedScopes), len(config.OAuth2Config.Scopes))
+ } else {
+ for i, scope := range expectedScopes {
+ if config.OAuth2Config.Scopes[i] != scope {
+ t.Errorf("expected scope[%d] %q, got %q", i, scope, config.OAuth2Config.Scopes[i])
+ }
+ }
+ }
+}
+
+func TestOAuthStateStore(t *testing.T) {
+ mockStore := newMockOAuthDataStore(func(email string) (testUser, error) {
+ return testUser{Username: email, ID: 1}, nil
+ })
+
+ now := time.Now().Truncate(time.Second)
+ state := OAuthState{
+ Redirect: "/dashboard",
+ ExpiresAt: now,
+ }
+
+ // Test storing state
+ stateID, err := mockStore.StoreState(state)
+ if err != nil {
+ t.Errorf("StoreState() error = %v", err)
+ return
+ }
+
+ if stateID == "" {
+ t.Error("StoreState() should return non-empty state ID")
+ }
+
+ // Test retrieving state
+ retrievedState, err := mockStore.GetAndClearState(stateID)
+ if err != nil {
+ t.Errorf("GetAndClearState() error = %v", err)
+ return
+ }
+
+ if retrievedState == nil {
+ t.Error("GetAndClearState() should return state data")
+ return
+ }
+
+ if retrievedState.Redirect != state.Redirect {
+ t.Errorf("retrieved Redirect = %q, want %q", retrievedState.Redirect, state.Redirect)
+ }
+ if !retrievedState.ExpiresAt.Equal(state.ExpiresAt) {
+ t.Errorf("retrieved ExpiresAt = %v, want %v", retrievedState.ExpiresAt, state.ExpiresAt)
+ }
+
+ // Test that state is cleared (one-time use)
+ clearedState, err := mockStore.GetAndClearState(stateID)
+ if err != nil {
+ t.Errorf("GetAndClearState() on cleared state error = %v", err)
+ }
+ if clearedState != nil {
+ t.Error("GetAndClearState() should return nil for already used state")
+ }
+}
diff --git a/password.go b/password.go
new file mode 100644
index 0000000..d1a5c54
--- /dev/null
+++ b/password.go
@@ -0,0 +1,131 @@
+package kate
+
+import (
+ "crypto/rand"
+ "crypto/subtle"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "strings"
+
+ "golang.org/x/crypto/argon2"
+)
+
+// Argon2Config configures Argon2id password hashing parameters.
+//
+// Leave any fields at their zero value to use the default.
+type Argon2Config struct {
+ // Time is the number of iterations (default: 1)
+ Time uint32
+ // Memory is the memory usage in KiB (default: 64*1024 = 64MB)
+ Memory uint32
+ // Threads is the number of parallel threads (default: 4)
+ Threads uint8
+ // KeyLen is the length of the derived key in bytes (default: 32)
+ KeyLen uint32
+}
+
+var defaultArgon2Config = Argon2Config{
+ Time: 1,
+ Memory: 64 * 1024,
+ Threads: 4,
+ KeyLen: 32,
+}
+
+// HashPassword hashes a password using Argon2id.
+//
+// Returns a PHC-formatted string containing all parameters needed for verification.
+// If config is nil, secure defaults are used.
+func HashPassword(password string, config *Argon2Config) (string, error) {
+ // Use defaults if no config provided
+ if config == nil {
+ config = &defaultArgon2Config
+ }
+
+ salt := make([]byte, 16)
+ if _, err := rand.Read(salt); err != nil {
+ return "", fmt.Errorf("HashPassword: %w", err)
+ }
+
+ // Use defaults for zero values in config
+ time := config.Time
+ if time == 0 {
+ time = defaultArgon2Config.Time
+ }
+ memory := config.Memory
+ if memory == 0 {
+ memory = defaultArgon2Config.Memory
+ }
+ threads := config.Threads
+ if threads == 0 {
+ threads = defaultArgon2Config.Threads
+ }
+ keyLen := config.KeyLen
+ if keyLen == 0 {
+ keyLen = defaultArgon2Config.KeyLen
+ }
+
+ hash := argon2.IDKey([]byte(password), salt, time, memory, threads, keyLen)
+
+ return buildPHC(salt, hash, memory, time, threads), nil
+}
+
+// ComparePassword verifies a password against a stored hash in PHC format.
+//
+// Returns (false, nil) for incorrect passwords and (false, error) for malformed hashes.
+// Uses constant-time comparison to prevent timing attacks.
+func ComparePassword(loginpass, storedhash string) (bool, error) {
+ salt, hash, memory, time, threads, err := parsePHC(storedhash)
+ if err != nil {
+ return false, fmt.Errorf("ComparePassword: %w", err)
+ }
+
+ loginHash := argon2.IDKey([]byte(loginpass), salt, time, memory, threads, uint32(len(hash)))
+
+ return subtle.ConstantTimeCompare(loginHash, hash) == 1, nil
+}
+
+func buildPHC(salt, hash []byte, memory, time uint32, threads uint8) string {
+ b64Salt := base64.RawStdEncoding.EncodeToString(salt)
+ b64Hash := base64.RawStdEncoding.EncodeToString(hash)
+ return fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", argon2.Version, memory, time, threads, b64Salt, b64Hash)
+}
+
+func parsePHC(phc string) (salt, hash []byte, memory, time uint32, threads uint8, err error) {
+ parts := strings.Split(phc, "$")
+ if len(parts) != 6 {
+ err = errors.New("invalid PHC format")
+ return
+ }
+
+ if parts[1] != "argon2id" {
+ err = errors.New("unsupported algorithm")
+ return
+ }
+
+ var version int
+ if _, err = fmt.Sscanf(parts[2], "v=%d", &version); err != nil {
+ return
+ }
+
+ if version > argon2.Version {
+ err = errors.New("unsupported argon2 version")
+ return
+ }
+
+ var p uint32
+ if _, err = fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &p); err != nil {
+ return
+ }
+ threads = uint8(p)
+
+ if salt, err = base64.RawStdEncoding.DecodeString(parts[4]); err != nil {
+ return
+ }
+
+ if hash, err = base64.RawStdEncoding.DecodeString(parts[5]); err != nil {
+ return
+ }
+
+ return
+}
diff --git a/password_login.go b/password_login.go
new file mode 100644
index 0000000..837d9c9
--- /dev/null
+++ b/password_login.go
@@ -0,0 +1,105 @@
+package kate
+
+import "net/http"
+
+// PasswordLoginConfig configures the password login handler behavior.
+type PasswordLoginConfig[T any] struct {
+ // UserData provides user data lookup and password hash extraction
+ UserData PasswordUserDataStore[T]
+
+ // Redirects configures post-authentication redirect behavior
+ Redirects Redirects
+
+ // UsernameField is the form field name for username
+ UsernameField string
+
+ // PasswordField is the form field name for password
+ PasswordField string
+
+ // LogError is called when the login handler encounters unexpected errors
+ LogError func(error)
+}
+
+// PasswordUserDataStore provides user data lookup for password authentication.
+type PasswordUserDataStore[T any] interface {
+ // Fetch retrieves user data by username.
+ //
+ // Returns the user data, whether the user was found, and any error.
+ // If the user is not found, should return (zero value, false, nil).
+ Fetch(username string) (T, bool, error)
+
+ // GetPassHash extracts the password hash from user data.
+ //
+ // Returns the stored password hash for comparison with the provided password.
+ GetPassHash(userData T) string
+}
+
+func (lc *PasswordLoginConfig[T]) setDefaults() {
+ if lc.UsernameField == "" {
+ lc.UsernameField = "username"
+ }
+ if lc.PasswordField == "" {
+ lc.PasswordField = "password"
+ }
+}
+
+func (lc PasswordLoginConfig[T]) logError(err error) {
+ if lc.LogError != nil {
+ lc.LogError(err)
+ }
+}
+
+// PasswordLoginHandler returns an HTTP handler that processes password login form submissions.
+func (a Auth[T]) PasswordLoginHandler(config PasswordLoginConfig[T]) http.Handler {
+ config.setDefaults()
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ if err := r.ParseForm(); err != nil {
+ http.Error(w, "Invalid form data", http.StatusBadRequest)
+ return
+ }
+
+ username := r.PostForm.Get(config.UsernameField)
+ password := r.PostForm.Get(config.PasswordField)
+
+ if username == "" || password == "" {
+ http.Error(w, "Username and password required", http.StatusBadRequest)
+ return
+ }
+
+ userData, ok, err := config.UserData.Fetch(username)
+ if err != nil {
+ config.logError(err)
+ http.Error(w, "Error fetching user", http.StatusInternalServerError)
+ return
+ }
+ if !ok {
+ http.Error(w, "Authentication failed", http.StatusUnauthorized)
+ return
+ }
+
+ match, err := ComparePassword(password, config.UserData.GetPassHash(userData))
+ if err != nil {
+ config.logError(err)
+ http.Error(w, "Authentication failed", http.StatusUnauthorized)
+ return
+ }
+
+ if !match {
+ http.Error(w, "Authentication failed", http.StatusUnauthorized)
+ return
+ }
+
+ if err := a.Set(w, userData); err != nil {
+ config.logError(err)
+ http.Error(w, "Failed to set authentication", http.StatusInternalServerError)
+ return
+ }
+
+ http.Redirect(w, r, config.Redirects.target(r), http.StatusSeeOther)
+ })
+}
diff --git a/password_login_test.go b/password_login_test.go
new file mode 100644
index 0000000..eec1ba7
--- /dev/null
+++ b/password_login_test.go
@@ -0,0 +1,385 @@
+package kate
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+)
+
+// Mock implementation of PasswordUserDataStore for testing
+type mockPasswordUserDataStore[T any] struct {
+ users map[string]T
+ getHash func(T) string
+}
+
+func newMockPasswordUserDataStore[T any](users map[string]T, getHash func(T) string) *mockPasswordUserDataStore[T] {
+ return &mockPasswordUserDataStore[T]{
+ users: users,
+ getHash: getHash,
+ }
+}
+
+func (m *mockPasswordUserDataStore[T]) Fetch(username string) (T, bool, error) {
+ user, exists := m.users[username]
+ return user, exists, nil
+}
+
+func (m *mockPasswordUserDataStore[T]) GetPassHash(userData T) string {
+ return m.getHash(userData)
+}
+
+func TestPasswordLoginHandler(t *testing.T) {
+ auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{
+ SerDes: testUserSerDes{},
+ CookieName: "test_session",
+ })
+
+ testHash, err := HashPassword("password123", nil)
+ if err != nil {
+ t.Fatalf("Failed to hash password: %v", err)
+ }
+
+ users := map[string]testUser{
+ "john": {Username: "john", Hash: testHash, ID: 1},
+ "jane": {Username: "jane", Hash: testHash, ID: 2},
+ }
+
+ mockStore := newMockPasswordUserDataStore(users, func(user testUser) string {
+ return user.Hash
+ })
+
+ config := PasswordLoginConfig[testUser]{
+ UserData: mockStore,
+ Redirects: Redirects{
+ Default: "/dashboard",
+ AllowedPrefixes: []string{"/app/", "/admin/"},
+ FieldName: "redirect",
+ },
+ }
+
+ handler := auth.PasswordLoginHandler(config)
+
+ tests := []struct {
+ name string
+ method string
+ formData url.Values
+ expectedStatus int
+ expectedRedirect string
+ checkCookie bool
+ }{
+ {
+ name: "successful login",
+ method: "POST",
+ formData: url.Values{"username": {"john"}, "password": {"password123"}},
+ expectedStatus: http.StatusSeeOther,
+ expectedRedirect: "/dashboard",
+ checkCookie: true,
+ },
+ {
+ name: "successful login with redirect",
+ method: "POST",
+ formData: url.Values{"username": {"john"}, "password": {"password123"}, "redirect": {"/app/settings"}},
+ expectedStatus: http.StatusSeeOther,
+ expectedRedirect: "/app/settings",
+ checkCookie: true,
+ },
+ {
+ name: "invalid redirect falls back to default",
+ method: "POST",
+ formData: url.Values{"username": {"john"}, "password": {"password123"}, "redirect": {"/evil/"}},
+ expectedStatus: http.StatusSeeOther,
+ expectedRedirect: "/dashboard",
+ checkCookie: true,
+ },
+ {
+ name: "wrong password",
+ method: "POST",
+ formData: url.Values{"username": {"john"}, "password": {"wrongpass"}},
+ expectedStatus: http.StatusUnauthorized,
+ checkCookie: false,
+ },
+ {
+ name: "nonexistent user",
+ method: "POST",
+ formData: url.Values{"username": {"nobody"}, "password": {"password123"}},
+ expectedStatus: http.StatusUnauthorized,
+ checkCookie: false,
+ },
+ {
+ name: "missing username",
+ method: "POST",
+ formData: url.Values{"password": {"password123"}},
+ expectedStatus: http.StatusBadRequest,
+ checkCookie: false,
+ },
+ {
+ name: "missing password",
+ method: "POST",
+ formData: url.Values{"username": {"john"}},
+ expectedStatus: http.StatusBadRequest,
+ checkCookie: false,
+ },
+ {
+ name: "GET method not allowed",
+ method: "GET",
+ formData: url.Values{},
+ expectedStatus: http.StatusMethodNotAllowed,
+ checkCookie: false,
+ },
+ {
+ name: "PUT method not allowed",
+ method: "PUT",
+ formData: url.Values{},
+ expectedStatus: http.StatusMethodNotAllowed,
+ checkCookie: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var req *http.Request
+
+ if tt.method == "POST" {
+ req = httptest.NewRequest(tt.method, "/login", strings.NewReader(tt.formData.Encode()))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ } else {
+ req = httptest.NewRequest(tt.method, "/login", nil)
+ }
+
+ rr := httptest.NewRecorder()
+ handler.ServeHTTP(rr, req)
+
+ if rr.Code != tt.expectedStatus {
+ t.Errorf("expected status %d, got %d", tt.expectedStatus, rr.Code)
+ }
+
+ if tt.expectedRedirect != "" {
+ location := rr.Header().Get("Location")
+ if location != tt.expectedRedirect {
+ t.Errorf("expected redirect to %s, got %s", tt.expectedRedirect, location)
+ }
+ }
+
+ if tt.checkCookie {
+ cookies := rr.Result().Cookies()
+ found := false
+ for _, cookie := range cookies {
+ if cookie.Name == "test_session" && cookie.Value != "" {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Error("expected authentication cookie to be set")
+ }
+ }
+ })
+ }
+}
+
+func TestPasswordLoginHandlerCustomFields(t *testing.T) {
+ auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{
+ SerDes: testUserSerDes{},
+ CookieName: "test_session",
+ })
+
+ testHash, err := HashPassword("secret", nil)
+ if err != nil {
+ t.Fatalf("Failed to hash password: %v", err)
+ }
+
+ users := map[string]testUser{
+ "admin": {Username: "admin", Hash: testHash, ID: 99},
+ }
+
+ mockStore := newMockPasswordUserDataStore(users, func(user testUser) string {
+ return user.Hash
+ })
+
+ config := PasswordLoginConfig[testUser]{
+ UserData: mockStore,
+ UsernameField: "email",
+ PasswordField: "pass",
+ Redirects: Redirects{
+ FieldName: "next",
+ Default: "/home",
+ },
+ }
+
+ handler := auth.PasswordLoginHandler(config)
+
+ formData := url.Values{
+ "email": {"admin"},
+ "pass": {"secret"},
+ "next": {"/custom"},
+ }
+
+ req := httptest.NewRequest("POST", "/login", strings.NewReader(formData.Encode()))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ rr := httptest.NewRecorder()
+
+ handler.ServeHTTP(rr, req)
+
+ if rr.Code != http.StatusSeeOther {
+ t.Errorf("expected status %d, got %d", http.StatusSeeOther, rr.Code)
+ }
+
+ location := rr.Header().Get("Location")
+ if location != "/custom" {
+ t.Errorf("expected redirect to /custom, got %s", location)
+ }
+}
+
+func TestPasswordLoginHandlerParseFormError(t *testing.T) {
+ auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{
+ SerDes: testUserSerDes{},
+ CookieName: "test_session",
+ })
+
+ users := map[string]testUser{
+ "test": {Username: "test", Hash: "", ID: 1},
+ }
+
+ mockStore := newMockPasswordUserDataStore(users, func(user testUser) string {
+ return user.Hash
+ })
+
+ config := PasswordLoginConfig[testUser]{
+ UserData: mockStore,
+ }
+
+ handler := auth.PasswordLoginHandler(config)
+
+ req := httptest.NewRequest("POST", "/login", strings.NewReader("username=%"))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ rr := httptest.NewRecorder()
+
+ handler.ServeHTTP(rr, req)
+
+ if rr.Code != http.StatusBadRequest {
+ t.Errorf("expected status %d, got %d", http.StatusBadRequest, rr.Code)
+ }
+}
+
+func TestPasswordLoginHandlerGetPassHashError(t *testing.T) {
+ auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{
+ SerDes: testUserSerDes{},
+ CookieName: "test_session",
+ })
+
+ users := map[string]testUser{
+ "user": {Username: "user", Hash: "", ID: 1},
+ }
+
+ mockStore := newMockPasswordUserDataStore(users, func(user testUser) string {
+ return user.Hash // Empty hash will cause password comparison to fail
+ })
+
+ config := PasswordLoginConfig[testUser]{
+ UserData: mockStore,
+ }
+
+ handler := auth.PasswordLoginHandler(config)
+
+ formData := url.Values{"username": {"user"}, "password": {"pass"}}
+ req := httptest.NewRequest("POST", "/login", strings.NewReader(formData.Encode()))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ rr := httptest.NewRecorder()
+
+ handler.ServeHTTP(rr, req)
+
+ if rr.Code != http.StatusUnauthorized {
+ t.Errorf("expected status %d, got %d", http.StatusUnauthorized, rr.Code)
+ }
+}
+
+func TestPasswordLoginHandlerMalformedHash(t *testing.T) {
+ auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{
+ SerDes: testUserSerDes{},
+ CookieName: "test_session",
+ })
+
+ users := map[string]testUser{
+ "user": {Username: "user", Hash: "invalid-hash", ID: 1},
+ }
+
+ mockStore := newMockPasswordUserDataStore(users, func(user testUser) string {
+ return user.Hash
+ })
+
+ config := PasswordLoginConfig[testUser]{
+ UserData: mockStore,
+ }
+
+ handler := auth.PasswordLoginHandler(config)
+
+ formData := url.Values{"username": {"user"}, "password": {"pass"}}
+ req := httptest.NewRequest("POST", "/login", strings.NewReader(formData.Encode()))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ rr := httptest.NewRecorder()
+
+ handler.ServeHTTP(rr, req)
+
+ if rr.Code != http.StatusUnauthorized {
+ t.Errorf("expected status %d, got %d", http.StatusUnauthorized, rr.Code)
+ }
+}
+
+func TestRedirectsIsValid(t *testing.T) {
+ tests := []struct {
+ name string
+ target string
+ allowedPrefixes []string
+ expected bool
+ }{
+ {
+ name: "no restrictions allows anything",
+ target: "/anything",
+ allowedPrefixes: []string{},
+ expected: true,
+ },
+ {
+ name: "valid prefix match",
+ target: "/app/dashboard",
+ allowedPrefixes: []string{"/app/", "/admin/"},
+ expected: true,
+ },
+ {
+ name: "invalid prefix",
+ target: "/evil/",
+ allowedPrefixes: []string{"/app/", "/admin/"},
+ expected: false,
+ },
+ {
+ name: "absolute URL rejected",
+ target: "https://evil.com/",
+ allowedPrefixes: []string{"/app/"},
+ expected: false,
+ },
+ {
+ name: "protocol relative URL rejected",
+ target: "//evil.com/",
+ allowedPrefixes: []string{"/app/"},
+ expected: false,
+ },
+ {
+ name: "invalid URL rejected",
+ target: ":/invalid",
+ allowedPrefixes: []string{"/app/"},
+ expected: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ redirects := Redirects{AllowedPrefixes: tt.allowedPrefixes}
+ result := redirects.isValid(tt.target)
+ if result != tt.expected {
+ t.Errorf("Redirects{AllowedPrefixes: %v}.isValid(%q) = %v, want %v", tt.allowedPrefixes, tt.target, result, tt.expected)
+ }
+ })
+ }
+}
+
diff --git a/password_test.go b/password_test.go
new file mode 100644
index 0000000..cfe90e0
--- /dev/null
+++ b/password_test.go
@@ -0,0 +1,230 @@
+package kate
+
+import (
+ "strings"
+ "testing"
+ "testing/quick"
+)
+
+func TestHashPasswordComparePasswordRoundtrip(t *testing.T) {
+ property := func(password string) bool {
+ hash, err := HashPassword(password, nil)
+ if err != nil {
+ return false
+ }
+
+ match, err := ComparePassword(password, hash)
+ if err != nil {
+ return false
+ }
+
+ return match
+ }
+
+ if err := quick.Check(property, nil); err != nil {
+ t.Errorf("roundtrip property failed: %v", err)
+ }
+}
+
+func TestComparePasswordWrongPassword(t *testing.T) {
+ property := func(password, wrongPassword string) bool {
+ if password == wrongPassword {
+ return true
+ }
+
+ hash, err := HashPassword(password, nil)
+ if err != nil {
+ return false
+ }
+
+ match, err := ComparePassword(wrongPassword, hash)
+ if err != nil {
+ return false
+ }
+
+ return !match
+ }
+
+ if err := quick.Check(property, nil); err != nil {
+ t.Errorf("wrong password property failed: %v", err)
+ }
+}
+
+func TestHashPasswordUniqueness(t *testing.T) {
+ property := func(password string) bool {
+ hash1, err1 := HashPassword(password, nil)
+ hash2, err2 := HashPassword(password, nil)
+
+ if err1 != nil || err2 != nil {
+ return false
+ }
+
+ return hash1 != hash2
+ }
+
+ if err := quick.Check(property, nil); err != nil {
+ t.Errorf("uniqueness property failed: %v", err)
+ }
+}
+
+func TestHashPasswordFormat(t *testing.T) {
+ property := func(password string) bool {
+ hash, err := HashPassword(password, nil)
+ if err != nil {
+ return false
+ }
+
+ parts := strings.Split(hash, "$")
+ if len(parts) != 6 {
+ return false
+ }
+
+ return parts[0] == "" && parts[1] == "argon2id"
+ }
+
+ if err := quick.Check(property, nil); err != nil {
+ t.Errorf("format property failed: %v", err)
+ }
+}
+
+func TestComparePasswordMalformedHash(t *testing.T) {
+ tests := []string{
+ "",
+ "invalid",
+ "$argon2id",
+ "$argon2id$v=19$m=65536,t=1,p=4",
+ "$argon2id$v=19$m=65536,t=1,p=4$salt",
+ "$wrong$v=19$m=65536,t=1,p=4$salt$hash",
+ "$argon2id$v=999$m=65536,t=1,p=4$salt$hash",
+ "$argon2id$v=19$invalid$salt$hash",
+ "$argon2id$v=19$m=65536,t=1,p=4$!!!$hash",
+ "$argon2id$v=19$m=65536,t=1,p=4$salt$!!!",
+ }
+
+ for _, malformedHash := range tests {
+ t.Run("malformed_"+malformedHash, func(t *testing.T) {
+ match, err := ComparePassword("password", malformedHash)
+ if err == nil {
+ t.Error("expected error for malformed hash")
+ }
+ if match {
+ t.Error("should not match with malformed hash")
+ }
+ })
+ }
+}
+
+func TestEmptyPassword(t *testing.T) {
+ hash, err := HashPassword("", nil)
+ if err != nil {
+ t.Fatalf("HashPassword failed for empty password: %v", err)
+ }
+
+ match, err := ComparePassword("", hash)
+ if err != nil {
+ t.Fatalf("ComparePassword failed: %v", err)
+ }
+ if !match {
+ t.Error("empty password should match its hash")
+ }
+
+ match, err = ComparePassword("nonempty", hash)
+ if err != nil {
+ t.Fatalf("ComparePassword failed: %v", err)
+ }
+ if match {
+ t.Error("non-empty password should not match empty password hash")
+ }
+}
+
+func TestHashPasswordWithConfig(t *testing.T) {
+ tests := []struct {
+ name string
+ config *Argon2Config
+ want string // Expected parameters in PHC format
+ }{
+ {
+ name: "nil config uses defaults",
+ config: nil,
+ want: "$argon2id$v=19$m=65536,t=1,p=4$",
+ },
+ {
+ name: "custom time parameter",
+ config: &Argon2Config{
+ Time: 3,
+ },
+ want: "$argon2id$v=19$m=65536,t=3,p=4$",
+ },
+ {
+ name: "custom memory parameter",
+ config: &Argon2Config{
+ Memory: 32 * 1024, // 32MB
+ },
+ want: "$argon2id$v=19$m=32768,t=1,p=4$",
+ },
+ {
+ name: "custom threads parameter",
+ config: &Argon2Config{
+ Threads: 2,
+ },
+ want: "$argon2id$v=19$m=65536,t=1,p=2$",
+ },
+ {
+ name: "all custom parameters",
+ config: &Argon2Config{
+ Time: 2,
+ Memory: 128 * 1024,
+ Threads: 8,
+ KeyLen: 64,
+ },
+ want: "$argon2id$v=19$m=131072,t=2,p=8$",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ hash, err := HashPassword("test", tt.config)
+ if err != nil {
+ t.Fatalf("HashPassword() error = %v", err)
+ }
+
+ // Check that the hash starts with expected PHC format parameters
+ if !strings.HasPrefix(hash, tt.want) {
+ t.Errorf("HashPassword() = %v, want prefix %v", hash, tt.want)
+ }
+
+ // Verify password still works
+ match, err := ComparePassword("test", hash)
+ if err != nil {
+ t.Fatalf("ComparePassword() error = %v", err)
+ }
+ if !match {
+ t.Error("Password should match its hash")
+ }
+
+ // Verify wrong password doesn't match
+ match, err = ComparePassword("wrong", hash)
+ if err != nil {
+ t.Fatalf("ComparePassword() error = %v", err)
+ }
+ if match {
+ t.Error("Wrong password should not match hash")
+ }
+ })
+ }
+}
+
+func TestDefaultArgon2Config(t *testing.T) {
+ if defaultArgon2Config.Time != 1 {
+ t.Errorf("Expected defaultArgon2Config.Time=1, got %d", defaultArgon2Config.Time)
+ }
+ if defaultArgon2Config.Memory != 64*1024 {
+ t.Errorf("Expected defaultArgon2Config.Memory=65536, got %d", defaultArgon2Config.Memory)
+ }
+ if defaultArgon2Config.Threads != 4 {
+ t.Errorf("Expected defaultArgon2Config.Threads=4, got %d", defaultArgon2Config.Threads)
+ }
+ if defaultArgon2Config.KeyLen != 32 {
+ t.Errorf("Expected defaultArgon2Config.KeyLen=32, got %d", defaultArgon2Config.KeyLen)
+ }
+}