summaryrefslogtreecommitdiff
path: root/oauth_login_test.go
diff options
context:
space:
mode:
authorT <t@tjp.lol>2025-06-26 11:42:17 -0600
committerT <t@tjp.lol>2025-07-01 17:50:49 -0600
commit639ad6a02cbb4b713434671ec09f309aa5410921 (patch)
tree7dde9cce8136636d11f2f7c961072984cfc705e7 /oauth_login_test.go
Create authentic_kate: user authentication for go HTTP applications
Diffstat (limited to 'oauth_login_test.go')
-rw-r--r--oauth_login_test.go369
1 files changed, 369 insertions, 0 deletions
diff --git a/oauth_login_test.go b/oauth_login_test.go
new file mode 100644
index 0000000..caec89b
--- /dev/null
+++ b/oauth_login_test.go
@@ -0,0 +1,369 @@
+package kate
+
+import (
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+)
+
+// Mock implementation of OAuthDataStore for testing
+type mockOAuthDataStore[T any] struct {
+ userData map[string]T
+ states map[string]OAuthState
+ mu sync.RWMutex
+ getOrCreateUser func(email string) (T, error)
+ stateIDCounter int
+}
+
+func newMockOAuthDataStore[T any](getOrCreateUser func(email string) (T, error)) *mockOAuthDataStore[T] {
+ return &mockOAuthDataStore[T]{
+ userData: make(map[string]T),
+ states: make(map[string]OAuthState),
+ getOrCreateUser: getOrCreateUser,
+ }
+}
+
+func (m *mockOAuthDataStore[T]) GetOrCreateUser(email string) (T, error) {
+ return m.getOrCreateUser(email)
+}
+
+func (m *mockOAuthDataStore[T]) StoreState(state OAuthState) (string, error) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ m.stateIDCounter++
+ stateID := fmt.Sprintf("state_%d", m.stateIDCounter)
+ m.states[stateID] = state
+ return stateID, nil
+}
+
+func (m *mockOAuthDataStore[T]) GetAndClearState(id string) (*OAuthState, error) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ state, exists := m.states[id]
+ if !exists {
+ return nil, nil
+ }
+
+ delete(m.states, id)
+ return &state, nil
+}
+
+func TestOAuthLoginHandler(t *testing.T) {
+ auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{
+ SerDes: testUserSerDes{},
+ CookieName: "test_session",
+ })
+
+ mockStore := newMockOAuthDataStore(func(email string) (testUser, error) {
+ return testUser{Username: email, ID: 1}, nil
+ })
+
+ config := GoogleOAuthConfig(
+ "test-client-id",
+ "test-client-secret",
+ "http://localhost:8080/auth/callback",
+ mockStore,
+ )
+ config.Redirects.Default = "/dashboard"
+ config.Redirects.AllowedPrefixes = []string{"/app/", "/admin/"}
+
+ handler := auth.OAuthLoginHandler(config)
+
+ tests := []struct {
+ name string
+ method string
+ queryParams url.Values
+ formData url.Values
+ expectedStatus int
+ checkRedirect bool
+ expectedContains []string
+ }{
+ {
+ name: "GET request initiates OAuth flow",
+ method: "GET",
+ queryParams: url.Values{},
+ expectedStatus: http.StatusSeeOther,
+ checkRedirect: true,
+ expectedContains: []string{
+ "accounts.google.com",
+ "client_id=test-client-id",
+ "redirect_uri=http%3A%2F%2Flocalhost%3A8080%2Fauth%2Fcallback",
+ "response_type=code",
+ "state=",
+ "scope=email",
+ },
+ },
+ {
+ name: "GET request with redirect parameter",
+ method: "GET",
+ queryParams: url.Values{"redirect": {"/app/settings"}},
+ expectedStatus: http.StatusSeeOther,
+ checkRedirect: true,
+ expectedContains: []string{
+ "accounts.google.com",
+ "client_id=test-client-id",
+ },
+ },
+ {
+ name: "POST method not allowed",
+ method: "POST",
+ formData: url.Values{"redirect": {"/app/profile"}},
+ expectedStatus: http.StatusMethodNotAllowed,
+ checkRedirect: false,
+ },
+ {
+ name: "invalid redirect falls back to default",
+ method: "GET",
+ queryParams: url.Values{"redirect": {"/evil/"}},
+ expectedStatus: http.StatusSeeOther,
+ checkRedirect: true,
+ expectedContains: []string{
+ "accounts.google.com",
+ "client_id=test-client-id",
+ },
+ },
+ {
+ name: "PUT method not allowed",
+ method: "PUT",
+ expectedStatus: http.StatusMethodNotAllowed,
+ checkRedirect: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var req *http.Request
+ reqURL := "/oauth/login"
+
+ if len(tt.queryParams) > 0 {
+ reqURL += "?" + tt.queryParams.Encode()
+ }
+
+ if tt.method == "POST" && tt.formData != nil {
+ req = httptest.NewRequest(tt.method, reqURL, strings.NewReader(tt.formData.Encode()))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ } else {
+ req = httptest.NewRequest(tt.method, reqURL, nil)
+ }
+
+ rr := httptest.NewRecorder()
+ handler.ServeHTTP(rr, req)
+
+ if rr.Code != tt.expectedStatus {
+ t.Errorf("expected status %d, got %d", tt.expectedStatus, rr.Code)
+ }
+
+ if tt.checkRedirect {
+ location := rr.Header().Get("Location")
+ if location == "" {
+ t.Error("expected redirect location to be set")
+ }
+
+ for _, expected := range tt.expectedContains {
+ if !strings.Contains(location, expected) {
+ t.Errorf("expected redirect URL to contain %q, got %q", expected, location)
+ }
+ }
+ }
+ })
+ }
+}
+
+func TestOAuthCallbackHandler(t *testing.T) {
+ auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{
+ SerDes: testUserSerDes{},
+ CookieName: "test_session",
+ })
+
+ mockStore := newMockOAuthDataStore(func(email string) (testUser, error) {
+ return testUser{Username: email, ID: 1}, nil
+ })
+
+ config := GoogleOAuthConfig(
+ "test-client-id",
+ "test-client-secret",
+ "http://localhost:8080/auth/callback",
+ mockStore,
+ )
+ config.Redirects.Default = "/dashboard"
+
+ handler := auth.OAuthCallbackHandler(config)
+
+ tests := []struct {
+ name string
+ method string
+ queryParams url.Values
+ expectedStatus int
+ }{
+ {
+ name: "GET callback with invalid state",
+ method: "GET",
+ queryParams: url.Values{"code": {"test-code"}, "state": {"test-state"}},
+ expectedStatus: http.StatusBadRequest,
+ },
+ {
+ name: "POST method not allowed",
+ method: "POST",
+ expectedStatus: http.StatusMethodNotAllowed,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ reqURL := "/oauth/callback"
+ if len(tt.queryParams) > 0 {
+ reqURL += "?" + tt.queryParams.Encode()
+ }
+
+ req := httptest.NewRequest(tt.method, reqURL, nil)
+ rr := httptest.NewRecorder()
+ handler.ServeHTTP(rr, req)
+
+ if rr.Code != tt.expectedStatus {
+ t.Errorf("expected status %d, got %d", tt.expectedStatus, rr.Code)
+ }
+ })
+ }
+}
+
+func TestGoogleOAuthConfig(t *testing.T) {
+ mockStore := newMockOAuthDataStore(func(email string) (testUser, error) {
+ return testUser{Username: email, ID: 1}, nil
+ })
+
+ config := GoogleOAuthConfig(
+ "test-client-id",
+ "test-client-secret",
+ "http://localhost:8080/auth/callback",
+ mockStore,
+ )
+
+ if config.OAuth2Config.ClientID != "test-client-id" {
+ t.Errorf("expected ClientID 'test-client-id', got %q", config.OAuth2Config.ClientID)
+ }
+ if config.OAuth2Config.ClientSecret != "test-client-secret" {
+ t.Errorf("expected ClientSecret 'test-client-secret', got %q", config.OAuth2Config.ClientSecret)
+ }
+ if config.OAuth2Config.RedirectURL != "http://localhost:8080/auth/callback" {
+ t.Errorf("expected RedirectURL 'http://localhost:8080/auth/callback', got %q", config.OAuth2Config.RedirectURL)
+ }
+ if config.OAuth2Config.Endpoint.AuthURL != "https://accounts.google.com/o/oauth2/auth" {
+ t.Errorf("expected AuthURL 'https://accounts.google.com/o/oauth2/auth', got %q", config.OAuth2Config.Endpoint.AuthURL)
+ }
+ if config.OAuth2Config.Endpoint.TokenURL != "https://oauth2.googleapis.com/token" {
+ t.Errorf("expected TokenURL 'https://oauth2.googleapis.com/token', got %q", config.OAuth2Config.Endpoint.TokenURL)
+ }
+ if config.UserInfoURL != "https://www.googleapis.com/oauth2/v2/userinfo" {
+ t.Errorf("expected UserInfoURL 'https://www.googleapis.com/oauth2/v2/userinfo', got %q", config.UserInfoURL)
+ }
+ expectedScopes := []string{"email"}
+ if len(config.OAuth2Config.Scopes) != len(expectedScopes) {
+ t.Errorf("expected %d scopes, got %d", len(expectedScopes), len(config.OAuth2Config.Scopes))
+ } else {
+ for i, scope := range expectedScopes {
+ if config.OAuth2Config.Scopes[i] != scope {
+ t.Errorf("expected scope[%d] %q, got %q", i, scope, config.OAuth2Config.Scopes[i])
+ }
+ }
+ }
+}
+
+func TestGitHubOAuthConfig(t *testing.T) {
+ mockStore := newMockOAuthDataStore(func(email string) (testUser, error) {
+ return testUser{Username: email, ID: 1}, nil
+ })
+
+ config := GitHubOAuthConfig(
+ "test-github-client-id",
+ "test-github-client-secret",
+ "http://localhost:8080/auth/github/callback",
+ mockStore,
+ )
+
+ if config.OAuth2Config.ClientID != "test-github-client-id" {
+ t.Errorf("expected ClientID 'test-github-client-id', got %q", config.OAuth2Config.ClientID)
+ }
+ if config.OAuth2Config.ClientSecret != "test-github-client-secret" {
+ t.Errorf("expected ClientSecret 'test-github-client-secret', got %q", config.OAuth2Config.ClientSecret)
+ }
+ if config.OAuth2Config.RedirectURL != "http://localhost:8080/auth/github/callback" {
+ t.Errorf("expected RedirectURL 'http://localhost:8080/auth/github/callback', got %q", config.OAuth2Config.RedirectURL)
+ }
+ if config.OAuth2Config.Endpoint.AuthURL != "https://github.com/login/oauth/authorize" {
+ t.Errorf("expected AuthURL 'https://github.com/login/oauth/authorize', got %q", config.OAuth2Config.Endpoint.AuthURL)
+ }
+ if config.OAuth2Config.Endpoint.TokenURL != "https://github.com/login/oauth/access_token" {
+ t.Errorf("expected TokenURL 'https://github.com/login/oauth/access_token', got %q", config.OAuth2Config.Endpoint.TokenURL)
+ }
+ if config.UserInfoURL != "https://api.github.com/user" {
+ t.Errorf("expected UserInfoURL 'https://api.github.com/user', got %q", config.UserInfoURL)
+ }
+ expectedScopes := []string{"user:email"}
+ if len(config.OAuth2Config.Scopes) != len(expectedScopes) {
+ t.Errorf("expected %d scopes, got %d", len(expectedScopes), len(config.OAuth2Config.Scopes))
+ } else {
+ for i, scope := range expectedScopes {
+ if config.OAuth2Config.Scopes[i] != scope {
+ t.Errorf("expected scope[%d] %q, got %q", i, scope, config.OAuth2Config.Scopes[i])
+ }
+ }
+ }
+}
+
+func TestOAuthStateStore(t *testing.T) {
+ mockStore := newMockOAuthDataStore(func(email string) (testUser, error) {
+ return testUser{Username: email, ID: 1}, nil
+ })
+
+ now := time.Now().Truncate(time.Second)
+ state := OAuthState{
+ Redirect: "/dashboard",
+ ExpiresAt: now,
+ }
+
+ // Test storing state
+ stateID, err := mockStore.StoreState(state)
+ if err != nil {
+ t.Errorf("StoreState() error = %v", err)
+ return
+ }
+
+ if stateID == "" {
+ t.Error("StoreState() should return non-empty state ID")
+ }
+
+ // Test retrieving state
+ retrievedState, err := mockStore.GetAndClearState(stateID)
+ if err != nil {
+ t.Errorf("GetAndClearState() error = %v", err)
+ return
+ }
+
+ if retrievedState == nil {
+ t.Error("GetAndClearState() should return state data")
+ return
+ }
+
+ if retrievedState.Redirect != state.Redirect {
+ t.Errorf("retrieved Redirect = %q, want %q", retrievedState.Redirect, state.Redirect)
+ }
+ if !retrievedState.ExpiresAt.Equal(state.ExpiresAt) {
+ t.Errorf("retrieved ExpiresAt = %v, want %v", retrievedState.ExpiresAt, state.ExpiresAt)
+ }
+
+ // Test that state is cleared (one-time use)
+ clearedState, err := mockStore.GetAndClearState(stateID)
+ if err != nil {
+ t.Errorf("GetAndClearState() on cleared state error = %v", err)
+ }
+ if clearedState != nil {
+ t.Error("GetAndClearState() should return nil for already used state")
+ }
+}