diff options
author | T <t@tjp.lol> | 2025-06-26 11:42:17 -0600 |
---|---|---|
committer | T <t@tjp.lol> | 2025-07-01 17:50:49 -0600 |
commit | 639ad6a02cbb4b713434671ec09f309aa5410921 (patch) | |
tree | 7dde9cce8136636d11f2f7c961072984cfc705e7 /oauth_login.go |
Create authentic_kate: user authentication for go HTTP applications
Diffstat (limited to 'oauth_login.go')
-rw-r--r-- | oauth_login.go | 257 |
1 files changed, 257 insertions, 0 deletions
diff --git a/oauth_login.go b/oauth_login.go new file mode 100644 index 0000000..6c16275 --- /dev/null +++ b/oauth_login.go @@ -0,0 +1,257 @@ +package kate + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "golang.org/x/oauth2" + "golang.org/x/oauth2/github" + "golang.org/x/oauth2/google" +) + +// OAuthConfig configures the OAuth login handler behavior. +type OAuthConfig[T any] struct { + // OAuth2Config is the standard OAuth2 configuration + OAuth2Config *oauth2.Config + + // UserInfoURL is the provider endpoint to get user information + UserInfoURL string + + // DataStore provides user data lookup and temporary state storage + DataStore OAuthDataStore[T] + + // Redirects configures post-authentication redirect behavior + Redirects Redirects + + // StateExpiry is how long the OAuth state is valid + StateExpiry time.Duration + + // LogError is an optional function to log errors + LogError func(error) +} + +// OAuthDataStore provides user data lookup and temporary storage for OAuth state data. +type OAuthDataStore[T any] interface { + // GetOrCreateUser retrieves or creates user data by email address. + // + // If the user exists, returns their data. If not, creates a new user + // and returns the new user data. This handles both OAuth login and + // registration in a single operation. + GetOrCreateUser(email string) (T, error) + + // StoreState stores OAuth state data temporarily and returns a unique state ID. + // + // The returned state ID will be used as the OAuth state parameter. + // Implementations should generate a unique, unguessable ID for security. + StoreState(state OAuthState) (string, error) + + // GetAndClearState retrieves and deletes OAuth state data by ID. + // + // Returns the state data if found, otherwise nil. + // The state should be deleted after retrieval to ensure one-time use. + GetAndClearState(id string) (*OAuthState, error) +} + +// GoogleOAuthConfig creates an OAuthConfig for Google OAuth with sensible defaults. +func GoogleOAuthConfig[T any]( + clientID, + clientSecret, + callbackURI string, + dataStore OAuthDataStore[T], +) OAuthConfig[T] { + return OAuthConfig[T]{ + OAuth2Config: &oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + RedirectURL: callbackURI, + Scopes: []string{"email"}, + Endpoint: google.Endpoint, + }, + UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo", + DataStore: dataStore, + Redirects: Redirects{Default: "/", FieldName: "redirect"}, + StateExpiry: 10 * time.Minute, + } +} + +// GitHubOAuthConfig creates an OAuthConfig for GitHub OAuth with sensible defaults. +func GitHubOAuthConfig[T any]( + clientID, + clientSecret, + callbackURI string, + dataStore OAuthDataStore[T], +) OAuthConfig[T] { + return OAuthConfig[T]{ + OAuth2Config: &oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + RedirectURL: callbackURI, + Scopes: []string{"user:email"}, + Endpoint: github.Endpoint, + }, + UserInfoURL: "https://api.github.com/user", + DataStore: dataStore, + Redirects: Redirects{Default: "/", FieldName: "redirect"}, + StateExpiry: 10 * time.Minute, + } +} + +func (oc *OAuthConfig[T]) setDefaults() { + if oc.StateExpiry == 0 { + oc.StateExpiry = 10 * time.Minute + } +} + +func (oc OAuthConfig[T]) logError(err error) { + if oc.LogError != nil { + oc.LogError(err) + } +} + +// OAuthState represents the OAuth state data that needs to be temporarily stored. +type OAuthState struct { + Redirect string + ExpiresAt time.Time +} + +// OAuthLoginHandler returns an HTTP handler that initiates OAuth authentication. +func (a Auth[T]) OAuthLoginHandler(config OAuthConfig[T]) http.Handler { + config.setDefaults() + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + state := OAuthState{ + Redirect: config.Redirects.target(r), + ExpiresAt: time.Now().Add(config.StateExpiry), + } + + stateID, err := config.DataStore.StoreState(state) + if err != nil { + config.logError(err) + http.Error(w, "Failed to store state", http.StatusInternalServerError) + return + } + + authURL := config.OAuth2Config.AuthCodeURL(stateID) + http.Redirect(w, r, authURL, http.StatusSeeOther) + }) +} + +// OAuthCallbackHandler returns an HTTP handler that handles the OAuth callback. +func (a Auth[T]) OAuthCallbackHandler(config OAuthConfig[T]) http.Handler { + config.setDefaults() + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + stateID := r.URL.Query().Get("state") + code := r.URL.Query().Get("code") + errorParam := r.URL.Query().Get("error") + + if errorParam != "" { + config.logError(fmt.Errorf("OAuth error: %s", errorParam)) + http.Error(w, "OAuth authorization failed", http.StatusBadRequest) + return + } + + if code == "" { + http.Error(w, "Missing authorization code", http.StatusBadRequest) + return + } + + if stateID == "" { + http.Error(w, "Missing state parameter", http.StatusBadRequest) + return + } + + state, err := config.DataStore.GetAndClearState(stateID) + if err != nil { + config.logError(err) + http.Error(w, "Error finding login state", http.StatusInternalServerError) + return + } + if state == nil { + http.Error(w, "Invalid login state", http.StatusBadRequest) + return + } + + if time.Now().After(state.ExpiresAt) { + http.Error(w, "Login expired", http.StatusBadRequest) + return + } + + ctx := context.Background() + token, err := config.OAuth2Config.Exchange(ctx, code) + if err != nil { + config.logError(err) + http.Error(w, "Failed to exchange code for token", http.StatusInternalServerError) + return + } + + client := config.OAuth2Config.Client(ctx, token) + resp, err := client.Get(config.UserInfoURL) + if err != nil { + config.logError(err) + http.Error(w, "Failed to get user info", http.StatusInternalServerError) + return + } + defer func() { + if err := resp.Body.Close(); err != nil { + config.logError(err) + } + }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + config.logError(err) + http.Error(w, "Failed to read user info response", http.StatusInternalServerError) + return + } + + var userInfo struct { + Email string `json:"email"` + } + if err := json.Unmarshal(body, &userInfo); err != nil { + config.logError(err) + http.Error(w, "Failed to parse user info", http.StatusInternalServerError) + return + } + email := userInfo.Email + if email == "" { + http.Error(w, "No email address found", http.StatusBadRequest) + return + } + + userData, err := config.DataStore.GetOrCreateUser(email) + if err != nil { + config.logError(err) + http.Error(w, "Failed to get or create user", http.StatusInternalServerError) + return + } + + if err := a.Set(w, userData); err != nil { + config.logError(err) + http.Error(w, "Failed to set authentication cookie", http.StatusInternalServerError) + return + } + + redirectTarget := state.Redirect + if redirectTarget == "" { + redirectTarget = config.Redirects.Default + if redirectTarget == "" { + redirectTarget = "/" + } + } + + http.Redirect(w, r, redirectTarget, http.StatusSeeOther) + }) +} |