diff options
author | T <t@tjp.lol> | 2025-06-26 11:42:17 -0600 |
---|---|---|
committer | T <t@tjp.lol> | 2025-07-01 17:50:49 -0600 |
commit | 639ad6a02cbb4b713434671ec09f309aa5410921 (patch) | |
tree | 7dde9cce8136636d11f2f7c961072984cfc705e7 /oauth_login_test.go |
Create authentic_kate: user authentication for go HTTP applications
Diffstat (limited to 'oauth_login_test.go')
-rw-r--r-- | oauth_login_test.go | 369 |
1 files changed, 369 insertions, 0 deletions
diff --git a/oauth_login_test.go b/oauth_login_test.go new file mode 100644 index 0000000..caec89b --- /dev/null +++ b/oauth_login_test.go @@ -0,0 +1,369 @@ +package kate + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync" + "testing" + "time" +) + +// Mock implementation of OAuthDataStore for testing +type mockOAuthDataStore[T any] struct { + userData map[string]T + states map[string]OAuthState + mu sync.RWMutex + getOrCreateUser func(email string) (T, error) + stateIDCounter int +} + +func newMockOAuthDataStore[T any](getOrCreateUser func(email string) (T, error)) *mockOAuthDataStore[T] { + return &mockOAuthDataStore[T]{ + userData: make(map[string]T), + states: make(map[string]OAuthState), + getOrCreateUser: getOrCreateUser, + } +} + +func (m *mockOAuthDataStore[T]) GetOrCreateUser(email string) (T, error) { + return m.getOrCreateUser(email) +} + +func (m *mockOAuthDataStore[T]) StoreState(state OAuthState) (string, error) { + m.mu.Lock() + defer m.mu.Unlock() + + m.stateIDCounter++ + stateID := fmt.Sprintf("state_%d", m.stateIDCounter) + m.states[stateID] = state + return stateID, nil +} + +func (m *mockOAuthDataStore[T]) GetAndClearState(id string) (*OAuthState, error) { + m.mu.Lock() + defer m.mu.Unlock() + + state, exists := m.states[id] + if !exists { + return nil, nil + } + + delete(m.states, id) + return &state, nil +} + +func TestOAuthLoginHandler(t *testing.T) { + auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{ + SerDes: testUserSerDes{}, + CookieName: "test_session", + }) + + mockStore := newMockOAuthDataStore(func(email string) (testUser, error) { + return testUser{Username: email, ID: 1}, nil + }) + + config := GoogleOAuthConfig( + "test-client-id", + "test-client-secret", + "http://localhost:8080/auth/callback", + mockStore, + ) + config.Redirects.Default = "/dashboard" + config.Redirects.AllowedPrefixes = []string{"/app/", "/admin/"} + + handler := auth.OAuthLoginHandler(config) + + tests := []struct { + name string + method string + queryParams url.Values + formData url.Values + expectedStatus int + checkRedirect bool + expectedContains []string + }{ + { + name: "GET request initiates OAuth flow", + method: "GET", + queryParams: url.Values{}, + expectedStatus: http.StatusSeeOther, + checkRedirect: true, + expectedContains: []string{ + "accounts.google.com", + "client_id=test-client-id", + "redirect_uri=http%3A%2F%2Flocalhost%3A8080%2Fauth%2Fcallback", + "response_type=code", + "state=", + "scope=email", + }, + }, + { + name: "GET request with redirect parameter", + method: "GET", + queryParams: url.Values{"redirect": {"/app/settings"}}, + expectedStatus: http.StatusSeeOther, + checkRedirect: true, + expectedContains: []string{ + "accounts.google.com", + "client_id=test-client-id", + }, + }, + { + name: "POST method not allowed", + method: "POST", + formData: url.Values{"redirect": {"/app/profile"}}, + expectedStatus: http.StatusMethodNotAllowed, + checkRedirect: false, + }, + { + name: "invalid redirect falls back to default", + method: "GET", + queryParams: url.Values{"redirect": {"/evil/"}}, + expectedStatus: http.StatusSeeOther, + checkRedirect: true, + expectedContains: []string{ + "accounts.google.com", + "client_id=test-client-id", + }, + }, + { + name: "PUT method not allowed", + method: "PUT", + expectedStatus: http.StatusMethodNotAllowed, + checkRedirect: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var req *http.Request + reqURL := "/oauth/login" + + if len(tt.queryParams) > 0 { + reqURL += "?" + tt.queryParams.Encode() + } + + if tt.method == "POST" && tt.formData != nil { + req = httptest.NewRequest(tt.method, reqURL, strings.NewReader(tt.formData.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + } else { + req = httptest.NewRequest(tt.method, reqURL, nil) + } + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != tt.expectedStatus { + t.Errorf("expected status %d, got %d", tt.expectedStatus, rr.Code) + } + + if tt.checkRedirect { + location := rr.Header().Get("Location") + if location == "" { + t.Error("expected redirect location to be set") + } + + for _, expected := range tt.expectedContains { + if !strings.Contains(location, expected) { + t.Errorf("expected redirect URL to contain %q, got %q", expected, location) + } + } + } + }) + } +} + +func TestOAuthCallbackHandler(t *testing.T) { + auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{ + SerDes: testUserSerDes{}, + CookieName: "test_session", + }) + + mockStore := newMockOAuthDataStore(func(email string) (testUser, error) { + return testUser{Username: email, ID: 1}, nil + }) + + config := GoogleOAuthConfig( + "test-client-id", + "test-client-secret", + "http://localhost:8080/auth/callback", + mockStore, + ) + config.Redirects.Default = "/dashboard" + + handler := auth.OAuthCallbackHandler(config) + + tests := []struct { + name string + method string + queryParams url.Values + expectedStatus int + }{ + { + name: "GET callback with invalid state", + method: "GET", + queryParams: url.Values{"code": {"test-code"}, "state": {"test-state"}}, + expectedStatus: http.StatusBadRequest, + }, + { + name: "POST method not allowed", + method: "POST", + expectedStatus: http.StatusMethodNotAllowed, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reqURL := "/oauth/callback" + if len(tt.queryParams) > 0 { + reqURL += "?" + tt.queryParams.Encode() + } + + req := httptest.NewRequest(tt.method, reqURL, nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != tt.expectedStatus { + t.Errorf("expected status %d, got %d", tt.expectedStatus, rr.Code) + } + }) + } +} + +func TestGoogleOAuthConfig(t *testing.T) { + mockStore := newMockOAuthDataStore(func(email string) (testUser, error) { + return testUser{Username: email, ID: 1}, nil + }) + + config := GoogleOAuthConfig( + "test-client-id", + "test-client-secret", + "http://localhost:8080/auth/callback", + mockStore, + ) + + if config.OAuth2Config.ClientID != "test-client-id" { + t.Errorf("expected ClientID 'test-client-id', got %q", config.OAuth2Config.ClientID) + } + if config.OAuth2Config.ClientSecret != "test-client-secret" { + t.Errorf("expected ClientSecret 'test-client-secret', got %q", config.OAuth2Config.ClientSecret) + } + if config.OAuth2Config.RedirectURL != "http://localhost:8080/auth/callback" { + t.Errorf("expected RedirectURL 'http://localhost:8080/auth/callback', got %q", config.OAuth2Config.RedirectURL) + } + if config.OAuth2Config.Endpoint.AuthURL != "https://accounts.google.com/o/oauth2/auth" { + t.Errorf("expected AuthURL 'https://accounts.google.com/o/oauth2/auth', got %q", config.OAuth2Config.Endpoint.AuthURL) + } + if config.OAuth2Config.Endpoint.TokenURL != "https://oauth2.googleapis.com/token" { + t.Errorf("expected TokenURL 'https://oauth2.googleapis.com/token', got %q", config.OAuth2Config.Endpoint.TokenURL) + } + if config.UserInfoURL != "https://www.googleapis.com/oauth2/v2/userinfo" { + t.Errorf("expected UserInfoURL 'https://www.googleapis.com/oauth2/v2/userinfo', got %q", config.UserInfoURL) + } + expectedScopes := []string{"email"} + if len(config.OAuth2Config.Scopes) != len(expectedScopes) { + t.Errorf("expected %d scopes, got %d", len(expectedScopes), len(config.OAuth2Config.Scopes)) + } else { + for i, scope := range expectedScopes { + if config.OAuth2Config.Scopes[i] != scope { + t.Errorf("expected scope[%d] %q, got %q", i, scope, config.OAuth2Config.Scopes[i]) + } + } + } +} + +func TestGitHubOAuthConfig(t *testing.T) { + mockStore := newMockOAuthDataStore(func(email string) (testUser, error) { + return testUser{Username: email, ID: 1}, nil + }) + + config := GitHubOAuthConfig( + "test-github-client-id", + "test-github-client-secret", + "http://localhost:8080/auth/github/callback", + mockStore, + ) + + if config.OAuth2Config.ClientID != "test-github-client-id" { + t.Errorf("expected ClientID 'test-github-client-id', got %q", config.OAuth2Config.ClientID) + } + if config.OAuth2Config.ClientSecret != "test-github-client-secret" { + t.Errorf("expected ClientSecret 'test-github-client-secret', got %q", config.OAuth2Config.ClientSecret) + } + if config.OAuth2Config.RedirectURL != "http://localhost:8080/auth/github/callback" { + t.Errorf("expected RedirectURL 'http://localhost:8080/auth/github/callback', got %q", config.OAuth2Config.RedirectURL) + } + if config.OAuth2Config.Endpoint.AuthURL != "https://github.com/login/oauth/authorize" { + t.Errorf("expected AuthURL 'https://github.com/login/oauth/authorize', got %q", config.OAuth2Config.Endpoint.AuthURL) + } + if config.OAuth2Config.Endpoint.TokenURL != "https://github.com/login/oauth/access_token" { + t.Errorf("expected TokenURL 'https://github.com/login/oauth/access_token', got %q", config.OAuth2Config.Endpoint.TokenURL) + } + if config.UserInfoURL != "https://api.github.com/user" { + t.Errorf("expected UserInfoURL 'https://api.github.com/user', got %q", config.UserInfoURL) + } + expectedScopes := []string{"user:email"} + if len(config.OAuth2Config.Scopes) != len(expectedScopes) { + t.Errorf("expected %d scopes, got %d", len(expectedScopes), len(config.OAuth2Config.Scopes)) + } else { + for i, scope := range expectedScopes { + if config.OAuth2Config.Scopes[i] != scope { + t.Errorf("expected scope[%d] %q, got %q", i, scope, config.OAuth2Config.Scopes[i]) + } + } + } +} + +func TestOAuthStateStore(t *testing.T) { + mockStore := newMockOAuthDataStore(func(email string) (testUser, error) { + return testUser{Username: email, ID: 1}, nil + }) + + now := time.Now().Truncate(time.Second) + state := OAuthState{ + Redirect: "/dashboard", + ExpiresAt: now, + } + + // Test storing state + stateID, err := mockStore.StoreState(state) + if err != nil { + t.Errorf("StoreState() error = %v", err) + return + } + + if stateID == "" { + t.Error("StoreState() should return non-empty state ID") + } + + // Test retrieving state + retrievedState, err := mockStore.GetAndClearState(stateID) + if err != nil { + t.Errorf("GetAndClearState() error = %v", err) + return + } + + if retrievedState == nil { + t.Error("GetAndClearState() should return state data") + return + } + + if retrievedState.Redirect != state.Redirect { + t.Errorf("retrieved Redirect = %q, want %q", retrievedState.Redirect, state.Redirect) + } + if !retrievedState.ExpiresAt.Equal(state.ExpiresAt) { + t.Errorf("retrieved ExpiresAt = %v, want %v", retrievedState.ExpiresAt, state.ExpiresAt) + } + + // Test that state is cleared (one-time use) + clearedState, err := mockStore.GetAndClearState(stateID) + if err != nil { + t.Errorf("GetAndClearState() on cleared state error = %v", err) + } + if clearedState != nil { + t.Error("GetAndClearState() should return nil for already used state") + } +} |