From 639ad6a02cbb4b713434671ec09f309aa5410921 Mon Sep 17 00:00:00 2001 From: T Date: Thu, 26 Jun 2025 11:42:17 -0600 Subject: Create authentic_kate: user authentication for go HTTP applications --- README.md | 203 +++++++++++++++++++ auth.go | 169 ++++++++++++++++ auth_test.go | 391 +++++++++++++++++++++++++++++++++++++ compat-test/README.md | 44 +++++ compat-test/compat_test.go | 167 ++++++++++++++++ compat-test/go.mod | 17 ++ compat-test/go.sum | 10 + doc.go | 150 ++++++++++++++ encryption.go | 136 +++++++++++++ encryption_test.go | 151 +++++++++++++++ go.mod | 13 ++ go.sum | 8 + login_helpers.go | 64 ++++++ magic_link_login.go | 233 ++++++++++++++++++++++ magic_link_login_test.go | 473 +++++++++++++++++++++++++++++++++++++++++++++ oauth_login.go | 257 ++++++++++++++++++++++++ oauth_login_test.go | 369 +++++++++++++++++++++++++++++++++++ password.go | 131 +++++++++++++ password_login.go | 105 ++++++++++ password_login_test.go | 385 ++++++++++++++++++++++++++++++++++++ password_test.go | 230 ++++++++++++++++++++++ 21 files changed, 3706 insertions(+) create mode 100644 README.md create mode 100644 auth.go create mode 100644 auth_test.go create mode 100644 compat-test/README.md create mode 100644 compat-test/compat_test.go create mode 100644 compat-test/go.mod create mode 100644 compat-test/go.sum create mode 100644 doc.go create mode 100644 encryption.go create mode 100644 encryption_test.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 login_helpers.go create mode 100644 magic_link_login.go create mode 100644 magic_link_login_test.go create mode 100644 oauth_login.go create mode 100644 oauth_login_test.go create mode 100644 password.go create mode 100644 password_login.go create mode 100644 password_login_test.go create mode 100644 password_test.go diff --git a/README.md b/README.md new file mode 100644 index 0000000..d5d7882 --- /dev/null +++ b/README.md @@ -0,0 +1,203 @@ +# authentic_kate + +Package kate provides secure cookie-based authentication for HTTP applications. + +This package implements encrypted cookies with authenticated encryption. +It supports generic types to allow storage of any serializable data in encrypted HTTP cookies. + +## Key features + +- Type-safe authentication data storage using Go generics +- Authenticated encryption for secure cookie storage +- HTTP middlewares for required and optional authentication +- Ready-to-use login handlers for password, oauth, and magic link authentication +- Configurable cookie properties (domain, path, HTTPS-only, max age) +- Secure password hashing with industry-standard algorithms + +## Example usage + +```go +type UserData struct { + ID int + Name string +} + +// Implement SerDes interface for UserData +type UserSerDes struct{} +func (s UserSerDes) Serialize(w io.Writer, data UserData) error { /* ... */ } +func (s UserSerDes) Deserialize(r io.Reader, data *UserData) error { /* ... */ } + +// Create authentication instance +auth := kate.New("hex-encoded-32-byte-key", kate.AuthConfig[UserData]{ + SerDes: UserSerDes{}, + CookieName: "session", + HTTPSOnly: true, + MaxAge: 24 * time.Hour, +}) + +// Use as middleware +http.Handle("GET /protected", auth.Required(protectedHandler)) + +// Set authentication cookie +user := UserData{ID: 123, Name: "John"} +auth.Set(w, user) + +// Get authenticated data +user, ok := auth.Get(r.Context()) +``` + +## Login handlers + +The package provides ready-to-use HTTP handlers for several authentication methods: + +### Password-based login + +```go +type User struct { + Username string + Hash string + ID int +} + +// Implement PasswordUserDataStore interface +type UserStore struct{} + +func (us UserStore) Fetch(username string) (User, bool, error) { + // Look up user in your database + user, exists := database.FindUser(username) + return user, exists, nil +} + +func (us UserStore) GetPassHash(user User) string { + return user.Hash +} + +// Configure password login +passwordConfig := kate.PasswordLoginConfig[User]{ + UserData: UserStore{}, + Redirects: kate.Redirects{ + Default: "/dashboard", + AllowedPrefixes: []string{"/app/", "/admin/"}, + FieldName: "redirect", + }, +} + +// Create handler +http.Handle("POST /login", auth.PasswordLoginHandler(passwordConfig)) +``` + +### Magic link authentication + +```go +// Implement MagicLinkMailer interface +type UserMailer struct{} + +func (um UserMailer) Fetch(email string) (User, bool, error) { + // Look up user by email + user, exists := database.FindUserByEmail(email) + return user, exists, nil +} + +func (um UserMailer) SendEmail(user User, token string) error { + // Send magic link email with the token + return emailService.SendMagicLink(user.Email, token) +} + +// Configure magic link login +magicConfig := kate.MagicLinkConfig[User]{ + Mailer: UserMailer{}, + Redirects: kate.Redirects{ Default: "/dashboard" }, + TokenExpiry: 15 * time.Minute, + TokenLocation: kate.TokenLocationQuery, // or TokenLocationPath +} + +// Create handlers +http.Handle("POST /magic-link", auth.MagicLinkLoginHandler(magicConfig)) +http.Handle("GET /verify", auth.MagicLinkVerifyHandler(magicConfig)) +``` + +### OAuth2 authentication + +```go +// Implement the OAuthDataStore interface +type UserDataStore struct{} + +func (ds UserDataStore) GetOrCreateUser(email string) (User, error) { + // Look up user by email, create if doesn't exist + if user, exists := database.FindUserByEmail(email); exists { + return user, nil + } + return database.CreateUser(email) +} + +func (ds UserDataStore) StoreState(state kate.OAuthState) (string, error) { + // Generate unique state ID and store temporarily + id := generateUniqueID() + return id, cache.Set(id, state, 10*time.Minute) +} + +func (ds UserDataStore) GetAndClearState(id string) (*kate.OAuthState, error) { + // Retrieve and delete state data + var state kate.OAuthState + exists, err := cache.GetAndDelete(id, &state) + if !exists { + return nil, err + } + return &state, err +} + +// Configure Google OAuth2 login +oauthConfig := kate.GoogleOAuthConfig[User]( + "your-google-client-id", + "your-google-client-secret", + "https://yourapp.com/auth/callback", + UserDataStore{}, +) + +// Customize redirects and security settings +oauthConfig.Redirects = kate.Redirects{ + Default: "/dashboard", + AllowedPrefixes: []string{"/app/", "/admin/"}, + FieldName: "redirect", +} + +// Create handlers +http.Handle("GET /auth/login", auth.OAuthLoginHandler(oauthConfig)) +http.Handle("GET /auth/callback", auth.OAuthCallbackHandler(oauthConfig)) + +// GitHub OAuth2 is also supported +githubConfig := kate.GitHubOAuthConfig[User]( + "your-github-client-id", + "your-github-client-secret", + "https://yourapp.com/auth/github/callback", + UserDataStore{}, +) +``` + +## Password hashing + +The package provides secure password hashing functions using Argon2id: + +```go +// Hash a password with secure defaults +hash, err := kate.HashPassword("user-password", nil) +if err != nil { + // Handle error +} + +// Verify a password +match, err := kate.ComparePassword("user-password", hash) +if err != nil { + // Handle malformed hash error +} +if match { + // Password is correct +} +``` + +## Cryptographic algorithms + +This package uses the following cryptographic algorithms: + +- **Token encryption**: XSalsa20 stream cipher with Poly1305 MAC for authenticated encryption (compatible with libsodium secretbox) +- **Password hashing**: Argon2id with secure default parameters and PHC format storage diff --git a/auth.go b/auth.go new file mode 100644 index 0000000..1066d9f --- /dev/null +++ b/auth.go @@ -0,0 +1,169 @@ +package kate + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net/http" + "time" +) + +// Auth provides secure cookie-based authentication for HTTP applications. +// +// It uses generic type T to allow storage of any serializable data in encrypted cookies. +type Auth[T any] struct { + enc encryption + config AuthConfig[T] +} + +// AuthConfig holds configuration settings for the Auth instance. +// +// It specifies how data is serialized, cookie properties, and security settings. +type AuthConfig[T any] struct { + // SerDes handles serialization and deserialization of authentication data + SerDes SerDes[T] + + // CookieName is the name of the HTTP cookie used for authentication + CookieName string + + // URLPath restricts the cookie to a specific path on the server (optional) + URLPath string + + // URLDomain restricts the cookie to a specific domain (optional) + URLDomain string + + // HTTPSOnly when true, requires cookies to be sent only over HTTPS connections + HTTPSOnly bool + + // MaxAge determines how long in seconds the authentication cookie remains valid + MaxAge time.Duration +} + +// SerDes defines the interface for serializing and deserializing authentication data. +// +// Implementations must handle conversion between type T and byte streams. +type SerDes[T any] interface { + // Serialize writes the data of type T to the provided writer + Serialize(io.Writer, T) error + + // Deserialize reads data from the reader and populates the provided pointer + Deserialize(io.Reader, *T) error +} + +// New creates a new Auth instance with the given private key and configuration. +// +// The private key must be a hex-encoded string used for cookie encryption. +// Panics if the private key is invalid. +func New[T any](privkey string, config AuthConfig[T]) Auth[T] { + enc, err := encryptionFromHexKey(privkey) + if err != nil { + panic(err.Error()) + } + return Auth[T]{enc: enc, config: config} +} + +// Required is an HTTP middleware that enforces authentication. +// +// It checks for a valid authentication cookie and makes the authenticated data +// available in the request context. Returns 401 Unauthorized if authentication fails. +func (a Auth[T]) Required(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie(a.config.CookieName) + if errors.Is(err, http.ErrNoCookie) { + http.Error(w, "Authentication missing", http.StatusUnauthorized) + return + } + + cleartext, ok := a.enc.Decrypt(cookie.Value) + if !ok { + http.Error(w, "Authentication failed", http.StatusUnauthorized) + return + } + + var data T + if err := a.config.SerDes.Deserialize(bytes.NewBuffer(cleartext), &data); err != nil { + http.Error(w, "Server error", http.StatusInternalServerError) + return + } + + handler.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), key, data))) + }) +} + +// Optional returns an HTTP middleware that allows optional authentication. +// +// It checks for a valid authentication cookie and makes the authenticated data +// available in the request context if present. Unlike Required, this middleware +// allows requests to proceed even when authentication is missing or invalid. +// Returns 500 Internal Server Error only if deserialization fails on valid authentication data. +func (a Auth[T]) Optional(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie(a.config.CookieName) + if errors.Is(err, http.ErrNoCookie) { + handler.ServeHTTP(w, r) + return + } + + cleartext, ok := a.enc.Decrypt(cookie.Value) + if !ok { + handler.ServeHTTP(w, r) + return + } + + var data T + if err := a.config.SerDes.Deserialize(bytes.NewBuffer(cleartext), &data); err != nil { + http.Error(w, "Server error", http.StatusInternalServerError) + return + } + + handler.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), key, data))) + }) +} + +// Set creates and sets an authentication cookie containing the provided data. +// +// The data is serialized, encrypted, and stored in an HTTP cookie. +// Returns an error if serialization fails. +func (a Auth[T]) Set(w http.ResponseWriter, data T) error { + buf := &bytes.Buffer{} + if err := a.config.SerDes.Serialize(buf, data); err != nil { + return fmt.Errorf("Auth.Set: %w", err) + } + cookie := &http.Cookie{ + Name: a.config.CookieName, + Value: a.enc.Encrypt(buf.Bytes()), + Path: a.config.URLPath, + Domain: a.config.URLDomain, + MaxAge: int(a.config.MaxAge / time.Second), + Secure: a.config.HTTPSOnly, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + } + w.Header().Add("Set-Cookie", cookie.String()) + return nil +} + +// Get retrieves authentication data from the request context. +// +// Returns the data and true if authentication data is present and valid, +// otherwise returns the zero value and false. Should be called within handlers +// protected by the Required middleware. +func (a Auth[T]) Get(ctx context.Context) (T, bool) { + var zero T + val := ctx.Value(key) + if val == nil { + return zero, false + } + switch v := val.(type) { + case T: + return v, true + default: + return zero, false + } +} + +type keyt struct{} + +var key = keyt{} diff --git a/auth_test.go b/auth_test.go new file mode 100644 index 0000000..131e132 --- /dev/null +++ b/auth_test.go @@ -0,0 +1,391 @@ +package kate + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +type testData struct { + UserID int + Name string +} + +type testSerDes struct{} + +func (ts testSerDes) Serialize(w io.Writer, data testData) error { + _, err := w.Write([]byte(strings.Join([]string{string(rune(data.UserID)), data.Name}, "|"))) + return err +} + +func (ts testSerDes) Deserialize(r io.Reader, data *testData) error { + buf, err := io.ReadAll(r) + if err != nil { + return err + } + parts := strings.Split(string(buf), "|") + if len(parts) != 2 { + return io.ErrUnexpectedEOF + } + data.UserID = int(rune(parts[0][0])) + data.Name = parts[1] + return nil +} + +func createTestAuth() Auth[testData] { + config := AuthConfig[testData]{ + SerDes: testSerDes{}, + CookieName: "test_auth", + URLPath: "/", + URLDomain: "example.com", + HTTPSOnly: false, + MaxAge: time.Hour, + } + return New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config) +} + +func TestNew(t *testing.T) { + tests := []struct { + name string + privkey string + wantErr bool + }{ + { + name: "valid hex key", + privkey: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + wantErr: false, + }, + { + name: "invalid hex key - too short", + privkey: "0123456789abcdef", + wantErr: true, + }, + { + name: "invalid hex key - not hex", + privkey: "invalid_hex_key_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := AuthConfig[testData]{ + SerDes: testSerDes{}, + CookieName: "test", + } + + defer func() { + if r := recover(); r != nil { + if !tt.wantErr { + t.Errorf("New() panicked unexpectedly: %v", r) + } + } else if tt.wantErr { + t.Errorf("New() should have panicked but didn't") + } + }() + + auth := New(tt.privkey, config) + if !tt.wantErr && auth.config.CookieName != "test" { + t.Errorf("New() failed to create auth instance properly") + } + }) + } +} + +func TestAuth_Set(t *testing.T) { + auth := createTestAuth() + data := testData{UserID: 123, Name: "Alice"} + + w := httptest.NewRecorder() + err := auth.Set(w, data) + if err != nil { + t.Fatalf("Set() error = %v", err) + } + + cookies := w.Result().Cookies() + if len(cookies) != 1 { + t.Fatalf("Expected 1 cookie, got %d", len(cookies)) + } + + cookie := cookies[0] + if cookie.Name != "test_auth" { + t.Errorf("Expected cookie name 'test_auth', got %s", cookie.Name) + } + if cookie.Path != "/" { + t.Errorf("Expected cookie path '/', got %s", cookie.Path) + } + if cookie.Domain != "example.com" { + t.Errorf("Expected cookie domain 'example.com', got %s", cookie.Domain) + } + if !cookie.HttpOnly { + t.Error("Expected cookie to be HttpOnly") + } + if cookie.SameSite != http.SameSiteLaxMode { + t.Errorf("Expected SameSite to be Lax, got %v", cookie.SameSite) + } +} + +func TestAuth_Required(t *testing.T) { + auth := createTestAuth() + data := testData{UserID: 123, Name: "Alice"} + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authData, ok := auth.Get(r.Context()) + if !ok { + t.Error("Expected to find auth data in context") + return + } + if authData.UserID != 123 || authData.Name != "Alice" { + t.Errorf("Expected UserID=123, Name=Alice, got UserID=%d, Name=%s", authData.UserID, authData.Name) + } + w.WriteHeader(http.StatusOK) + }) + + protectedHandler := auth.Required(handler) + + t.Run("missing cookie", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + protectedHandler.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("Expected status 401, got %d", w.Code) + } + }) + + t.Run("invalid cookie", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + req.AddCookie(&http.Cookie{ + Name: "test_auth", + Value: "invalid_token", + }) + w := httptest.NewRecorder() + + protectedHandler.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("Expected status 401, got %d", w.Code) + } + }) + + t.Run("valid cookie", func(t *testing.T) { + setW := httptest.NewRecorder() + if err := auth.Set(setW, data); err != nil { + t.Fatalf("Set() error = %v", err) + } + cookie := setW.Result().Cookies()[0] + + req := httptest.NewRequest("GET", "/", nil) + req.AddCookie(cookie) + w := httptest.NewRecorder() + + protectedHandler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + }) +} + +func TestAuth_Optional(t *testing.T) { + auth := createTestAuth() + data := testData{UserID: 123, Name: "Alice"} + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authData, ok := auth.Get(r.Context()) + if ok { + if authData.UserID != 123 || authData.Name != "Alice" { + t.Errorf("Expected UserID=123, Name=Alice, got UserID=%d, Name=%s", authData.UserID, authData.Name) + } + } + w.WriteHeader(http.StatusOK) + }) + + optionalHandler := auth.Optional(handler) + + t.Run("missing cookie - should allow", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + optionalHandler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + }) + + t.Run("invalid cookie - should allow", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + req.AddCookie(&http.Cookie{ + Name: "test_auth", + Value: "invalid_token", + }) + w := httptest.NewRecorder() + + optionalHandler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + }) + + t.Run("valid cookie", func(t *testing.T) { + setW := httptest.NewRecorder() + if err := auth.Set(setW, data); err != nil { + t.Fatalf("Set() error = %v", err) + } + cookie := setW.Result().Cookies()[0] + + req := httptest.NewRequest("GET", "/", nil) + req.AddCookie(cookie) + w := httptest.NewRecorder() + + optionalHandler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + }) +} + +func TestAuth_Get(t *testing.T) { + auth := createTestAuth() + data := testData{UserID: 123, Name: "Alice"} + + t.Run("no data in context", func(t *testing.T) { + ctx := context.Background() + result, ok := auth.Get(ctx) + if ok { + t.Error("Expected Get to return false for empty context") + } + if result.UserID != 0 || result.Name != "" { + t.Error("Expected zero value for missing data") + } + }) + + t.Run("data in context", func(t *testing.T) { + ctx := context.WithValue(context.Background(), key, data) + result, ok := auth.Get(ctx) + if !ok { + t.Error("Expected Get to return true for context with data") + } + if result.UserID != 123 || result.Name != "Alice" { + t.Errorf("Expected UserID=123, Name=Alice, got UserID=%d, Name=%s", result.UserID, result.Name) + } + }) + + t.Run("wrong type in context", func(t *testing.T) { + ctx := context.WithValue(context.Background(), key, "wrong_type") + result, ok := auth.Get(ctx) + if ok { + t.Error("Expected Get to return false for wrong type in context") + } + if result.UserID != 0 || result.Name != "" { + t.Error("Expected zero value for wrong type") + } + }) +} + +type failingSerDes struct { + failSerialize bool + failDeserialize bool +} + +func (f failingSerDes) Serialize(w io.Writer, data testData) error { + if f.failSerialize { + return io.ErrClosedPipe + } + return testSerDes{}.Serialize(w, data) +} + +func (f failingSerDes) Deserialize(r io.Reader, data *testData) error { + if f.failDeserialize { + return io.ErrUnexpectedEOF + } + return testSerDes{}.Deserialize(r, data) +} + +func TestAuth_Set_SerializationError(t *testing.T) { + config := AuthConfig[testData]{ + SerDes: failingSerDes{failSerialize: true}, + CookieName: "test_auth", + } + auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config) + + data := testData{UserID: 123, Name: "Alice"} + w := httptest.NewRecorder() + + err := auth.Set(w, data) + if err == nil { + t.Error("Expected Set to return error for serialization failure") + } +} + +func TestAuth_Required_DeserializationError(t *testing.T) { + config := AuthConfig[testData]{ + SerDes: failingSerDes{failDeserialize: true}, + CookieName: "test_auth", + } + goodAuth := createTestAuth() + badAuth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config) + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + protectedHandler := badAuth.Required(handler) + + data := testData{UserID: 123, Name: "Alice"} + setW := httptest.NewRecorder() + if err := goodAuth.Set(setW, data); err != nil { + t.Fatalf("Set() error = %v", err) + } + cookie := setW.Result().Cookies()[0] + + req := httptest.NewRequest("GET", "/", nil) + req.AddCookie(cookie) + w := httptest.NewRecorder() + + protectedHandler.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("Expected status 500 for deserialization error, got %d", w.Code) + } +} + +func TestAuth_Optional_DeserializationError(t *testing.T) { + config := AuthConfig[testData]{ + SerDes: failingSerDes{failDeserialize: true}, + CookieName: "test_auth", + } + goodAuth := createTestAuth() + badAuth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", config) + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + optionalHandler := badAuth.Optional(handler) + + data := testData{UserID: 123, Name: "Alice"} + setW := httptest.NewRecorder() + if err := goodAuth.Set(setW, data); err != nil { + t.Fatalf("Set() error = %v", err) + } + cookie := setW.Result().Cookies()[0] + + req := httptest.NewRequest("GET", "/", nil) + req.AddCookie(cookie) + w := httptest.NewRecorder() + + optionalHandler.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("Expected status 500 for deserialization error, got %d", w.Code) + } +} diff --git a/compat-test/README.md b/compat-test/README.md new file mode 100644 index 0000000..c1e91b8 --- /dev/null +++ b/compat-test/README.md @@ -0,0 +1,44 @@ +# Kate libsodium Compatibility Tests + +This directory contains tests to verify that the `kate` package's `seal` and `open` functions are compatible with the real libsodium implementation. + +## Why? + +The `kate` library is designed to be CGO-free to avoid the complexity and deployment challenges that come with C dependencies. Rather than requiring applications to link against libsodium, `kate` implements a pure Go version of the NaCl secretbox encryption scheme. + +This separate compatibility test project verifies that `kate`'s implementation produces ciphertext that is fully interoperable with libsodium while keeping the CGO dependency isolated from the main package. This way, applications get the benefits of a pure Go library while maintaining confidence in libsodium compatibility. + +## Setup + +The tests require libsodium to be installed on your system. The project uses the `github.com/jamesruan/sodium` Go bindings for libsodium. + +### Installing libsodium + +On macOS: +```bash +brew install libsodium +``` + +On Ubuntu/Debian: +```bash +sudo apt-get install libsodium-dev +``` + +On CentOS/RHEL: +```bash +sudo yum install libsodium-devel +``` + +## Running the Tests + +```bash +cd compat-test +go test +``` + +## Test Cases + +The test suite includes: + +1. **KateEncrypt_LibsodiumDecrypt**: Encrypts data with kate's library and decrypts with libsodium +2. **LibsodiumEncrypt_KateDecrypt**: Encrypts data with libsodium and decrypts with kate's library diff --git a/compat-test/compat_test.go b/compat-test/compat_test.go new file mode 100644 index 0000000..2b0bfd7 --- /dev/null +++ b/compat-test/compat_test.go @@ -0,0 +1,167 @@ +package main + +import ( + "crypto/rand" + "encoding/hex" + "io" + "net/http" + "net/http/httptest" + "testing" + "testing/quick" + + "git.tjp.lol/authentic_kate" + "github.com/jamesruan/sodium" +) + +type ByteSerDes struct{} + +func (ByteSerDes) Serialize(w io.Writer, data []byte) error { + _, err := w.Write(data) + return err +} + +func (ByteSerDes) Deserialize(r io.Reader, data *[]byte) error { + buf, err := io.ReadAll(r) + if err != nil { + return err + } + *data = buf + return nil +} + +func kateSeal(key [32]byte, plaintext []byte) []byte { + keyHex := hex.EncodeToString(key[:]) + + auth := kate.New(keyHex, kate.AuthConfig[[]byte]{ + SerDes: ByteSerDes{}, + CookieName: "test", + }) + + w := httptest.NewRecorder() + if err := auth.Set(w, plaintext); err != nil { + panic(err) + } + + cookies := w.Result().Cookies() + if len(cookies) == 0 { + panic("No cookie set") + } + + encryptedBytes, err := hex.DecodeString(cookies[0].Value) + if err != nil { + panic(err) + } + + return encryptedBytes +} + +func libsodiumSeal(key [32]byte, plaintext []byte) []byte { + var nonce [24]byte + if _, err := rand.Read(nonce[:]); err != nil { + panic(err) + } + + ciphertextAndMac := sodium.Bytes(plaintext).SecretBox( + sodium.SecretBoxNonce{Bytes: nonce[:]}, + sodium.SecretBoxKey{Bytes: key[:]}, + ) + + result := make([]byte, 24+len(ciphertextAndMac)) + copy(result[:24], nonce[:]) + copy(result[24:], ciphertextAndMac) + + return result +} + +func kateOpen(key [32]byte, box []byte) ([]byte, bool) { + keyHex := hex.EncodeToString(key[:]) + + auth := kate.New(keyHex, kate.AuthConfig[[]byte]{ + SerDes: ByteSerDes{}, + CookieName: "test", + }) + + cookieValue := hex.EncodeToString(box) + req := httptest.NewRequest("GET", "/", nil) + req.AddCookie(&http.Cookie{Name: "test", Value: cookieValue}) + + var decryptedData []byte + var success bool + handler := auth.Optional(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + data, ok := auth.Get(r.Context()) + if ok { + decryptedData = data + success = true + } + })) + + testW := httptest.NewRecorder() + handler.ServeHTTP(testW, req) + + return decryptedData, success +} + +func libsodiumOpen(key [32]byte, box []byte) ([]byte, bool) { + if len(box) < 24 { + return nil, false + } + + nonce := box[:24] + ciphertextAndMac := box[24:] + + decrypted, err := sodium.Bytes(ciphertextAndMac).SecretBoxOpen( + sodium.SecretBoxNonce{Bytes: nonce}, + sodium.SecretBoxKey{Bytes: key[:]}, + ) + if err != nil { + return nil, false + } + + return decrypted, true +} + +func TestCrossLibraryCompatibility(t *testing.T) { + t.Run("KateEncrypt_LibsodiumDecrypt", func(t *testing.T) { + f := func(data []byte) bool { + if len(data) == 0 { + return true + } + + var key [32]byte + if _, err := rand.Read(key[:]); err != nil { + return false + } + + kateBox := kateSeal(key, data) + decrypted, ok := libsodiumOpen(key, kateBox) + + return ok && string(decrypted) == string(data) + } + + if err := quick.Check(f, nil); err != nil { + t.Errorf("Property test failed: %v", err) + } + }) + + t.Run("LibsodiumEncrypt_KateDecrypt", func(t *testing.T) { + f := func(data []byte) bool { + if len(data) == 0 { + return true + } + + var key [32]byte + if _, err := rand.Read(key[:]); err != nil { + return false + } + + libsodiumBox := libsodiumSeal(key, data) + decrypted, ok := kateOpen(key, libsodiumBox) + + return ok && string(decrypted) == string(data) + } + + if err := quick.Check(f, nil); err != nil { + t.Errorf("Property test failed: %v", err) + } + }) +} diff --git a/compat-test/go.mod b/compat-test/go.mod new file mode 100644 index 0000000..c7711c9 --- /dev/null +++ b/compat-test/go.mod @@ -0,0 +1,17 @@ +module git.tjp.lol/authentic_kate/compat-test + +go 1.24.4 + +require ( + git.tjp.lol/authentic_kate v0.0.0-00010101000000-000000000000 + github.com/jamesruan/sodium v1.0.14 +) + +require ( + cloud.google.com/go/compute/metadata v0.3.0 // indirect + golang.org/x/crypto v0.39.0 // indirect + golang.org/x/oauth2 v0.30.0 // indirect + golang.org/x/sys v0.33.0 // indirect +) + +replace git.tjp.lol/authentic_kate => ../ diff --git a/compat-test/go.sum b/compat-test/go.sum new file mode 100644 index 0000000..c96829d --- /dev/null +++ b/compat-test/go.sum @@ -0,0 +1,10 @@ +cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= +cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= +github.com/jamesruan/sodium v1.0.14 h1:JfOHobip/lUWouxHV3PwYwu3gsLewPrDrZXO3HuBzUU= +github.com/jamesruan/sodium v1.0.14/go.mod h1:GK2+LACf7kuVQ9k7Irk0MB2B65j5rVqkz+9ylGIggZk= +golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= +golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= diff --git a/doc.go b/doc.go new file mode 100644 index 0000000..8d60f2a --- /dev/null +++ b/doc.go @@ -0,0 +1,150 @@ +// Package kate provides secure cookie-based authentication for HTTP applications. +// +// This package implements encrypted cookies with authenticated encryption. +// It supports generic types to allow storage of any serializable data in encrypted HTTP cookies. +// +// Key features: +// - Type-safe authentication data storage using Go generics +// - Authenticated encryption for secure cookie storage +// - HTTP middleware for required and optional authentication +// - Ready-to-use login handlers for password, oauth, and magic link authentication +// - Configurable cookie properties (domain, path, HTTPS-only, max age) +// - Pluggable serialization interface for custom data formats +// - Secure password hashing with industry-standard algorithms +// +// Example usage: +// +// type UserData struct { +// ID int +// Name string +// } +// +// // Implement SerDes interface for UserData +// type UserSerDes struct{} +// func (s UserSerDes) Serialize(w io.Writer, data UserData) error { /* ... */ } +// func (s UserSerDes) Deserialize(r io.Reader, data *UserData) error { /* ... */ } +// +// // Create authentication instance +// auth := kate.New("hex-encoded-32-byte-key", kate.AuthConfig[UserData]{ +// SerDes: UserSerDes{}, +// CookieName: "session", +// HTTPSOnly: true, +// MaxAge: 24 * time.Hour, +// }) +// +// // Use as middleware +// http.Handle("GET /protected", auth.Required(protectedHandler)) +// +// // Set authentication cookie +// user := UserData{ID: 123, Name: "John"} +// auth.Set(w, user) +// +// // Get authenticated data +// user, ok := auth.Get(r.Context()) +// +// Login handlers: +// +// type User struct { +// Username string +// Hash string +// ID int +// } +// +// // Implement PasswordUserDataStore interface +// type UserStore struct{} +// func (us UserStore) Fetch(username string) (User, bool, error) { +// user, exists := database.FindUser(username) +// return user, exists, nil +// } +// func (us UserStore) GetPassHash(user User) string { +// return user.Hash +// } +// +// // Password-based login +// passwordConfig := kate.PasswordLoginConfig[User]{ +// UserData: UserStore{}, +// Redirects: kate.Redirects{ +// Default: "/dashboard", +// AllowedPrefixes: []string{"/app/", "/admin/"}, +// FieldName: "redirect", +// }, +// } +// http.Handle("POST /login", auth.PasswordLoginHandler(passwordConfig)) +// +// // Implement MagicLinkMailer interface +// type UserMailer struct{} +// func (um UserMailer) Fetch(email string) (User, bool, error) { +// user, exists := database.FindUserByEmail(email) +// return user, exists, nil +// } +// func (um UserMailer) SendEmail(user User, token string) error { +// return emailService.SendMagicLink(user.Email, token) +// } +// +// // Magic link authentication +// magicConfig := kate.MagicLinkConfig[User]{ +// Mailer: UserMailer{}, +// Redirects: kate.Redirects{ Default: "/dashboard" }, +// TokenExpiry: 15 * time.Minute, +// TokenLocation: kate.TokenLocationQuery, +// } +// http.Handle("POST /magic-link", auth.MagicLinkLoginHandler(magicConfig)) +// http.Handle("GET /verify", auth.MagicLinkVerifyHandler(magicConfig)) +// +// // OAuth2 authentication (Google example) +// type UserDataStore struct{} +// func (ds UserDataStore) GetOrCreateUser(email string) (User, error) { +// user, exists := database.FindUserByEmail(email) +// if exists { +// return user, nil +// } +// return database.CreateUser(email) +// } +// func (ds UserDataStore) StoreState(state kate.OAuthState) (string, error) { +// id := generateUniqueID() +// return id, cache.Set(id, state, 10*time.Minute) +// } +// func (ds UserDataStore) GetAndClearState(id string) (*kate.OAuthState, error) { +// var state kate.OAuthState +// exists, err := cache.GetAndDelete(id, &state) +// if !exists { +// return nil, err +// } +// return &state, err +// } +// +// oauthConfig := kate.GoogleOAuthConfig[User]( +// "your-client-id", +// "your-client-secret", +// "https://yourapp.com/auth/callback", +// UserDataStore{}, +// ) +// // Customize redirects if needed +// oauthConfig.Redirects = kate.Redirects{ +// Default: "/dashboard", +// AllowedPrefixes: []string{"/app/", "/admin/"}, +// FieldName: "redirect", +// } +// http.Handle("GET /auth/login", auth.OAuthLoginHandler(oauthConfig)) +// http.Handle("GET /auth/callback", auth.OAuthCallbackHandler(oauthConfig)) +// +// Password hashing: +// +// // Hash a password with defaults +// hash, err := kate.HashPassword("user-password", nil) +// +// // Verify a password +// match, err := kate.ComparePassword("user-password", hash) +// if err != nil { +// // Handle malformed hash error +// } +// if match { +// // Password is correct +// } +// +// Cryptographic algorithms: +// +// This package uses the following cryptographic algorithms: +// - Token encryption: XSalsa20 stream cipher with Poly1305 MAC for authenticated encryption (compatible with libsodium secretbox) +// - Password hashing: Argon2id with secure default parameters and PHC format storage +package kate diff --git a/encryption.go b/encryption.go new file mode 100644 index 0000000..a445668 --- /dev/null +++ b/encryption.go @@ -0,0 +1,136 @@ +package kate + +import ( + "crypto/rand" + "encoding/hex" + "errors" + "fmt" + + "golang.org/x/crypto/poly1305" //nolint:staticcheck + "golang.org/x/crypto/salsa20" +) + +type encryption struct { + Key [naclKeyLen]byte +} + +var ErrInvalidToken = errors.New("invalid token") + +func encryptionFromHexKey(key string) (encryption, error) { + e := encryption{} + l := hex.EncodedLen(naclKeyLen) + keybytes := []byte(key) + if l != len(keybytes) { + return e, fmt.Errorf("expected %d length encryption key, got %d", l, len(keybytes)) + } + if _, err := hex.Decode(e.Key[:], keybytes); err != nil { + return e, fmt.Errorf("encryptionFromHexKey: %w", err) + } + return e, nil +} + +func (e encryption) Encrypt(data []byte) string { + buf := make([]byte, naclNonceLen+len(data)+naclMacLen) + nonce := [naclNonceLen]byte{} + + _, _ = rand.Read(nonce[:]) + + seal(buf[naclNonceLen:], e.Key, nonce, data) + + copy(buf[:naclNonceLen], nonce[:]) + return hex.EncodeToString(buf) +} + +func (e encryption) Decrypt(token string) ([]byte, bool) { + tokenBytes := []byte(token) + ciphertext := make([]byte, hex.DecodedLen(len(tokenBytes))) + if _, err := hex.Decode(ciphertext, tokenBytes); err != nil { + return nil, false + } + + nonce := ciphertext[:naclNonceLen] + noncebuf := [naclNonceLen]byte{} + copy(noncebuf[:], nonce) + ciphertext = ciphertext[naclNonceLen:] + + cleartext := make([]byte, len(ciphertext)-naclMacLen) + if _, ok := open(cleartext, e.Key, noncebuf, ciphertext); !ok { + return nil, false + } + return cleartext, true +} + +const ( + naclKeyLen = 32 + naclNonceLen = 24 + naclMacLen = 16 +) + +// +// Replicating the libsodium secretbox API below. +// +// This was a last resort. I really didn't want to introduce CGO, +// and the only pure-go implementation I could find (libgodium) +// is unmaintained and riddled with bugs. +// +// With solid go implementations of the core primitives already in +// golang.org/x/crypto, it made the most sense to just wrap those. +// + +func seal(dst []byte, key [naclKeyLen]byte, nonce [naclNonceLen]byte, plaintext []byte) []byte { + if len(dst) < len(plaintext)+naclMacLen { + dst = make([]byte, len(plaintext)+naclMacLen) + } + + keystream := make([]byte, naclKeyLen+len(plaintext)) + salsa20.XORKeyStream(keystream, keystream, nonce[:], &key) + + var poly1305Key [naclKeyLen]byte + copy(poly1305Key[:], keystream[:naclKeyLen]) + + ciphertext := make([]byte, len(plaintext)) + for i := range plaintext { + ciphertext[i] = plaintext[i] ^ keystream[naclKeyLen+i] + } + + var mac [naclMacLen]byte + poly1305.Sum(&mac, ciphertext, &poly1305Key) + + copy(dst[:naclMacLen], mac[:]) + copy(dst[naclMacLen:], ciphertext) + + return dst +} + +func open(dst []byte, key [naclKeyLen]byte, nonce [naclNonceLen]byte, box []byte) ([]byte, bool) { + if len(box) < naclMacLen { + return nil, false + } + + if len(dst) < len(box)-naclMacLen { + dst = make([]byte, len(box)-naclMacLen) + } + + var receivedMAC [naclMacLen]byte + copy(receivedMAC[:], box[:naclMacLen]) + ciphertext := box[naclMacLen:] + + keystream := make([]byte, naclKeyLen+len(ciphertext)) + salsa20.XORKeyStream(keystream, keystream, nonce[:], &key) + + var poly1305Key [naclKeyLen]byte + copy(poly1305Key[:], keystream[:naclKeyLen]) + + var expectedMAC [naclMacLen]byte + poly1305.Sum(&expectedMAC, ciphertext, &poly1305Key) + + if !poly1305.Verify(&receivedMAC, ciphertext, &poly1305Key) { + return nil, false + } + + for i := range ciphertext { + dst[i] = ciphertext[i] ^ keystream[naclKeyLen+i] + } + + return dst, true +} diff --git a/encryption_test.go b/encryption_test.go new file mode 100644 index 0000000..fcec058 --- /dev/null +++ b/encryption_test.go @@ -0,0 +1,151 @@ +package kate + +import ( + "bytes" + "crypto/rand" + "fmt" + "testing" + "testing/quick" +) + +func TestEncryptDecryptRoundtrip(t *testing.T) { + property := func(data []byte) bool { + var key [naclKeyLen]byte + if _, err := rand.Read(key[:]); err != nil { + return false + } + + encryption := encryption{Key: key} + token := encryption.Encrypt(data) + decryptedData, success := encryption.Decrypt(token) + + if !bytes.Equal(decryptedData, data) { + fmt.Printf("expected %v, got %v\n", data, decryptedData) + } + return success && bytes.Equal(decryptedData, data) + } + + if err := quick.Check(property, nil); err != nil { + t.Errorf("Encrypt/Decrypt roundtrip property failed: %v", err) + } +} + +func TestEncryptionFromHexKey(t *testing.T) { + tests := []struct { + name string + key string + wantErr bool + }{ + { + name: "valid hex key", + key: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + wantErr: false, + }, + { + name: "key too short", + key: "0123456789abcdef", + wantErr: true, + }, + { + name: "key too long", + key: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef00", + wantErr: true, + }, + { + name: "invalid hex characters", + key: "0123456789abcdefghijklmnopqrstuvwxyz0123456789abcdef0123456789abcdef", + wantErr: true, + }, + { + name: "empty key", + key: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if _, err := encryptionFromHexKey(tt.key); (err != nil) != tt.wantErr { + t.Errorf("encryptionFromHexKey() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestDecryptInvalidTokens(t *testing.T) { + var key [naclKeyLen]byte + if _, err := rand.Read(key[:]); err != nil { + t.Fatal(err) + } + enc := encryption{Key: key} + + tests := []struct { + name string + token string + }{ + { + name: "invalid hex token", + token: "invalid_hex_string", + }, + { + name: "token with wrong MAC", + token: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, ok := enc.Decrypt(tt.token) + if ok { + t.Errorf("Decrypt() should have failed for %s, but returned data: %v", tt.name, data) + } + }) + } +} + +func TestEncryptEmptyData(t *testing.T) { + var key [naclKeyLen]byte + if _, err := rand.Read(key[:]); err != nil { + t.Fatal(err) + } + enc := encryption{Key: key} + + token := enc.Encrypt([]byte{}) + if token == "" { + t.Error("Encrypt() should return non-empty token even for empty data") + } + + data, ok := enc.Decrypt(token) + if !ok { + t.Error("Decrypt() should succeed for encrypted empty data") + } + if len(data) != 0 { + t.Errorf("Decrypt() should return empty data, got %v", data) + } +} + +func TestEncryptLargeData(t *testing.T) { + var key [naclKeyLen]byte + if _, err := rand.Read(key[:]); err != nil { + t.Fatal(err) + } + enc := encryption{Key: key} + + largeData := make([]byte, 10000) + if _, err := rand.Read(largeData); err != nil { + t.Fatal(err) + } + + token := enc.Encrypt(largeData) + if token == "" { + t.Error("Encrypt() should return non-empty token for large data") + } + + decryptedData, ok := enc.Decrypt(token) + if !ok { + t.Error("Decrypt() should succeed for encrypted large data") + } + if !bytes.Equal(decryptedData, largeData) { + t.Error("Decrypt() should return original large data") + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..9bcb898 --- /dev/null +++ b/go.mod @@ -0,0 +1,13 @@ +module git.tjp.lol/authentic_kate + +go 1.24.4 + +require ( + golang.org/x/crypto v0.39.0 + golang.org/x/oauth2 v0.30.0 +) + +require ( + cloud.google.com/go/compute/metadata v0.3.0 // indirect + golang.org/x/sys v0.33.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..a2cbcfb --- /dev/null +++ b/go.sum @@ -0,0 +1,8 @@ +cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= +cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= +golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= +golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= diff --git a/login_helpers.go b/login_helpers.go new file mode 100644 index 0000000..d7cac41 --- /dev/null +++ b/login_helpers.go @@ -0,0 +1,64 @@ +package kate + +import ( + "net/http" + "net/url" + "strings" +) + +// Redirects configures redirect behavior for authentication handlers. +type Redirects struct { + // Default is the default URL to redirect to after successful authentication + Default string + + // AllowedPrefixes is a list of allowed redirect URL prefixes for security. + // + // If empty, any redirect target is allowed (not recommended for production) + AllowedPrefixes []string + + // FieldName is the form/query field name for the redirect target + // + // If empty, Default will always be used as the target + FieldName string +} + +func (r Redirects) isValid(target string) bool { + targetURL, err := url.Parse(target) + if err != nil { + return false + } + if targetURL.IsAbs() && targetURL.Host != "" { + return false + } + + if len(r.AllowedPrefixes) == 0 { + return true + } + + for _, prefix := range r.AllowedPrefixes { + if strings.HasPrefix(target, prefix) { + return true + } + } + return false +} + +func (r Redirects) target(req *http.Request) string { + d := r.Default + if d == "" { + d = "/" + } + + if r.FieldName == "" { + return d + } + if err := req.ParseForm(); err != nil { + return d + } + + target := req.Form.Get(r.FieldName) + if target != "" && r.isValid(target) { + return target + } + return d +} diff --git a/magic_link_login.go b/magic_link_login.go new file mode 100644 index 0000000..9228a97 --- /dev/null +++ b/magic_link_login.go @@ -0,0 +1,233 @@ +package kate + +import ( + "errors" + "net/http" + "strings" + "time" +) + +// MagicLinkConfig configures the magic link authentication handler behavior. +type MagicLinkConfig[T any] struct { + // Mailer provides user data lookup and email sending + Mailer MagicLinkMailer[T] + + // Redirects configures post-authentication redirect behavior + Redirects Redirects + + // UsernameField is the form field name for username + UsernameField string + + // TokenField is the URL parameter name for the magic link token + TokenField string + + // TokenLocation specifies where to retrieve the token from + TokenLocation TokenLocation + + // TokenExpiry is how long the magic link token is valid + TokenExpiry time.Duration + + // LogError is an optional function to log errors + LogError func(error) +} + +// MagicLinkMailer provides user data lookup and email sending for magic link authentication. +type MagicLinkMailer[T any] interface { + // Fetch retrieves user data by username. + // + // Returns the user data, whether the user was found, and any error. + // If the user is not found, should return (zero value, false, nil). + Fetch(username string) (T, bool, error) + + // SendEmail sends a magic link email to the user. + // + // The token parameter contains the encrypted magic link token that should + // be included in the email URL for authentication. + SendEmail(userData T, token string) error +} + +// TokenLocation specifies where the magic link token should be retrieved from +type TokenLocation struct{ location int } + +var ( + // TokenLocationQuery retrieves the token from URL query parameters + TokenLocationQuery = TokenLocation{0} + // TokenLocationPath retrieves the token from URL path parameters using Request.PathValue() + TokenLocationPath = TokenLocation{1} +) + +func (mlc *MagicLinkConfig[T]) setDefaults() { + if mlc.UsernameField == "" { + mlc.UsernameField = "email" + } + if mlc.TokenField == "" { + mlc.TokenField = "token" + } + // TokenLocation defaults to TokenLocationQuery (zero value) + if mlc.TokenExpiry == 0 { + mlc.TokenExpiry = 15 * time.Minute + } +} + +func (mlc MagicLinkConfig[T]) logError(err error) { + if mlc.LogError != nil { + mlc.LogError(err) + } +} + +type magicLinkToken struct { + Username string + Redirect string + ExpiresAt time.Time +} + +func (t magicLinkToken) serialize() string { + return t.Redirect + "\x00" + t.ExpiresAt.Format(time.RFC3339) + "\x00" + t.Username +} + +func parseMagicLinkToken(data string) (magicLinkToken, error) { + parts := strings.SplitN(data, "\x00", 3) + if len(parts) != 3 { + return magicLinkToken{}, errors.New("invalid token format") + } + + redirect := parts[0] + expiresStr := parts[1] + username := parts[2] + + expiresAt, err := time.Parse(time.RFC3339, expiresStr) + if err != nil { + return magicLinkToken{}, err + } + + return magicLinkToken{ + Redirect: redirect, + ExpiresAt: expiresAt, + Username: username, + }, nil +} + +// MagicLinkLoginHandler returns an HTTP handler that processes magic link requests. +// +// It looks up the user, generates a token, and sends an email with the magic link. +func (a Auth[T]) MagicLinkLoginHandler(config MagicLinkConfig[T]) http.Handler { + config.setDefaults() + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + if err := r.ParseForm(); err != nil { + http.Error(w, "Invalid form data", http.StatusBadRequest) + return + } + + username := r.PostForm.Get(config.UsernameField) + + if username == "" { + http.Error(w, config.UsernameField+" required", http.StatusBadRequest) + return + } + + userData, ok, err := config.Mailer.Fetch(username) + if err != nil { + config.logError(err) + http.Error(w, "Error finding user", http.StatusInternalServerError) + return + } + if !ok { + // Don't reveal whether user exists + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("Magic link sent")); err != nil { + config.logError(err) + } + return + } + + tokenData := []byte(magicLinkToken{ + Username: username, + Redirect: config.Redirects.target(r), + ExpiresAt: time.Now().Add(config.TokenExpiry), + }.serialize()) + + encryptedToken := a.enc.Encrypt(tokenData) + + if err := config.Mailer.SendEmail(userData, encryptedToken); err != nil { + config.logError(err) + http.Error(w, "Failed to send email", http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("Magic link sent")); err != nil { + config.logError(err) + } + }) +} + +// MagicLinkVerifyHandler returns an HTTP handler that verifies magic link tokens. +// +// It decrypts and validates the token, sets the authentication cookie, and redirects. +func (a Auth[T]) MagicLinkVerifyHandler(config MagicLinkConfig[T]) http.Handler { + config.setDefaults() + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var encryptedToken string + switch config.TokenLocation { + case TokenLocationPath: + encryptedToken = r.PathValue(config.TokenField) + case TokenLocationQuery: + encryptedToken = r.URL.Query().Get(config.TokenField) + default: + encryptedToken = r.URL.Query().Get(config.TokenField) + } + + if encryptedToken == "" { + http.Error(w, "Missing token", http.StatusBadRequest) + return + } + + tokenData, ok := a.enc.Decrypt(encryptedToken) + if !ok { + http.Error(w, "Invalid token", http.StatusUnauthorized) + return + } + + token, err := parseMagicLinkToken(string(tokenData)) + if err != nil { + config.logError(err) + http.Error(w, "Invalid token", http.StatusUnauthorized) + return + } + + if time.Now().After(token.ExpiresAt) { + http.Error(w, "Token expired", http.StatusUnauthorized) + return + } + + userData, ok, err := config.Mailer.Fetch(token.Username) + if err != nil { + config.logError(err) + http.Error(w, "Authentication failed", http.StatusUnauthorized) + return + } + if !ok { + http.Error(w, "Authentication failed", http.StatusUnauthorized) + return + } + + if err := a.Set(w, userData); err != nil { + config.logError(err) + http.Error(w, "Failed to set authentication", http.StatusInternalServerError) + return + } + + http.Redirect(w, r, token.Redirect, http.StatusSeeOther) + }) +} + diff --git a/magic_link_login_test.go b/magic_link_login_test.go new file mode 100644 index 0000000..da4bfa6 --- /dev/null +++ b/magic_link_login_test.go @@ -0,0 +1,473 @@ +package kate + +import ( + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "strings" + "testing" + "time" +) + +// Mock implementation of MagicLinkMailer for testing +type mockMagicLinkMailer[T any] struct { + users map[string]T + sentEmails []struct { + user T + token string + } + sendEmailFunc func(T, string) error +} + +func newMockMagicLinkMailer[T any](users map[string]T, sendEmailFunc func(T, string) error) *mockMagicLinkMailer[T] { + return &mockMagicLinkMailer[T]{ + users: users, + sendEmailFunc: sendEmailFunc, + } +} + +func (m *mockMagicLinkMailer[T]) Fetch(username string) (T, bool, error) { + user, exists := m.users[username] + return user, exists, nil +} + +func (m *mockMagicLinkMailer[T]) SendEmail(userData T, token string) error { + m.sentEmails = append(m.sentEmails, struct { + user T + token string + }{userData, token}) + if m.sendEmailFunc != nil { + return m.sendEmailFunc(userData, token) + } + return nil +} + +type testUser struct { + Username string + Hash string + ID int +} + +type testUserSerDes struct{} + +func (ts testUserSerDes) Serialize(w io.Writer, data testUser) error { + _, err := fmt.Fprintf(w, "%s|%s|%d", data.Username, data.Hash, data.ID) + return err +} + +func (ts testUserSerDes) Deserialize(r io.Reader, data *testUser) error { + buf, err := io.ReadAll(r) + if err != nil { + return err + } + parts := strings.Split(string(buf), "|") + if len(parts) != 3 { + return errors.New("invalid format") + } + data.Username = parts[0] + data.Hash = parts[1] + id, err := strconv.Atoi(parts[2]) + if err != nil { + return err + } + data.ID = id + return nil +} + +func TestMagicLinkHandler(t *testing.T) { + auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{ + SerDes: testUserSerDes{}, + CookieName: "test_session", + }) + + users := map[string]testUser{ + "john@example.com": {Username: "john@example.com", Hash: "", ID: 1}, + "jane@example.com": {Username: "jane@example.com", Hash: "", ID: 2}, + } + + mockMailer := newMockMagicLinkMailer(users, nil) + + config := MagicLinkConfig[testUser]{ + Mailer: mockMailer, + Redirects: Redirects{ + Default: "/dashboard", + AllowedPrefixes: []string{"/app/", "/admin/"}, + }, + } + + handler := auth.MagicLinkLoginHandler(config) + + tests := []struct { + name string + method string + formData url.Values + expectedStatus int + expectedBody string + checkEmail bool + }{ + { + name: "successful magic link request", + method: "POST", + formData: url.Values{"email": {"john@example.com"}}, + expectedStatus: http.StatusOK, + expectedBody: "Magic link sent", + checkEmail: true, + }, + { + name: "magic link with redirect", + method: "POST", + formData: url.Values{"email": {"john@example.com"}, "redirect": {"/app/settings"}}, + expectedStatus: http.StatusOK, + expectedBody: "Magic link sent", + checkEmail: true, + }, + { + name: "nonexistent user returns success (no user enumeration)", + method: "POST", + formData: url.Values{"email": {"nobody@example.com"}}, + expectedStatus: http.StatusOK, + expectedBody: "Magic link sent", + checkEmail: false, + }, + { + name: "missing email", + method: "POST", + formData: url.Values{}, + expectedStatus: http.StatusBadRequest, + checkEmail: false, + }, + { + name: "GET method not allowed", + method: "GET", + formData: url.Values{}, + expectedStatus: http.StatusMethodNotAllowed, + checkEmail: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockMailer.sentEmails = nil // Reset + + var req *http.Request + if tt.method == "POST" { + req = httptest.NewRequest(tt.method, "/magic-link", strings.NewReader(tt.formData.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + } else { + req = httptest.NewRequest(tt.method, "/magic-link", nil) + } + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != tt.expectedStatus { + t.Errorf("expected status %d, got %d", tt.expectedStatus, rr.Code) + } + + if tt.expectedBody != "" { + body := strings.TrimSpace(rr.Body.String()) + if body != tt.expectedBody { + t.Errorf("expected body %q, got %q", tt.expectedBody, body) + } + } + + if tt.checkEmail { + if len(mockMailer.sentEmails) != 1 { + t.Errorf("expected 1 email sent, got %d", len(mockMailer.sentEmails)) + } else { + email := mockMailer.sentEmails[0] + if email.token == "" { + t.Error("expected non-empty token") + } + } + } else { + if len(mockMailer.sentEmails) != 0 { + t.Errorf("expected no emails sent, got %d", len(mockMailer.sentEmails)) + } + } + }) + } +} + +func TestMagicLinkVerifyHandler(t *testing.T) { + auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{ + SerDes: testUserSerDes{}, + CookieName: "test_session", + }) + + users := map[string]testUser{ + "john@example.com": {Username: "john@example.com", Hash: "", ID: 1}, + } + + mockMailer := newMockMagicLinkMailer(users, nil) + + config := MagicLinkConfig[testUser]{ + Mailer: mockMailer, + Redirects: Redirects{Default: "/dashbaord"}, + TokenExpiry: time.Minute, + } + + token := magicLinkToken{ + Username: "john@example.com", + Redirect: "/app/settings", + ExpiresAt: time.Now().Add(time.Minute), + } + tokenData := []byte(token.serialize()) + validToken := auth.enc.Encrypt(tokenData) + + expiredToken := magicLinkToken{ + Username: "john@example.com", + Redirect: "/app/settings", + ExpiresAt: time.Now().Add(-time.Minute), + } + expiredTokenData := []byte(expiredToken.serialize()) + expiredTokenStr := auth.enc.Encrypt(expiredTokenData) + + handler := auth.MagicLinkVerifyHandler(config) + + tests := []struct { + name string + method string + token string + expectedStatus int + expectedRedirect string + checkCookie bool + }{ + { + name: "valid token with redirect", + method: "GET", + token: validToken, + expectedStatus: http.StatusSeeOther, + expectedRedirect: "/app/settings", + checkCookie: true, + }, + { + name: "missing token", + method: "GET", + token: "", + expectedStatus: http.StatusBadRequest, + checkCookie: false, + }, + { + name: "invalid token", + method: "GET", + token: "invalid-token", + expectedStatus: http.StatusUnauthorized, + checkCookie: false, + }, + { + name: "expired token", + method: "GET", + token: expiredTokenStr, + expectedStatus: http.StatusUnauthorized, + checkCookie: false, + }, + { + name: "POST method not allowed", + method: "POST", + token: validToken, + expectedStatus: http.StatusMethodNotAllowed, + checkCookie: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + url := "/verify" + if tt.token != "" { + url += "?token=" + tt.token + } + + req := httptest.NewRequest(tt.method, url, nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != tt.expectedStatus { + t.Errorf("expected status %d, got %d", tt.expectedStatus, rr.Code) + } + + if tt.expectedRedirect != "" { + location := rr.Header().Get("Location") + if location != tt.expectedRedirect { + t.Errorf("expected redirect to %s, got %s", tt.expectedRedirect, location) + } + } + + if tt.checkCookie { + cookies := rr.Result().Cookies() + found := false + for _, cookie := range cookies { + if cookie.Name == "test_session" && cookie.Value != "" { + found = true + break + } + } + if !found { + t.Error("expected authentication cookie to be set") + } + } + }) + } +} + +func TestMagicLinkConfigDefaults(t *testing.T) { + config := MagicLinkConfig[testUser]{} + config.setDefaults() + + if config.UsernameField != "email" { + t.Errorf("expected UsernameField to be 'email', got %s", config.UsernameField) + } + if config.TokenField != "token" { + t.Errorf("expected TokenField to be 'token', got %s", config.TokenField) + } + if config.TokenExpiry != 15*time.Minute { + t.Errorf("expected TokenExpiry to be 15 minutes, got %v", config.TokenExpiry) + } + if config.TokenLocation != TokenLocationQuery { + t.Errorf("expected TokenLocation to be TokenLocationQuery, got %v", config.TokenLocation) + } +} + +func TestMagicLinkTokenSerialization(t *testing.T) { + now := time.Now().Truncate(time.Second) // Remove nanoseconds for consistent comparison + + tests := []struct { + name string + token magicLinkToken + expected string + }{ + { + name: "simple token", + token: magicLinkToken{ + Redirect: "/dashboard", + ExpiresAt: now, + Username: "user@example.com", + }, + expected: "/dashboard\x00" + now.Format(time.RFC3339) + "\x00user@example.com", + }, + { + name: "redirect with pipes", + token: magicLinkToken{ + Redirect: "/app|section|page", + ExpiresAt: now, + Username: "user@example.com", + }, + expected: "/app|section|page\x00" + now.Format(time.RFC3339) + "\x00user@example.com", + }, + { + name: "empty redirect", + token: magicLinkToken{ + Redirect: "", + ExpiresAt: now, + Username: "user@example.com", + }, + expected: "\x00" + now.Format(time.RFC3339) + "\x00user@example.com", + }, + { + name: "username with pipes", + token: magicLinkToken{ + Redirect: "/dashboard", + ExpiresAt: now, + Username: "user|with|pipes@example.com", + }, + expected: "/dashboard\x00" + now.Format(time.RFC3339) + "\x00user|with|pipes@example.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + serialized := tt.token.serialize() + if serialized != tt.expected { + t.Errorf("serialize() = %q, want %q", serialized, tt.expected) + } + + parsed, err := parseMagicLinkToken(serialized) + if err != nil { + t.Errorf("parseMagicLinkToken() error = %v", err) + return + } + + if parsed.Redirect != tt.token.Redirect { + t.Errorf("parsed Redirect = %q, want %q", parsed.Redirect, tt.token.Redirect) + } + if parsed.Username != tt.token.Username { + t.Errorf("parsed Username = %q, want %q", parsed.Username, tt.token.Username) + } + if !parsed.ExpiresAt.Equal(tt.token.ExpiresAt) { + t.Errorf("parsed ExpiresAt = %v, want %v", parsed.ExpiresAt, tt.token.ExpiresAt) + } + }) + } +} + +func TestParseMagicLinkTokenErrors(t *testing.T) { + tests := []struct { + name string + input string + }{ + {"empty string", ""}, + {"single part", "onlyonepart"}, + {"two parts", "two\x00parts"}, + {"invalid time", "redirect\x00invalid-time\x00username"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := parseMagicLinkToken(tt.input) + if err == nil { + t.Errorf("parseMagicLinkToken(%q) expected error, got nil", tt.input) + } + }) + } +} + +func TestMagicLinkVerifyHandlerPathToken(t *testing.T) { + auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{ + SerDes: testUserSerDes{}, + CookieName: "test_session", + }) + + users := map[string]testUser{ + "test": {Username: "test", ID: 1}, + } + + mockMailer := newMockMagicLinkMailer(users, nil) + + config := MagicLinkConfig[testUser]{ + Mailer: mockMailer, + Redirects: Redirects{Default: "/dashboard"}, + TokenExpiry: time.Minute, + TokenLocation: TokenLocationPath, + TokenField: "token", + } + + handler := auth.MagicLinkVerifyHandler(config) + + if config.TokenLocation != TokenLocationPath { + t.Errorf("expected TokenLocation to be TokenLocationPath, got %v", config.TokenLocation) + } + if config.TokenField != "token" { + t.Errorf("expected TokenField to be 'token', got %s", config.TokenField) + } + + req := httptest.NewRequest("GET", "/verify", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusBadRequest { + t.Errorf("expected status %d for missing token, got %d", http.StatusBadRequest, rr.Code) + } +} + +func TestTokenLocationEnum(t *testing.T) { + if TokenLocationQuery.location != 0 { + t.Errorf("expected TokenLocationQuery to be 0, got %d", TokenLocationQuery.location) + } + if TokenLocationPath.location != 1 { + t.Errorf("expected TokenLocationPath to be 1, got %d", TokenLocationPath.location) + } +} diff --git a/oauth_login.go b/oauth_login.go new file mode 100644 index 0000000..6c16275 --- /dev/null +++ b/oauth_login.go @@ -0,0 +1,257 @@ +package kate + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "golang.org/x/oauth2" + "golang.org/x/oauth2/github" + "golang.org/x/oauth2/google" +) + +// OAuthConfig configures the OAuth login handler behavior. +type OAuthConfig[T any] struct { + // OAuth2Config is the standard OAuth2 configuration + OAuth2Config *oauth2.Config + + // UserInfoURL is the provider endpoint to get user information + UserInfoURL string + + // DataStore provides user data lookup and temporary state storage + DataStore OAuthDataStore[T] + + // Redirects configures post-authentication redirect behavior + Redirects Redirects + + // StateExpiry is how long the OAuth state is valid + StateExpiry time.Duration + + // LogError is an optional function to log errors + LogError func(error) +} + +// OAuthDataStore provides user data lookup and temporary storage for OAuth state data. +type OAuthDataStore[T any] interface { + // GetOrCreateUser retrieves or creates user data by email address. + // + // If the user exists, returns their data. If not, creates a new user + // and returns the new user data. This handles both OAuth login and + // registration in a single operation. + GetOrCreateUser(email string) (T, error) + + // StoreState stores OAuth state data temporarily and returns a unique state ID. + // + // The returned state ID will be used as the OAuth state parameter. + // Implementations should generate a unique, unguessable ID for security. + StoreState(state OAuthState) (string, error) + + // GetAndClearState retrieves and deletes OAuth state data by ID. + // + // Returns the state data if found, otherwise nil. + // The state should be deleted after retrieval to ensure one-time use. + GetAndClearState(id string) (*OAuthState, error) +} + +// GoogleOAuthConfig creates an OAuthConfig for Google OAuth with sensible defaults. +func GoogleOAuthConfig[T any]( + clientID, + clientSecret, + callbackURI string, + dataStore OAuthDataStore[T], +) OAuthConfig[T] { + return OAuthConfig[T]{ + OAuth2Config: &oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + RedirectURL: callbackURI, + Scopes: []string{"email"}, + Endpoint: google.Endpoint, + }, + UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo", + DataStore: dataStore, + Redirects: Redirects{Default: "/", FieldName: "redirect"}, + StateExpiry: 10 * time.Minute, + } +} + +// GitHubOAuthConfig creates an OAuthConfig for GitHub OAuth with sensible defaults. +func GitHubOAuthConfig[T any]( + clientID, + clientSecret, + callbackURI string, + dataStore OAuthDataStore[T], +) OAuthConfig[T] { + return OAuthConfig[T]{ + OAuth2Config: &oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + RedirectURL: callbackURI, + Scopes: []string{"user:email"}, + Endpoint: github.Endpoint, + }, + UserInfoURL: "https://api.github.com/user", + DataStore: dataStore, + Redirects: Redirects{Default: "/", FieldName: "redirect"}, + StateExpiry: 10 * time.Minute, + } +} + +func (oc *OAuthConfig[T]) setDefaults() { + if oc.StateExpiry == 0 { + oc.StateExpiry = 10 * time.Minute + } +} + +func (oc OAuthConfig[T]) logError(err error) { + if oc.LogError != nil { + oc.LogError(err) + } +} + +// OAuthState represents the OAuth state data that needs to be temporarily stored. +type OAuthState struct { + Redirect string + ExpiresAt time.Time +} + +// OAuthLoginHandler returns an HTTP handler that initiates OAuth authentication. +func (a Auth[T]) OAuthLoginHandler(config OAuthConfig[T]) http.Handler { + config.setDefaults() + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + state := OAuthState{ + Redirect: config.Redirects.target(r), + ExpiresAt: time.Now().Add(config.StateExpiry), + } + + stateID, err := config.DataStore.StoreState(state) + if err != nil { + config.logError(err) + http.Error(w, "Failed to store state", http.StatusInternalServerError) + return + } + + authURL := config.OAuth2Config.AuthCodeURL(stateID) + http.Redirect(w, r, authURL, http.StatusSeeOther) + }) +} + +// OAuthCallbackHandler returns an HTTP handler that handles the OAuth callback. +func (a Auth[T]) OAuthCallbackHandler(config OAuthConfig[T]) http.Handler { + config.setDefaults() + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + stateID := r.URL.Query().Get("state") + code := r.URL.Query().Get("code") + errorParam := r.URL.Query().Get("error") + + if errorParam != "" { + config.logError(fmt.Errorf("OAuth error: %s", errorParam)) + http.Error(w, "OAuth authorization failed", http.StatusBadRequest) + return + } + + if code == "" { + http.Error(w, "Missing authorization code", http.StatusBadRequest) + return + } + + if stateID == "" { + http.Error(w, "Missing state parameter", http.StatusBadRequest) + return + } + + state, err := config.DataStore.GetAndClearState(stateID) + if err != nil { + config.logError(err) + http.Error(w, "Error finding login state", http.StatusInternalServerError) + return + } + if state == nil { + http.Error(w, "Invalid login state", http.StatusBadRequest) + return + } + + if time.Now().After(state.ExpiresAt) { + http.Error(w, "Login expired", http.StatusBadRequest) + return + } + + ctx := context.Background() + token, err := config.OAuth2Config.Exchange(ctx, code) + if err != nil { + config.logError(err) + http.Error(w, "Failed to exchange code for token", http.StatusInternalServerError) + return + } + + client := config.OAuth2Config.Client(ctx, token) + resp, err := client.Get(config.UserInfoURL) + if err != nil { + config.logError(err) + http.Error(w, "Failed to get user info", http.StatusInternalServerError) + return + } + defer func() { + if err := resp.Body.Close(); err != nil { + config.logError(err) + } + }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + config.logError(err) + http.Error(w, "Failed to read user info response", http.StatusInternalServerError) + return + } + + var userInfo struct { + Email string `json:"email"` + } + if err := json.Unmarshal(body, &userInfo); err != nil { + config.logError(err) + http.Error(w, "Failed to parse user info", http.StatusInternalServerError) + return + } + email := userInfo.Email + if email == "" { + http.Error(w, "No email address found", http.StatusBadRequest) + return + } + + userData, err := config.DataStore.GetOrCreateUser(email) + if err != nil { + config.logError(err) + http.Error(w, "Failed to get or create user", http.StatusInternalServerError) + return + } + + if err := a.Set(w, userData); err != nil { + config.logError(err) + http.Error(w, "Failed to set authentication cookie", http.StatusInternalServerError) + return + } + + redirectTarget := state.Redirect + if redirectTarget == "" { + redirectTarget = config.Redirects.Default + if redirectTarget == "" { + redirectTarget = "/" + } + } + + http.Redirect(w, r, redirectTarget, http.StatusSeeOther) + }) +} diff --git a/oauth_login_test.go b/oauth_login_test.go new file mode 100644 index 0000000..caec89b --- /dev/null +++ b/oauth_login_test.go @@ -0,0 +1,369 @@ +package kate + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync" + "testing" + "time" +) + +// Mock implementation of OAuthDataStore for testing +type mockOAuthDataStore[T any] struct { + userData map[string]T + states map[string]OAuthState + mu sync.RWMutex + getOrCreateUser func(email string) (T, error) + stateIDCounter int +} + +func newMockOAuthDataStore[T any](getOrCreateUser func(email string) (T, error)) *mockOAuthDataStore[T] { + return &mockOAuthDataStore[T]{ + userData: make(map[string]T), + states: make(map[string]OAuthState), + getOrCreateUser: getOrCreateUser, + } +} + +func (m *mockOAuthDataStore[T]) GetOrCreateUser(email string) (T, error) { + return m.getOrCreateUser(email) +} + +func (m *mockOAuthDataStore[T]) StoreState(state OAuthState) (string, error) { + m.mu.Lock() + defer m.mu.Unlock() + + m.stateIDCounter++ + stateID := fmt.Sprintf("state_%d", m.stateIDCounter) + m.states[stateID] = state + return stateID, nil +} + +func (m *mockOAuthDataStore[T]) GetAndClearState(id string) (*OAuthState, error) { + m.mu.Lock() + defer m.mu.Unlock() + + state, exists := m.states[id] + if !exists { + return nil, nil + } + + delete(m.states, id) + return &state, nil +} + +func TestOAuthLoginHandler(t *testing.T) { + auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{ + SerDes: testUserSerDes{}, + CookieName: "test_session", + }) + + mockStore := newMockOAuthDataStore(func(email string) (testUser, error) { + return testUser{Username: email, ID: 1}, nil + }) + + config := GoogleOAuthConfig( + "test-client-id", + "test-client-secret", + "http://localhost:8080/auth/callback", + mockStore, + ) + config.Redirects.Default = "/dashboard" + config.Redirects.AllowedPrefixes = []string{"/app/", "/admin/"} + + handler := auth.OAuthLoginHandler(config) + + tests := []struct { + name string + method string + queryParams url.Values + formData url.Values + expectedStatus int + checkRedirect bool + expectedContains []string + }{ + { + name: "GET request initiates OAuth flow", + method: "GET", + queryParams: url.Values{}, + expectedStatus: http.StatusSeeOther, + checkRedirect: true, + expectedContains: []string{ + "accounts.google.com", + "client_id=test-client-id", + "redirect_uri=http%3A%2F%2Flocalhost%3A8080%2Fauth%2Fcallback", + "response_type=code", + "state=", + "scope=email", + }, + }, + { + name: "GET request with redirect parameter", + method: "GET", + queryParams: url.Values{"redirect": {"/app/settings"}}, + expectedStatus: http.StatusSeeOther, + checkRedirect: true, + expectedContains: []string{ + "accounts.google.com", + "client_id=test-client-id", + }, + }, + { + name: "POST method not allowed", + method: "POST", + formData: url.Values{"redirect": {"/app/profile"}}, + expectedStatus: http.StatusMethodNotAllowed, + checkRedirect: false, + }, + { + name: "invalid redirect falls back to default", + method: "GET", + queryParams: url.Values{"redirect": {"/evil/"}}, + expectedStatus: http.StatusSeeOther, + checkRedirect: true, + expectedContains: []string{ + "accounts.google.com", + "client_id=test-client-id", + }, + }, + { + name: "PUT method not allowed", + method: "PUT", + expectedStatus: http.StatusMethodNotAllowed, + checkRedirect: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var req *http.Request + reqURL := "/oauth/login" + + if len(tt.queryParams) > 0 { + reqURL += "?" + tt.queryParams.Encode() + } + + if tt.method == "POST" && tt.formData != nil { + req = httptest.NewRequest(tt.method, reqURL, strings.NewReader(tt.formData.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + } else { + req = httptest.NewRequest(tt.method, reqURL, nil) + } + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != tt.expectedStatus { + t.Errorf("expected status %d, got %d", tt.expectedStatus, rr.Code) + } + + if tt.checkRedirect { + location := rr.Header().Get("Location") + if location == "" { + t.Error("expected redirect location to be set") + } + + for _, expected := range tt.expectedContains { + if !strings.Contains(location, expected) { + t.Errorf("expected redirect URL to contain %q, got %q", expected, location) + } + } + } + }) + } +} + +func TestOAuthCallbackHandler(t *testing.T) { + auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{ + SerDes: testUserSerDes{}, + CookieName: "test_session", + }) + + mockStore := newMockOAuthDataStore(func(email string) (testUser, error) { + return testUser{Username: email, ID: 1}, nil + }) + + config := GoogleOAuthConfig( + "test-client-id", + "test-client-secret", + "http://localhost:8080/auth/callback", + mockStore, + ) + config.Redirects.Default = "/dashboard" + + handler := auth.OAuthCallbackHandler(config) + + tests := []struct { + name string + method string + queryParams url.Values + expectedStatus int + }{ + { + name: "GET callback with invalid state", + method: "GET", + queryParams: url.Values{"code": {"test-code"}, "state": {"test-state"}}, + expectedStatus: http.StatusBadRequest, + }, + { + name: "POST method not allowed", + method: "POST", + expectedStatus: http.StatusMethodNotAllowed, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reqURL := "/oauth/callback" + if len(tt.queryParams) > 0 { + reqURL += "?" + tt.queryParams.Encode() + } + + req := httptest.NewRequest(tt.method, reqURL, nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != tt.expectedStatus { + t.Errorf("expected status %d, got %d", tt.expectedStatus, rr.Code) + } + }) + } +} + +func TestGoogleOAuthConfig(t *testing.T) { + mockStore := newMockOAuthDataStore(func(email string) (testUser, error) { + return testUser{Username: email, ID: 1}, nil + }) + + config := GoogleOAuthConfig( + "test-client-id", + "test-client-secret", + "http://localhost:8080/auth/callback", + mockStore, + ) + + if config.OAuth2Config.ClientID != "test-client-id" { + t.Errorf("expected ClientID 'test-client-id', got %q", config.OAuth2Config.ClientID) + } + if config.OAuth2Config.ClientSecret != "test-client-secret" { + t.Errorf("expected ClientSecret 'test-client-secret', got %q", config.OAuth2Config.ClientSecret) + } + if config.OAuth2Config.RedirectURL != "http://localhost:8080/auth/callback" { + t.Errorf("expected RedirectURL 'http://localhost:8080/auth/callback', got %q", config.OAuth2Config.RedirectURL) + } + if config.OAuth2Config.Endpoint.AuthURL != "https://accounts.google.com/o/oauth2/auth" { + t.Errorf("expected AuthURL 'https://accounts.google.com/o/oauth2/auth', got %q", config.OAuth2Config.Endpoint.AuthURL) + } + if config.OAuth2Config.Endpoint.TokenURL != "https://oauth2.googleapis.com/token" { + t.Errorf("expected TokenURL 'https://oauth2.googleapis.com/token', got %q", config.OAuth2Config.Endpoint.TokenURL) + } + if config.UserInfoURL != "https://www.googleapis.com/oauth2/v2/userinfo" { + t.Errorf("expected UserInfoURL 'https://www.googleapis.com/oauth2/v2/userinfo', got %q", config.UserInfoURL) + } + expectedScopes := []string{"email"} + if len(config.OAuth2Config.Scopes) != len(expectedScopes) { + t.Errorf("expected %d scopes, got %d", len(expectedScopes), len(config.OAuth2Config.Scopes)) + } else { + for i, scope := range expectedScopes { + if config.OAuth2Config.Scopes[i] != scope { + t.Errorf("expected scope[%d] %q, got %q", i, scope, config.OAuth2Config.Scopes[i]) + } + } + } +} + +func TestGitHubOAuthConfig(t *testing.T) { + mockStore := newMockOAuthDataStore(func(email string) (testUser, error) { + return testUser{Username: email, ID: 1}, nil + }) + + config := GitHubOAuthConfig( + "test-github-client-id", + "test-github-client-secret", + "http://localhost:8080/auth/github/callback", + mockStore, + ) + + if config.OAuth2Config.ClientID != "test-github-client-id" { + t.Errorf("expected ClientID 'test-github-client-id', got %q", config.OAuth2Config.ClientID) + } + if config.OAuth2Config.ClientSecret != "test-github-client-secret" { + t.Errorf("expected ClientSecret 'test-github-client-secret', got %q", config.OAuth2Config.ClientSecret) + } + if config.OAuth2Config.RedirectURL != "http://localhost:8080/auth/github/callback" { + t.Errorf("expected RedirectURL 'http://localhost:8080/auth/github/callback', got %q", config.OAuth2Config.RedirectURL) + } + if config.OAuth2Config.Endpoint.AuthURL != "https://github.com/login/oauth/authorize" { + t.Errorf("expected AuthURL 'https://github.com/login/oauth/authorize', got %q", config.OAuth2Config.Endpoint.AuthURL) + } + if config.OAuth2Config.Endpoint.TokenURL != "https://github.com/login/oauth/access_token" { + t.Errorf("expected TokenURL 'https://github.com/login/oauth/access_token', got %q", config.OAuth2Config.Endpoint.TokenURL) + } + if config.UserInfoURL != "https://api.github.com/user" { + t.Errorf("expected UserInfoURL 'https://api.github.com/user', got %q", config.UserInfoURL) + } + expectedScopes := []string{"user:email"} + if len(config.OAuth2Config.Scopes) != len(expectedScopes) { + t.Errorf("expected %d scopes, got %d", len(expectedScopes), len(config.OAuth2Config.Scopes)) + } else { + for i, scope := range expectedScopes { + if config.OAuth2Config.Scopes[i] != scope { + t.Errorf("expected scope[%d] %q, got %q", i, scope, config.OAuth2Config.Scopes[i]) + } + } + } +} + +func TestOAuthStateStore(t *testing.T) { + mockStore := newMockOAuthDataStore(func(email string) (testUser, error) { + return testUser{Username: email, ID: 1}, nil + }) + + now := time.Now().Truncate(time.Second) + state := OAuthState{ + Redirect: "/dashboard", + ExpiresAt: now, + } + + // Test storing state + stateID, err := mockStore.StoreState(state) + if err != nil { + t.Errorf("StoreState() error = %v", err) + return + } + + if stateID == "" { + t.Error("StoreState() should return non-empty state ID") + } + + // Test retrieving state + retrievedState, err := mockStore.GetAndClearState(stateID) + if err != nil { + t.Errorf("GetAndClearState() error = %v", err) + return + } + + if retrievedState == nil { + t.Error("GetAndClearState() should return state data") + return + } + + if retrievedState.Redirect != state.Redirect { + t.Errorf("retrieved Redirect = %q, want %q", retrievedState.Redirect, state.Redirect) + } + if !retrievedState.ExpiresAt.Equal(state.ExpiresAt) { + t.Errorf("retrieved ExpiresAt = %v, want %v", retrievedState.ExpiresAt, state.ExpiresAt) + } + + // Test that state is cleared (one-time use) + clearedState, err := mockStore.GetAndClearState(stateID) + if err != nil { + t.Errorf("GetAndClearState() on cleared state error = %v", err) + } + if clearedState != nil { + t.Error("GetAndClearState() should return nil for already used state") + } +} diff --git a/password.go b/password.go new file mode 100644 index 0000000..d1a5c54 --- /dev/null +++ b/password.go @@ -0,0 +1,131 @@ +package kate + +import ( + "crypto/rand" + "crypto/subtle" + "encoding/base64" + "errors" + "fmt" + "strings" + + "golang.org/x/crypto/argon2" +) + +// Argon2Config configures Argon2id password hashing parameters. +// +// Leave any fields at their zero value to use the default. +type Argon2Config struct { + // Time is the number of iterations (default: 1) + Time uint32 + // Memory is the memory usage in KiB (default: 64*1024 = 64MB) + Memory uint32 + // Threads is the number of parallel threads (default: 4) + Threads uint8 + // KeyLen is the length of the derived key in bytes (default: 32) + KeyLen uint32 +} + +var defaultArgon2Config = Argon2Config{ + Time: 1, + Memory: 64 * 1024, + Threads: 4, + KeyLen: 32, +} + +// HashPassword hashes a password using Argon2id. +// +// Returns a PHC-formatted string containing all parameters needed for verification. +// If config is nil, secure defaults are used. +func HashPassword(password string, config *Argon2Config) (string, error) { + // Use defaults if no config provided + if config == nil { + config = &defaultArgon2Config + } + + salt := make([]byte, 16) + if _, err := rand.Read(salt); err != nil { + return "", fmt.Errorf("HashPassword: %w", err) + } + + // Use defaults for zero values in config + time := config.Time + if time == 0 { + time = defaultArgon2Config.Time + } + memory := config.Memory + if memory == 0 { + memory = defaultArgon2Config.Memory + } + threads := config.Threads + if threads == 0 { + threads = defaultArgon2Config.Threads + } + keyLen := config.KeyLen + if keyLen == 0 { + keyLen = defaultArgon2Config.KeyLen + } + + hash := argon2.IDKey([]byte(password), salt, time, memory, threads, keyLen) + + return buildPHC(salt, hash, memory, time, threads), nil +} + +// ComparePassword verifies a password against a stored hash in PHC format. +// +// Returns (false, nil) for incorrect passwords and (false, error) for malformed hashes. +// Uses constant-time comparison to prevent timing attacks. +func ComparePassword(loginpass, storedhash string) (bool, error) { + salt, hash, memory, time, threads, err := parsePHC(storedhash) + if err != nil { + return false, fmt.Errorf("ComparePassword: %w", err) + } + + loginHash := argon2.IDKey([]byte(loginpass), salt, time, memory, threads, uint32(len(hash))) + + return subtle.ConstantTimeCompare(loginHash, hash) == 1, nil +} + +func buildPHC(salt, hash []byte, memory, time uint32, threads uint8) string { + b64Salt := base64.RawStdEncoding.EncodeToString(salt) + b64Hash := base64.RawStdEncoding.EncodeToString(hash) + return fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", argon2.Version, memory, time, threads, b64Salt, b64Hash) +} + +func parsePHC(phc string) (salt, hash []byte, memory, time uint32, threads uint8, err error) { + parts := strings.Split(phc, "$") + if len(parts) != 6 { + err = errors.New("invalid PHC format") + return + } + + if parts[1] != "argon2id" { + err = errors.New("unsupported algorithm") + return + } + + var version int + if _, err = fmt.Sscanf(parts[2], "v=%d", &version); err != nil { + return + } + + if version > argon2.Version { + err = errors.New("unsupported argon2 version") + return + } + + var p uint32 + if _, err = fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &p); err != nil { + return + } + threads = uint8(p) + + if salt, err = base64.RawStdEncoding.DecodeString(parts[4]); err != nil { + return + } + + if hash, err = base64.RawStdEncoding.DecodeString(parts[5]); err != nil { + return + } + + return +} diff --git a/password_login.go b/password_login.go new file mode 100644 index 0000000..837d9c9 --- /dev/null +++ b/password_login.go @@ -0,0 +1,105 @@ +package kate + +import "net/http" + +// PasswordLoginConfig configures the password login handler behavior. +type PasswordLoginConfig[T any] struct { + // UserData provides user data lookup and password hash extraction + UserData PasswordUserDataStore[T] + + // Redirects configures post-authentication redirect behavior + Redirects Redirects + + // UsernameField is the form field name for username + UsernameField string + + // PasswordField is the form field name for password + PasswordField string + + // LogError is called when the login handler encounters unexpected errors + LogError func(error) +} + +// PasswordUserDataStore provides user data lookup for password authentication. +type PasswordUserDataStore[T any] interface { + // Fetch retrieves user data by username. + // + // Returns the user data, whether the user was found, and any error. + // If the user is not found, should return (zero value, false, nil). + Fetch(username string) (T, bool, error) + + // GetPassHash extracts the password hash from user data. + // + // Returns the stored password hash for comparison with the provided password. + GetPassHash(userData T) string +} + +func (lc *PasswordLoginConfig[T]) setDefaults() { + if lc.UsernameField == "" { + lc.UsernameField = "username" + } + if lc.PasswordField == "" { + lc.PasswordField = "password" + } +} + +func (lc PasswordLoginConfig[T]) logError(err error) { + if lc.LogError != nil { + lc.LogError(err) + } +} + +// PasswordLoginHandler returns an HTTP handler that processes password login form submissions. +func (a Auth[T]) PasswordLoginHandler(config PasswordLoginConfig[T]) http.Handler { + config.setDefaults() + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + if err := r.ParseForm(); err != nil { + http.Error(w, "Invalid form data", http.StatusBadRequest) + return + } + + username := r.PostForm.Get(config.UsernameField) + password := r.PostForm.Get(config.PasswordField) + + if username == "" || password == "" { + http.Error(w, "Username and password required", http.StatusBadRequest) + return + } + + userData, ok, err := config.UserData.Fetch(username) + if err != nil { + config.logError(err) + http.Error(w, "Error fetching user", http.StatusInternalServerError) + return + } + if !ok { + http.Error(w, "Authentication failed", http.StatusUnauthorized) + return + } + + match, err := ComparePassword(password, config.UserData.GetPassHash(userData)) + if err != nil { + config.logError(err) + http.Error(w, "Authentication failed", http.StatusUnauthorized) + return + } + + if !match { + http.Error(w, "Authentication failed", http.StatusUnauthorized) + return + } + + if err := a.Set(w, userData); err != nil { + config.logError(err) + http.Error(w, "Failed to set authentication", http.StatusInternalServerError) + return + } + + http.Redirect(w, r, config.Redirects.target(r), http.StatusSeeOther) + }) +} diff --git a/password_login_test.go b/password_login_test.go new file mode 100644 index 0000000..eec1ba7 --- /dev/null +++ b/password_login_test.go @@ -0,0 +1,385 @@ +package kate + +import ( + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" +) + +// Mock implementation of PasswordUserDataStore for testing +type mockPasswordUserDataStore[T any] struct { + users map[string]T + getHash func(T) string +} + +func newMockPasswordUserDataStore[T any](users map[string]T, getHash func(T) string) *mockPasswordUserDataStore[T] { + return &mockPasswordUserDataStore[T]{ + users: users, + getHash: getHash, + } +} + +func (m *mockPasswordUserDataStore[T]) Fetch(username string) (T, bool, error) { + user, exists := m.users[username] + return user, exists, nil +} + +func (m *mockPasswordUserDataStore[T]) GetPassHash(userData T) string { + return m.getHash(userData) +} + +func TestPasswordLoginHandler(t *testing.T) { + auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{ + SerDes: testUserSerDes{}, + CookieName: "test_session", + }) + + testHash, err := HashPassword("password123", nil) + if err != nil { + t.Fatalf("Failed to hash password: %v", err) + } + + users := map[string]testUser{ + "john": {Username: "john", Hash: testHash, ID: 1}, + "jane": {Username: "jane", Hash: testHash, ID: 2}, + } + + mockStore := newMockPasswordUserDataStore(users, func(user testUser) string { + return user.Hash + }) + + config := PasswordLoginConfig[testUser]{ + UserData: mockStore, + Redirects: Redirects{ + Default: "/dashboard", + AllowedPrefixes: []string{"/app/", "/admin/"}, + FieldName: "redirect", + }, + } + + handler := auth.PasswordLoginHandler(config) + + tests := []struct { + name string + method string + formData url.Values + expectedStatus int + expectedRedirect string + checkCookie bool + }{ + { + name: "successful login", + method: "POST", + formData: url.Values{"username": {"john"}, "password": {"password123"}}, + expectedStatus: http.StatusSeeOther, + expectedRedirect: "/dashboard", + checkCookie: true, + }, + { + name: "successful login with redirect", + method: "POST", + formData: url.Values{"username": {"john"}, "password": {"password123"}, "redirect": {"/app/settings"}}, + expectedStatus: http.StatusSeeOther, + expectedRedirect: "/app/settings", + checkCookie: true, + }, + { + name: "invalid redirect falls back to default", + method: "POST", + formData: url.Values{"username": {"john"}, "password": {"password123"}, "redirect": {"/evil/"}}, + expectedStatus: http.StatusSeeOther, + expectedRedirect: "/dashboard", + checkCookie: true, + }, + { + name: "wrong password", + method: "POST", + formData: url.Values{"username": {"john"}, "password": {"wrongpass"}}, + expectedStatus: http.StatusUnauthorized, + checkCookie: false, + }, + { + name: "nonexistent user", + method: "POST", + formData: url.Values{"username": {"nobody"}, "password": {"password123"}}, + expectedStatus: http.StatusUnauthorized, + checkCookie: false, + }, + { + name: "missing username", + method: "POST", + formData: url.Values{"password": {"password123"}}, + expectedStatus: http.StatusBadRequest, + checkCookie: false, + }, + { + name: "missing password", + method: "POST", + formData: url.Values{"username": {"john"}}, + expectedStatus: http.StatusBadRequest, + checkCookie: false, + }, + { + name: "GET method not allowed", + method: "GET", + formData: url.Values{}, + expectedStatus: http.StatusMethodNotAllowed, + checkCookie: false, + }, + { + name: "PUT method not allowed", + method: "PUT", + formData: url.Values{}, + expectedStatus: http.StatusMethodNotAllowed, + checkCookie: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var req *http.Request + + if tt.method == "POST" { + req = httptest.NewRequest(tt.method, "/login", strings.NewReader(tt.formData.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + } else { + req = httptest.NewRequest(tt.method, "/login", nil) + } + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != tt.expectedStatus { + t.Errorf("expected status %d, got %d", tt.expectedStatus, rr.Code) + } + + if tt.expectedRedirect != "" { + location := rr.Header().Get("Location") + if location != tt.expectedRedirect { + t.Errorf("expected redirect to %s, got %s", tt.expectedRedirect, location) + } + } + + if tt.checkCookie { + cookies := rr.Result().Cookies() + found := false + for _, cookie := range cookies { + if cookie.Name == "test_session" && cookie.Value != "" { + found = true + break + } + } + if !found { + t.Error("expected authentication cookie to be set") + } + } + }) + } +} + +func TestPasswordLoginHandlerCustomFields(t *testing.T) { + auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{ + SerDes: testUserSerDes{}, + CookieName: "test_session", + }) + + testHash, err := HashPassword("secret", nil) + if err != nil { + t.Fatalf("Failed to hash password: %v", err) + } + + users := map[string]testUser{ + "admin": {Username: "admin", Hash: testHash, ID: 99}, + } + + mockStore := newMockPasswordUserDataStore(users, func(user testUser) string { + return user.Hash + }) + + config := PasswordLoginConfig[testUser]{ + UserData: mockStore, + UsernameField: "email", + PasswordField: "pass", + Redirects: Redirects{ + FieldName: "next", + Default: "/home", + }, + } + + handler := auth.PasswordLoginHandler(config) + + formData := url.Values{ + "email": {"admin"}, + "pass": {"secret"}, + "next": {"/custom"}, + } + + req := httptest.NewRequest("POST", "/login", strings.NewReader(formData.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusSeeOther { + t.Errorf("expected status %d, got %d", http.StatusSeeOther, rr.Code) + } + + location := rr.Header().Get("Location") + if location != "/custom" { + t.Errorf("expected redirect to /custom, got %s", location) + } +} + +func TestPasswordLoginHandlerParseFormError(t *testing.T) { + auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{ + SerDes: testUserSerDes{}, + CookieName: "test_session", + }) + + users := map[string]testUser{ + "test": {Username: "test", Hash: "", ID: 1}, + } + + mockStore := newMockPasswordUserDataStore(users, func(user testUser) string { + return user.Hash + }) + + config := PasswordLoginConfig[testUser]{ + UserData: mockStore, + } + + handler := auth.PasswordLoginHandler(config) + + req := httptest.NewRequest("POST", "/login", strings.NewReader("username=%")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusBadRequest { + t.Errorf("expected status %d, got %d", http.StatusBadRequest, rr.Code) + } +} + +func TestPasswordLoginHandlerGetPassHashError(t *testing.T) { + auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{ + SerDes: testUserSerDes{}, + CookieName: "test_session", + }) + + users := map[string]testUser{ + "user": {Username: "user", Hash: "", ID: 1}, + } + + mockStore := newMockPasswordUserDataStore(users, func(user testUser) string { + return user.Hash // Empty hash will cause password comparison to fail + }) + + config := PasswordLoginConfig[testUser]{ + UserData: mockStore, + } + + handler := auth.PasswordLoginHandler(config) + + formData := url.Values{"username": {"user"}, "password": {"pass"}} + req := httptest.NewRequest("POST", "/login", strings.NewReader(formData.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Errorf("expected status %d, got %d", http.StatusUnauthorized, rr.Code) + } +} + +func TestPasswordLoginHandlerMalformedHash(t *testing.T) { + auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{ + SerDes: testUserSerDes{}, + CookieName: "test_session", + }) + + users := map[string]testUser{ + "user": {Username: "user", Hash: "invalid-hash", ID: 1}, + } + + mockStore := newMockPasswordUserDataStore(users, func(user testUser) string { + return user.Hash + }) + + config := PasswordLoginConfig[testUser]{ + UserData: mockStore, + } + + handler := auth.PasswordLoginHandler(config) + + formData := url.Values{"username": {"user"}, "password": {"pass"}} + req := httptest.NewRequest("POST", "/login", strings.NewReader(formData.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Errorf("expected status %d, got %d", http.StatusUnauthorized, rr.Code) + } +} + +func TestRedirectsIsValid(t *testing.T) { + tests := []struct { + name string + target string + allowedPrefixes []string + expected bool + }{ + { + name: "no restrictions allows anything", + target: "/anything", + allowedPrefixes: []string{}, + expected: true, + }, + { + name: "valid prefix match", + target: "/app/dashboard", + allowedPrefixes: []string{"/app/", "/admin/"}, + expected: true, + }, + { + name: "invalid prefix", + target: "/evil/", + allowedPrefixes: []string{"/app/", "/admin/"}, + expected: false, + }, + { + name: "absolute URL rejected", + target: "https://evil.com/", + allowedPrefixes: []string{"/app/"}, + expected: false, + }, + { + name: "protocol relative URL rejected", + target: "//evil.com/", + allowedPrefixes: []string{"/app/"}, + expected: false, + }, + { + name: "invalid URL rejected", + target: ":/invalid", + allowedPrefixes: []string{"/app/"}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + redirects := Redirects{AllowedPrefixes: tt.allowedPrefixes} + result := redirects.isValid(tt.target) + if result != tt.expected { + t.Errorf("Redirects{AllowedPrefixes: %v}.isValid(%q) = %v, want %v", tt.allowedPrefixes, tt.target, result, tt.expected) + } + }) + } +} + diff --git a/password_test.go b/password_test.go new file mode 100644 index 0000000..cfe90e0 --- /dev/null +++ b/password_test.go @@ -0,0 +1,230 @@ +package kate + +import ( + "strings" + "testing" + "testing/quick" +) + +func TestHashPasswordComparePasswordRoundtrip(t *testing.T) { + property := func(password string) bool { + hash, err := HashPassword(password, nil) + if err != nil { + return false + } + + match, err := ComparePassword(password, hash) + if err != nil { + return false + } + + return match + } + + if err := quick.Check(property, nil); err != nil { + t.Errorf("roundtrip property failed: %v", err) + } +} + +func TestComparePasswordWrongPassword(t *testing.T) { + property := func(password, wrongPassword string) bool { + if password == wrongPassword { + return true + } + + hash, err := HashPassword(password, nil) + if err != nil { + return false + } + + match, err := ComparePassword(wrongPassword, hash) + if err != nil { + return false + } + + return !match + } + + if err := quick.Check(property, nil); err != nil { + t.Errorf("wrong password property failed: %v", err) + } +} + +func TestHashPasswordUniqueness(t *testing.T) { + property := func(password string) bool { + hash1, err1 := HashPassword(password, nil) + hash2, err2 := HashPassword(password, nil) + + if err1 != nil || err2 != nil { + return false + } + + return hash1 != hash2 + } + + if err := quick.Check(property, nil); err != nil { + t.Errorf("uniqueness property failed: %v", err) + } +} + +func TestHashPasswordFormat(t *testing.T) { + property := func(password string) bool { + hash, err := HashPassword(password, nil) + if err != nil { + return false + } + + parts := strings.Split(hash, "$") + if len(parts) != 6 { + return false + } + + return parts[0] == "" && parts[1] == "argon2id" + } + + if err := quick.Check(property, nil); err != nil { + t.Errorf("format property failed: %v", err) + } +} + +func TestComparePasswordMalformedHash(t *testing.T) { + tests := []string{ + "", + "invalid", + "$argon2id", + "$argon2id$v=19$m=65536,t=1,p=4", + "$argon2id$v=19$m=65536,t=1,p=4$salt", + "$wrong$v=19$m=65536,t=1,p=4$salt$hash", + "$argon2id$v=999$m=65536,t=1,p=4$salt$hash", + "$argon2id$v=19$invalid$salt$hash", + "$argon2id$v=19$m=65536,t=1,p=4$!!!$hash", + "$argon2id$v=19$m=65536,t=1,p=4$salt$!!!", + } + + for _, malformedHash := range tests { + t.Run("malformed_"+malformedHash, func(t *testing.T) { + match, err := ComparePassword("password", malformedHash) + if err == nil { + t.Error("expected error for malformed hash") + } + if match { + t.Error("should not match with malformed hash") + } + }) + } +} + +func TestEmptyPassword(t *testing.T) { + hash, err := HashPassword("", nil) + if err != nil { + t.Fatalf("HashPassword failed for empty password: %v", err) + } + + match, err := ComparePassword("", hash) + if err != nil { + t.Fatalf("ComparePassword failed: %v", err) + } + if !match { + t.Error("empty password should match its hash") + } + + match, err = ComparePassword("nonempty", hash) + if err != nil { + t.Fatalf("ComparePassword failed: %v", err) + } + if match { + t.Error("non-empty password should not match empty password hash") + } +} + +func TestHashPasswordWithConfig(t *testing.T) { + tests := []struct { + name string + config *Argon2Config + want string // Expected parameters in PHC format + }{ + { + name: "nil config uses defaults", + config: nil, + want: "$argon2id$v=19$m=65536,t=1,p=4$", + }, + { + name: "custom time parameter", + config: &Argon2Config{ + Time: 3, + }, + want: "$argon2id$v=19$m=65536,t=3,p=4$", + }, + { + name: "custom memory parameter", + config: &Argon2Config{ + Memory: 32 * 1024, // 32MB + }, + want: "$argon2id$v=19$m=32768,t=1,p=4$", + }, + { + name: "custom threads parameter", + config: &Argon2Config{ + Threads: 2, + }, + want: "$argon2id$v=19$m=65536,t=1,p=2$", + }, + { + name: "all custom parameters", + config: &Argon2Config{ + Time: 2, + Memory: 128 * 1024, + Threads: 8, + KeyLen: 64, + }, + want: "$argon2id$v=19$m=131072,t=2,p=8$", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hash, err := HashPassword("test", tt.config) + if err != nil { + t.Fatalf("HashPassword() error = %v", err) + } + + // Check that the hash starts with expected PHC format parameters + if !strings.HasPrefix(hash, tt.want) { + t.Errorf("HashPassword() = %v, want prefix %v", hash, tt.want) + } + + // Verify password still works + match, err := ComparePassword("test", hash) + if err != nil { + t.Fatalf("ComparePassword() error = %v", err) + } + if !match { + t.Error("Password should match its hash") + } + + // Verify wrong password doesn't match + match, err = ComparePassword("wrong", hash) + if err != nil { + t.Fatalf("ComparePassword() error = %v", err) + } + if match { + t.Error("Wrong password should not match hash") + } + }) + } +} + +func TestDefaultArgon2Config(t *testing.T) { + if defaultArgon2Config.Time != 1 { + t.Errorf("Expected defaultArgon2Config.Time=1, got %d", defaultArgon2Config.Time) + } + if defaultArgon2Config.Memory != 64*1024 { + t.Errorf("Expected defaultArgon2Config.Memory=65536, got %d", defaultArgon2Config.Memory) + } + if defaultArgon2Config.Threads != 4 { + t.Errorf("Expected defaultArgon2Config.Threads=4, got %d", defaultArgon2Config.Threads) + } + if defaultArgon2Config.KeyLen != 32 { + t.Errorf("Expected defaultArgon2Config.KeyLen=32, got %d", defaultArgon2Config.KeyLen) + } +} -- cgit v1.2.3