feat: Add reset-password command (#1380)

* allow non-destructively checking if database needs to be migrated

* feat: Add reset-password command

* fix linter errors

* clean up reset-password usage prompt

* Add confirmation to reset-password command

* Ping database before checking migration, to improve error message
This commit is contained in:
David Wahler 2022-05-12 12:32:56 -05:00 committed by GitHub
parent a629a705d0
commit 20916281d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 325 additions and 7 deletions

90
cli/resetpassword.go Normal file
View File

@ -0,0 +1,90 @@
package cli
import (
"database/sql"
"github.com/spf13/cobra"
"golang.org/x/xerrors"
"github.com/coder/coder/cli/cliflag"
"github.com/coder/coder/cli/cliui"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/userpassword"
)
func resetPassword() *cobra.Command {
var (
postgresURL string
)
root := &cobra.Command{
Use: "reset-password <username>",
Short: "Reset a user's password by directly updating the database",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
username := args[0]
sqlDB, err := sql.Open("postgres", postgresURL)
if err != nil {
return xerrors.Errorf("dial postgres: %w", err)
}
defer sqlDB.Close()
err = sqlDB.Ping()
if err != nil {
return xerrors.Errorf("ping postgres: %w", err)
}
err = database.EnsureClean(sqlDB)
if err != nil {
return xerrors.Errorf("database needs migration: %w", err)
}
db := database.New(sqlDB)
user, err := db.GetUserByEmailOrUsername(cmd.Context(), database.GetUserByEmailOrUsernameParams{
Username: username,
})
if err != nil {
return xerrors.Errorf("retrieving user: %w", err)
}
password, err := cliui.Prompt(cmd, cliui.PromptOptions{
Text: "Enter new " + cliui.Styles.Field.Render("password") + ":",
Secret: true,
Validate: cliui.ValidateNotEmpty,
})
if err != nil {
return xerrors.Errorf("password prompt: %w", err)
}
confirmedPassword, err := cliui.Prompt(cmd, cliui.PromptOptions{
Text: "Confirm " + cliui.Styles.Field.Render("password") + ":",
Secret: true,
Validate: cliui.ValidateNotEmpty,
})
if err != nil {
return xerrors.Errorf("confirm password prompt: %w", err)
}
if password != confirmedPassword {
return xerrors.New("Passwords do not match")
}
hashedPassword, err := userpassword.Hash(password)
if err != nil {
return xerrors.Errorf("hash password: %w", err)
}
err = db.UpdateUserHashedPassword(cmd.Context(), database.UpdateUserHashedPasswordParams{
ID: user.ID,
HashedPassword: []byte(hashedPassword),
})
if err != nil {
return xerrors.Errorf("updating password: %w", err)
}
return nil
},
}
cliflag.StringVarP(root.Flags(), &postgresURL, "postgres-url", "", "CODER_PG_CONNECTION_URL", "", "URL of a PostgreSQL database to connect to")
return root
}

106
cli/resetpassword_test.go Normal file
View File

