coder/coderd/provisionerdserver/provisionerdserver_internal...

86 lines
2.2 KiB
Go

package provisionerdserver
import (
"context"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbfake"
"github.com/coder/coder/coderd/database/dbgen"
)
func TestObtainOIDCAccessToken(t *testing.T) {
t.Parallel()
ctx := context.Background()
t.Run("NoToken", func(t *testing.T) {
t.Parallel()
db := dbfake.New()
_, err := obtainOIDCAccessToken(ctx, db, nil, uuid.Nil)
require.NoError(t, err)
})
t.Run("InvalidConfig", func(t *testing.T) {
// We still want OIDC to succeed even if exchanging the token fails.
t.Parallel()
db := dbfake.New()
user := dbgen.User(t, db, database.User{})
dbgen.UserLink(t, db, database.UserLink{
UserID: user.ID,
LoginType: database.LoginTypeOIDC,
OAuthExpiry: database.Now().Add(-time.Hour),
})
_, err := obtainOIDCAccessToken(ctx, db, &oauth2.Config{}, user.ID)
require.NoError(t, err)
})
t.Run("Exchange", func(t *testing.T) {
t.Parallel()
db := dbfake.New()
user := dbgen.User(t, db, database.User{})
dbgen.UserLink(t, db, database.UserLink{
UserID: user.ID,
LoginType: database.LoginTypeOIDC,
OAuthExpiry: database.Now().Add(-time.Hour),
})
_, err := obtainOIDCAccessToken(ctx, db, &oauth2Config{
tokenSource: func() (*oauth2.Token, error) {
return &oauth2.Token{
AccessToken: "token",
}, nil
},
}, user.ID)
require.NoError(t, err)
link, err := db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{
UserID: user.ID,
LoginType: database.LoginTypeOIDC,
})
require.NoError(t, err)
require.Equal(t, "token", link.OAuthAccessToken)
})
}
type oauth2Config struct {
tokenSource oauth2TokenSource
}
func (o *oauth2Config) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
return o.tokenSource
}
func (*oauth2Config) AuthCodeURL(string, ...oauth2.AuthCodeOption) string {
return ""
}
func (*oauth2Config) Exchange(context.Context, string, ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
return &oauth2.Token{}, nil
}
type oauth2TokenSource func() (*oauth2.Token, error)
func (o oauth2TokenSource) Token() (*oauth2.Token, error) {
return o()
}