mirror of https://gitlab.com/gitlab-org/cli.git
141 lines
3.4 KiB
Go
141 lines
3.4 KiB
Go
package oauth2
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"gitlab.com/gitlab-org/cli/pkg/glinstance"
|
|
"gitlab.com/gitlab-org/cli/pkg/iostreams"
|
|
)
|
|
|
|
func TestHandleAuthRedirect(t *testing.T) {
|
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = w.Write([]byte(`{
|
|
"access_token": "at",
|
|
"refresh_token": "rt",
|
|
"expiresIn": 60
|
|
}`))
|
|
}))
|
|
|
|
cfg := stubConfig{
|
|
hosts: map[string]map[string]string{},
|
|
}
|
|
|
|
hostname := strings.Split(svr.URL, "://")[1]
|
|
cfg.hosts[hostname] = map[string]string{
|
|
"is_oauth2": "true",
|
|
"oauth2_refresh_token": "refresh_token",
|
|
"token": "access_token",
|
|
"oauth2_code_verifier": "123",
|
|
"oauth2_expiry_date": "13 Mar 23 15:47 GMT",
|
|
"client_id": "321",
|
|
}
|
|
|
|
ios, _, _, _ := iostreams.Test()
|
|
|
|
tokenCh := handleAuthRedirect(ios, "123", hostname, "http", "abc")
|
|
defer close(tokenCh)
|
|
time.Sleep(1 * time.Second)
|
|
|
|
go func() {
|
|
_, err := http.Get("http://localhost:7171/auth/redirect?code=123")
|
|
require.Nil(t, err)
|
|
}()
|
|
|
|
token := <-tokenCh
|
|
assert.Equal(t, "at", token.AccessToken)
|
|
}
|
|
|
|
func TestRefreshToken(t *testing.T) {
|
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = w.Write([]byte(`{
|
|
"access_token": "at",
|
|
"refresh_token": "rt",
|
|
"expiresIn": 60
|
|
}`))
|
|
}))
|
|
|
|
cfg := stubConfig{
|
|
hosts: map[string]map[string]string{},
|
|
}
|
|
|
|
hostname := strings.Split(svr.URL, "://")[1]
|
|
cfg.hosts[hostname] = map[string]string{
|
|
"is_oauth2": "true",
|
|
"oauth2_refresh_token": "refresh_token",
|
|
"token": "access_token",
|
|
"oauth2_code_verifier": "123",
|
|
"oauth2_expiry_date": "13 Mar 23 15:47 GMT",
|
|
"client_id": "321",
|
|
}
|
|
|
|
err := RefreshToken(hostname, cfg, "http")
|
|
require.Nil(t, err)
|
|
|
|
accessToken, err := cfg.Get(hostname, "token")
|
|
require.Nil(t, err)
|
|
assert.Equal(t, "at", accessToken)
|
|
|
|
refreshToken, err := cfg.Get(hostname, "oauth2_refresh_token")
|
|
require.Nil(t, err)
|
|
assert.Equal(t, "rt", refreshToken)
|
|
|
|
expiryDateString, err := cfg.Get(hostname, "oauth2_expiry_date")
|
|
require.Nil(t, err)
|
|
_, err = time.Parse(time.RFC822, expiryDateString)
|
|
require.Nil(t, err)
|
|
}
|
|
|
|
func TestClientID(t *testing.T) {
|
|
testCasesTable := []struct {
|
|
name string
|
|
hostname string
|
|
configClientID string
|
|
expectedClientID string
|
|
}{
|
|
{
|
|
name: "managed",
|
|
hostname: glinstance.Default(),
|
|
configClientID: "",
|
|
expectedClientID: glinstance.DefaultClientID(),
|
|
},
|
|
{
|
|
name: "self-managed-complete",
|
|
hostname: "salsa.debian.org",
|
|
configClientID: "321",
|
|
expectedClientID: "321",
|
|
},
|
|
}
|
|
|
|
for _, testCase := range testCasesTable {
|
|
t.Run(testCase.name, func(t *testing.T) {
|
|
cfg := stubConfig{
|
|
hosts: map[string]map[string]string{
|
|
testCase.hostname: {
|
|
"client_id": testCase.configClientID,
|
|
},
|
|
},
|
|
}
|
|
clientID, err := oAuthClientID(cfg, testCase.hostname)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, testCase.expectedClientID, clientID)
|
|
})
|
|
}
|
|
|
|
t.Run("invalid self-managed config", func(t *testing.T) {
|
|
cfg := stubConfig{
|
|
hosts: map[string]map[string]string{
|
|
"salsa.debian.org": {},
|
|
},
|
|
}
|
|
clientID, err := oAuthClientID(cfg, "salsa.debian.org")
|
|
assert.Error(t, err)
|
|
assert.Empty(t, clientID)
|
|
})
|
|
}
|