@ -0,0 +1,106 @@
package cli_test
import (
"context"
"net/url"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/coder/coder/cli/clitest"
"github.com/coder/coder/coderd/database/postgres"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/pty/ptytest"
)
func TestResetPassword(t *testing.T) {
t.Parallel()
if runtime.GOOS != "linux" || testing.Short() {
// Skip on non-Linux because it spawns a PostgreSQL instance.
t.SkipNow()
}
const email = "some@one.com"
const username = "example"
const oldPassword = "password"
const newPassword = "password2"
// start postgres and coder server processes
connectionURL, closeFunc, err := postgres.Open()
require.NoError(t, err)
defer closeFunc()
ctx, cancelFunc := context.WithCancel(context.Background())
serverDone := make(chan struct{})
serverCmd, cfg := clitest.New(t, "server", "--address", ":0", "--postgres-url", connectionURL)
go func() {
defer close(serverDone)
err = serverCmd.ExecuteContext(ctx)
require.ErrorIs(t, err, context.Canceled)
}()
var client *codersdk.Client
require.Eventually(t, func() bool {
rawURL, err := cfg.URL().Read()
if err != nil {
return false
}
accessURL, err := url.Parse(rawURL)
require.NoError(t, err)
client = codersdk.New(accessURL)
return true
}, 15*time.Second, 25*time.Millisecond)
_, err = client.CreateFirstUser(ctx, codersdk.CreateFirstUserRequest{
Email: email,
Username: username,
Password: oldPassword,
OrganizationName: "example",
})
require.NoError(t, err)
// reset the password
resetCmd, cmdCfg := clitest.New(t, "reset-password", "--postgres-url", connectionURL, username)
clitest.SetupConfig(t, client, cmdCfg)
cmdDone := make(chan struct{})
pty := ptytest.New(t)
resetCmd.SetIn(pty.Input())
resetCmd.SetOut(pty.Output())
go func() {
defer close(cmdDone)
err = resetCmd.Execute()
require.NoError(t, err)
}()
matches := []struct {
output string
input string
}{
{"Enter new", newPassword},
{"Confirm", newPassword},
}
for _, match := range matches {
pty.ExpectMatch(match.output)
pty.WriteLine(match.input)
}
<-cmdDone
// now try logging in
_, err = client.LoginWithPassword(ctx, codersdk.LoginWithPasswordRequest{
Email: email,
Password: oldPassword,
})
require.Error(t, err)
_, err = client.LoginWithPassword(ctx, codersdk.LoginWithPasswordRequest{
Email: email,
Password: newPassword,
})
require.NoError(t, err)
cancelFunc()
<-serverDone
}

View File

