mirror of https://gitlab.com/gitlab-org/cli.git
193 lines
4.5 KiB
Go
193 lines
4.5 KiB
Go
package oauth2
|
|
|
|
import (
|
|
"context"
|
|
_ "embed"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"time"
|
|
|
|
"gitlab.com/gitlab-org/cli/internal/config"
|
|
"gitlab.com/gitlab-org/cli/pkg/glinstance"
|
|
"gitlab.com/gitlab-org/cli/pkg/iostreams"
|
|
"gitlab.com/gitlab-org/cli/pkg/utils"
|
|
)
|
|
|
|
const (
|
|
redirectURI = "http://localhost:7171/auth/redirect"
|
|
scopes = "openid+profile+read_user+write_repository+api"
|
|
)
|
|
|
|
func oAuthClientID(cfg config.Config, hostname string) (string, error) {
|
|
if glinstance.IsSelfHosted(hostname) {
|
|
clientID, err := cfg.Get(hostname, "client_id")
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if clientID == "" {
|
|
return "", fmt.Errorf("set 'client_id' first with `glab config set client_id <client_id> -g -h %s`", hostname)
|
|
}
|
|
return clientID, nil
|
|
}
|
|
return glinstance.DefaultClientID(), nil
|
|
}
|
|
|
|
func StartFlow(cfg config.Config, io *iostreams.IOStreams, hostname string) (string, error) {
|
|
authURL := fmt.Sprintf("https://%s/oauth/authorize", hostname)
|
|
|
|
clientID, err := oAuthClientID(cfg, hostname)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
state := randomString()
|
|
codeVerifier := randomString()
|
|
codeChallenge := generateCodeChallenge(codeVerifier)
|
|
completeAuthURL := fmt.Sprintf(
|
|
"%s?client_id=%s&redirect_uri=%s&response_type=code&state=%s&scope=%s&code_challenge=%s&code_challenge_method=S256",
|
|
authURL, clientID, redirectURI, state, scopes, codeChallenge)
|
|
|
|
tokenCh := handleAuthRedirect(io, codeVerifier, hostname, "https", clientID)
|
|
defer close(tokenCh)
|
|
|
|
browser, _ := cfg.Get(hostname, "browser")
|
|
if err := utils.OpenInBrowser(completeAuthURL, browser); err != nil {
|
|
fmt.Fprintf(io.StdErr, "Failed opening a browser at %s\n", completeAuthURL)
|
|
fmt.Fprintf(io.StdErr, "Encountered error: %s\n", err)
|
|
fmt.Fprint(io.StdErr, "Please try entering the URL in your browser manually\n")
|
|
}
|
|
token := <-tokenCh
|
|
|
|
err = token.SetConfig(hostname, cfg)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return token.AccessToken, nil
|
|
}
|
|
|
|
func handleAuthRedirect(io *iostreams.IOStreams, codeVerifier, hostname, protocol, clientID string) chan *AuthToken {
|
|
tokenCh := make(chan *AuthToken)
|
|
|
|
server := &http.Server{Addr: ":7171"}
|
|
|
|
http.HandleFunc("/auth/redirect", func(w http.ResponseWriter, r *http.Request) {
|
|
code := r.URL.Query().Get("code")
|
|
token, err := requestToken(hostname, protocol, clientID, code, codeVerifier)
|
|
if err != nil {
|
|
fmt.Fprintf(io.StdErr, "Error occured requesting access token %s", err)
|
|
tokenCh <- nil
|
|
return
|
|
}
|
|
|
|
_, _ = w.Write([]byte("You have authenticated successfully. You can now close this browser window."))
|
|
tokenCh <- token
|
|
_ = server.Shutdown(context.Background())
|
|
})
|
|
|
|
go func() {
|
|
err := http.ListenAndServe(":7171", nil)
|
|
if err != nil {
|
|
fmt.Fprintf(io.StdErr, "Error occured while setting up server %s", err)
|
|
tokenCh <- nil
|
|
}
|
|
}()
|
|
|
|
return tokenCh
|
|
}
|
|
|
|
func requestToken(hostname, protocol, clientID, code, codeVerifier string) (*AuthToken, error) {
|
|
tokenURL := fmt.Sprintf("%s://%s/oauth/token", protocol, hostname)
|
|
|
|
form := url.Values{
|
|
"client_id": []string{clientID},
|
|
"code": []string{code},
|
|
"grant_type": []string{"authorization_code"},
|
|
"redirect_uri": []string{redirectURI},
|
|
"code_verifier": []string{codeVerifier},
|
|
}
|
|
|
|
resp, err := http.PostForm(tokenURL, form)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
respBytes, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
at := AuthToken{}
|
|
|
|
err = json.Unmarshal(respBytes, &at)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
at.CalcExpiresDate()
|
|
at.CodeVerifier = codeVerifier
|
|
|
|
return &at, nil
|
|
}
|
|
|
|
func RefreshToken(hostname string, cfg config.Config, protocol string) error {
|
|
token, err := tokenFromConfig(hostname, cfg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Check if token has expired
|
|
if token.ExpiryDate.After(time.Now()) {
|
|
return nil
|
|
}
|
|
|
|
clientID, err := oAuthClientID(cfg, hostname)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
form := url.Values{
|
|
"client_id": []string{clientID},
|
|
"grant_type": []string{"refresh_token"},
|
|
"refresh_token": []string{token.RefreshToken},
|
|
"redirect_uri": []string{redirectURI},
|
|
"code_verifier": []string{token.CodeVerifier},
|
|
}
|
|
|
|
tokenURL := fmt.Sprintf("%s://%s/oauth/token", protocol, hostname)
|
|
resp, err := http.PostForm(tokenURL, form)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
respBytes, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
at := AuthToken{}
|
|
|
|
err = json.Unmarshal(respBytes, &at)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
at.CalcExpiresDate()
|
|
|
|
err = at.SetConfig(hostname, cfg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = cfg.Write()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|