package kate import ( "fmt" "net/http" "net/http/httptest" "net/url" "strings" "sync" "testing" "time" ) // Mock implementation of OAuthDataStore for testing type mockOAuthDataStore[T any] struct { userData map[string]T states map[string]OAuthState mu sync.RWMutex getOrCreateUser func(email string) (T, error) stateIDCounter int } func newMockOAuthDataStore[T any](getOrCreateUser func(email string) (T, error)) *mockOAuthDataStore[T] { return &mockOAuthDataStore[T]{ userData: make(map[string]T), states: make(map[string]OAuthState), getOrCreateUser: getOrCreateUser, } } func (m *mockOAuthDataStore[T]) GetOrCreateUser(email string) (T, error) { return m.getOrCreateUser(email) } func (m *mockOAuthDataStore[T]) StoreState(state OAuthState) (string, error) { m.mu.Lock() defer m.mu.Unlock() m.stateIDCounter++ stateID := fmt.Sprintf("state_%d", m.stateIDCounter) m.states[stateID] = state return stateID, nil } func (m *mockOAuthDataStore[T]) GetAndClearState(id string) (*OAuthState, error) { m.mu.Lock() defer m.mu.Unlock() state, exists := m.states[id] if !exists { return nil, nil } delete(m.states, id) return &state, nil } func TestOAuthLoginHandler(t *testing.T) { auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{ SerDes: testUserSerDes{}, CookieName: "test_session", }) mockStore := newMockOAuthDataStore(func(email string) (testUser, error) { return testUser{Username: email, ID: 1}, nil }) config := GoogleOAuthConfig( "test-client-id", "test-client-secret", "http://localhost:8080/auth/callback", mockStore, ) config.Redirects.Default = "/dashboard" config.Redirects.AllowedPrefixes = []string{"/app/", "/admin/"} handler := auth.OAuthLoginHandler(config) tests := []struct { name string method string queryParams url.Values formData url.Values expectedStatus int checkRedirect bool expectedContains []string }{ { name: "GET request initiates OAuth flow", method: "GET", queryParams: url.Values{}, expectedStatus: http.StatusSeeOther, checkRedirect: true, expectedContains: []string{ "accounts.google.com", "client_id=test-client-id", "redirect_uri=http%3A%2F%2Flocalhost%3A8080%2Fauth%2Fcallback", "response_type=code", "state=", "scope=email", }, }, { name: "GET request with redirect parameter", method: "GET", queryParams: url.Values{"redirect": {"/app/settings"}}, expectedStatus: http.StatusSeeOther, checkRedirect: true, expectedContains: []string{ "accounts.google.com", "client_id=test-client-id", }, }, { name: "POST method not allowed", method: "POST", formData: url.Values{"redirect": {"/app/profile"}}, expectedStatus: http.StatusMethodNotAllowed, checkRedirect: false, }, { name: "invalid redirect falls back to default", method: "GET", queryParams: url.Values{"redirect": {"/evil/"}}, expectedStatus: http.StatusSeeOther, checkRedirect: true, expectedContains: []string{ "accounts.google.com", "client_id=test-client-id", }, }, { name: "PUT method not allowed", method: "PUT", expectedStatus: http.StatusMethodNotAllowed, checkRedirect: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var req *http.Request reqURL := "/oauth/login" if len(tt.queryParams) > 0 { reqURL += "?" + tt.queryParams.Encode() } if tt.method == "POST" && tt.formData != nil { req = httptest.NewRequest(tt.method, reqURL, strings.NewReader(tt.formData.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") } else { req = httptest.NewRequest(tt.method, reqURL, nil) } rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != tt.expectedStatus { t.Errorf("expected status %d, got %d", tt.expectedStatus, rr.Code) } if tt.checkRedirect { location := rr.Header().Get("Location") if location == "" { t.Error("expected redirect location to be set") } for _, expected := range tt.expectedContains { if !strings.Contains(location, expected) { t.Errorf("expected redirect URL to contain %q, got %q", expected, location) } } } }) } } func TestOAuthCallbackHandler(t *testing.T) { auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{ SerDes: testUserSerDes{}, CookieName: "test_session", }) mockStore := newMockOAuthDataStore(func(email string) (testUser, error) { return testUser{Username: email, ID: 1}, nil }) config := GoogleOAuthConfig( "test-client-id", "test-client-secret", "http://localhost:8080/auth/callback", mockStore, ) config.Redirects.Default = "/dashboard" handler := auth.OAuthCallbackHandler(config) tests := []struct { name string method string queryParams url.Values expectedStatus int }{ { name: "GET callback with invalid state", method: "GET", queryParams: url.Values{"code": {"test-code"}, "state": {"test-state"}}, expectedStatus: http.StatusBadRequest, }, { name: "POST method not allowed", method: "POST", expectedStatus: http.StatusMethodNotAllowed, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { reqURL := "/oauth/callback" if len(tt.queryParams) > 0 { reqURL += "?" + tt.queryParams.Encode() } req := httptest.NewRequest(tt.method, reqURL, nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != tt.expectedStatus { t.Errorf("expected status %d, got %d", tt.expectedStatus, rr.Code) } }) } } func TestGoogleOAuthConfig(t *testing.T) { mockStore := newMockOAuthDataStore(func(email string) (testUser, error) { return testUser{Username: email, ID: 1}, nil }) config := GoogleOAuthConfig( "test-client-id", "test-client-secret", "http://localhost:8080/auth/callback", mockStore, ) if config.OAuth2Config.ClientID != "test-client-id" { t.Errorf("expected ClientID 'test-client-id', got %q", config.OAuth2Config.ClientID) } if config.OAuth2Config.ClientSecret != "test-client-secret" { t.Errorf("expected ClientSecret 'test-client-secret', got %q", config.OAuth2Config.ClientSecret) } if config.OAuth2Config.RedirectURL != "http://localhost:8080/auth/callback" { t.Errorf("expected RedirectURL 'http://localhost:8080/auth/callback', got %q", config.OAuth2Config.RedirectURL) } if config.OAuth2Config.Endpoint.AuthURL != "https://accounts.google.com/o/oauth2/auth" { t.Errorf("expected AuthURL 'https://accounts.google.com/o/oauth2/auth', got %q", config.OAuth2Config.Endpoint.AuthURL) } if config.OAuth2Config.Endpoint.TokenURL != "https://oauth2.googleapis.com/token" { t.Errorf("expected TokenURL 'https://oauth2.googleapis.com/token', got %q", config.OAuth2Config.Endpoint.TokenURL) } if config.UserInfoURL != "https://www.googleapis.com/oauth2/v2/userinfo" { t.Errorf("expected UserInfoURL 'https://www.googleapis.com/oauth2/v2/userinfo', got %q", config.UserInfoURL) } expectedScopes := []string{"email"} if len(config.OAuth2Config.Scopes) != len(expectedScopes) { t.Errorf("expected %d scopes, got %d", len(expectedScopes), len(config.OAuth2Config.Scopes)) } else { for i, scope := range expectedScopes { if config.OAuth2Config.Scopes[i] != scope { t.Errorf("expected scope[%d] %q, got %q", i, scope, config.OAuth2Config.Scopes[i]) } } } } func TestGitHubOAuthConfig(t *testing.T) { mockStore := newMockOAuthDataStore(func(email string) (testUser, error) { return testUser{Username: email, ID: 1}, nil }) config := GitHubOAuthConfig( "test-github-client-id", "test-github-client-secret", "http://localhost:8080/auth/github/callback", mockStore, ) if config.OAuth2Config.ClientID != "test-github-client-id" { t.Errorf("expected ClientID 'test-github-client-id', got %q", config.OAuth2Config.ClientID) } if config.OAuth2Config.ClientSecret != "test-github-client-secret" { t.Errorf("expected ClientSecret 'test-github-client-secret', got %q", config.OAuth2Config.ClientSecret) } if config.OAuth2Config.RedirectURL != "http://localhost:8080/auth/github/callback" { t.Errorf("expected RedirectURL 'http://localhost:8080/auth/github/callback', got %q", config.OAuth2Config.RedirectURL) } if config.OAuth2Config.Endpoint.AuthURL != "https://github.com/login/oauth/authorize" { t.Errorf("expected AuthURL 'https://github.com/login/oauth/authorize', got %q", config.OAuth2Config.Endpoint.AuthURL) } if config.OAuth2Config.Endpoint.TokenURL != "https://github.com/login/oauth/access_token" { t.Errorf("expected TokenURL 'https://github.com/login/oauth/access_token', got %q", config.OAuth2Config.Endpoint.TokenURL) } if config.UserInfoURL != "https://api.github.com/user" { t.Errorf("expected UserInfoURL 'https://api.github.com/user', got %q", config.UserInfoURL) } expectedScopes := []string{"user:email"} if len(config.OAuth2Config.Scopes) != len(expectedScopes) { t.Errorf("expected %d scopes, got %d", len(expectedScopes), len(config.OAuth2Config.Scopes)) } else { for i, scope := range expectedScopes { if config.OAuth2Config.Scopes[i] != scope { t.Errorf("expected scope[%d] %q, got %q", i, scope, config.OAuth2Config.Scopes[i]) } } } } func TestOAuthStateStore(t *testing.T) { mockStore := newMockOAuthDataStore(func(email string) (testUser, error) { return testUser{Username: email, ID: 1}, nil }) now := time.Now().Truncate(time.Second) state := OAuthState{ Redirect: "/dashboard", ExpiresAt: now, } // Test storing state stateID, err := mockStore.StoreState(state) if err != nil { t.Errorf("StoreState() error = %v", err) return } if stateID == "" { t.Error("StoreState() should return non-empty state ID") } // Test retrieving state retrievedState, err := mockStore.GetAndClearState(stateID) if err != nil { t.Errorf("GetAndClearState() error = %v", err) return } if retrievedState == nil { t.Error("GetAndClearState() should return state data") return } if retrievedState.Redirect != state.Redirect { t.Errorf("retrieved Redirect = %q, want %q", retrievedState.Redirect, state.Redirect) } if !retrievedState.ExpiresAt.Equal(state.ExpiresAt) { t.Errorf("retrieved ExpiresAt = %v, want %v", retrievedState.ExpiresAt, state.ExpiresAt) } // Test that state is cleared (one-time use) clearedState, err := mockStore.GetAndClearState(stateID) if err != nil { t.Errorf("GetAndClearState() on cleared state error = %v", err) } if clearedState != nil { t.Error("GetAndClearState() should return nil for already used state") } }