@ -61,6 +61,7 @@ func Root() *cobra.Command {
list(),
login(),
publickey(),
resetPassword(),
server(),
show(),
start(),

View File

@ -4,9 +4,11 @@ import (
"database/sql"
"embed"
"errors"
"os"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/postgres"
"github.com/golang-migrate/migrate/v4/source"
"github.com/golang-migrate/migrate/v4/source/iofs"
"golang.org/x/xerrors"
)
@ -14,28 +16,28 @@ import (
//go:embed migrations/*.sql
var migrations embed.FS
func migrateSetup(db *sql.DB) (*migrate.Migrate, error) {
func migrateSetup(db *sql.DB) (source.Driver, *migrate.Migrate, error) {
sourceDriver, err := iofs.New(migrations, "migrations")
if err != nil {
return nil, xerrors.Errorf("create iofs: %w", err)
return nil, nil, xerrors.Errorf("create iofs: %w", err)
}
dbDriver, err := postgres.WithInstance(db, &postgres.Config{})
if err != nil {
return nil, xerrors.Errorf("wrap postgres connection: %w", err)
return nil, nil, xerrors.Errorf("wrap postgres connection: %w", err)
}
m, err := migrate.NewWithInstance("", sourceDriver, "", dbDriver)
if err != nil {
return nil, xerrors.Errorf("new migrate instance: %w", err)
return nil, nil, xerrors.Errorf("new migrate instance: %w", err)
}
return m, nil
return sourceDriver, m, nil
}
// MigrateUp runs SQL migrations to ensure the database schema is up-to-date.
func MigrateUp(db *sql.DB) error {
m, err := migrateSetup(db)
_, m, err := migrateSetup(db)
if err != nil {
return xerrors.Errorf("migrate setup: %w", err)
}
@ -55,7 +57,7 @@ func MigrateUp(db *sql.DB) error {
// MigrateDown runs all down SQL migrations.
func MigrateDown(db *sql.DB) error {
m, err := migrateSetup(db)
_, m, err := migrateSetup(db)
if err != nil {
return xerrors.Errorf("migrate setup: %w", err)
}
@ -72,3 +74,68 @@ func MigrateDown(db *sql.DB) error {
return nil
}
// EnsureClean checks whether all migrations for the current version have been
// applied, without making any changes to the database. If not, returns a
// non-nil error.
func EnsureClean(db *sql.DB) error {
sourceDriver, m, err := migrateSetup(db)
if err != nil {
return xerrors.Errorf("migrate setup: %w", err)
}
version, dirty, err := m.Version()
if err != nil {
return xerrors.Errorf("get migration version: %w", err)
}
if dirty {
return xerrors.Errorf("database has not been cleanly migrated")
}
// Verify that the database's migration version is "current" by checking
// that a migration with that version exists, but there is no next version.
err = CheckLatestVersion(sourceDriver, version)
if err != nil {
return xerrors.Errorf("database needs migration: %w", err)
}
return nil
}
// Returns nil if currentVersion corresponds to the latest available migration,
// otherwise an error explaining why not.
func CheckLatestVersion(sourceDriver source.Driver, currentVersion uint) error {
// This is ugly, but seems like the only way to do it with the public
// interfaces provided by golang-migrate.
// Check that there is no later version
nextVersion, err := sourceDriver.Next(currentVersion)
if err == nil {
return xerrors.Errorf("current version is %d, but later version %d exists", currentVersion, nextVersion)
}
if !errors.Is(err, os.ErrNotExist) {
return xerrors.Errorf("get next migration after %d: %w", currentVersion, err)
}
// Once we reach this point, we know that either currentVersion doesn't
// exist, or it has no successor (the return value from
// sourceDriver.Next() is the same in either case). So we need to check
// that either it's the first version, or it has a predecessor.
firstVersion, err := sourceDriver.First()
if err != nil {
// the total number of migrations should be non-zero, so this must be
// an actual error, not just a missing file
return xerrors.Errorf("get first migration: %w", err)
}
if firstVersion == currentVersion {
return nil
}
_, err = sourceDriver.Prev(currentVersion)
if err != nil {
return xerrors.Errorf("get previous migration: %w", err)
}
return nil
}

View File

@ -4,8 +4,11 @@ package database_test
import (
"database/sql"
"fmt"
"testing"
"github.com/golang-migrate/migrate/v4/source"
"github.com/golang-migrate/migrate/v4/source/stub"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
@ -75,3 +78,54 @@ func testSQLDB(t testing.TB) *sql.DB {
return db
}
// paralleltest linter doesn't correctly handle table-driven tests (https://github.com/kunwardeep/paralleltest/issues/8)
// nolint:paralleltest
func TestCheckLatestVersion(t *testing.T) {
t.Parallel()
type test struct {
currentVersion uint
existingVersions []uint
expectedResult string
}
tests := []test{
// successful cases
{1, []uint{1}, ""},
{3, []uint{1, 2, 3}, ""},
{3, []uint{1, 3}, ""},
// failure cases
{1, []uint{1, 2}, "current version is 1, but later version 2 exists"},
{2, []uint{1, 2, 3}, "current version is 2, but later version 3 exists"},
{4, []uint{1, 2, 3}, "get previous migration: prev for version 4 : file does not exist"},
{4, []uint{1, 2, 3, 5}, "get previous migration: prev for version 4 : file does not exist"},
}
for i, tc := range tests {
i, tc := i, tc
t.Run(fmt.Sprintf("entry %d", i), func(t *testing.T) {
t.Parallel()
driver, _ := stub.WithInstance(nil, &stub.Config{})
stub, ok := driver.(*stub.Stub)
require.True(t, ok)
for _, version := range tc.existingVersions {
stub.Migrations.Append(&source.Migration{
Version: version,
Identifier: "",
Direction: source.Up,
Raw: "",
})
}
err := database.CheckLatestVersion(driver, tc.currentVersion)
var errMessage string
if err != nil {
errMessage = err.Error()
}
require.Equal(t, tc.expectedResult, errMessage)
})
}
}