From e57ca3cdaa0650c3fa35411c312c861e018d04b6 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Tue, 23 Apr 2024 12:43:14 +0100 Subject: [PATCH] feat(scripts): add script to check schema between migrations (#13037) - migrations: allow passing in a custom migrate.FS - gen/dump: extract some functions to dbtestutil - scripts: write script to test migrations --- Makefile | 7 ++ coderd/database/dbtestutil/db.go | 93 +++++++++++++--- coderd/database/gen/dump/main.go | 85 +-------------- coderd/database/migrations/migrate.go | 22 ++-- scripts/migrate-test/main.go | 147 ++++++++++++++++++++++++++ 5 files changed, 252 insertions(+), 102 deletions(-) create mode 100644 scripts/migrate-test/main.go diff --git a/Makefile b/Makefile index e588279384..5c74420357 100644 --- a/Makefile +++ b/Makefile @@ -783,6 +783,13 @@ test-postgres: test-postgres-docker -count=1 .PHONY: test-postgres +test-migrations: test-postgres-docker + echo "--- test migrations" + COMMIT_FROM=$(shell git rev-parse --short HEAD) + COMMIT_TO=$(shell git rev-parse --short main) + DB_NAME=$(shell go run scripts/migrate-ci/main.go) + go run ./scripts/migrate-test/main.go --from="$$COMMIT_FROM" --to="$$COMMIT_TO" --postgres-url="postgresql://postgres:postgres@localhost:5432/$$DB_NAME?sslmode=disable" + # NOTE: we set --memory to the same size as a GitHub runner. test-postgres-docker: docker rm -f test-postgres-docker || true diff --git a/coderd/database/dbtestutil/db.go b/coderd/database/dbtestutil/db.go index baea42005b..16eb3393ca 100644 --- a/coderd/database/dbtestutil/db.go +++ b/coderd/database/dbtestutil/db.go @@ -10,6 +10,7 @@ import ( "os/exec" "path/filepath" "regexp" + "strconv" "strings" "testing" "time" @@ -184,20 +185,21 @@ func DumpOnFailure(t testing.TB, connectionURL string) { now := time.Now() timeSuffix := fmt.Sprintf("%d%d%d%d%d%d", now.Year(), now.Month(), now.Day(), now.Hour(), now.Minute(), now.Second()) outPath := filepath.Join(cwd, snakeCaseName+"."+timeSuffix+".test.sql") - dump, err := pgDump(connectionURL) + dump, err := PGDump(connectionURL) if err != nil { t.Errorf("dump on failure: failed to run pg_dump") return } - if err := os.WriteFile(outPath, filterDump(dump), 0o600); err != nil { + if err := os.WriteFile(outPath, normalizeDump(dump), 0o600); err != nil { t.Errorf("dump on failure: failed to write: %s", err.Error()) return } t.Logf("Dumped database to %q due to failed test. I hope you find what you're looking for!", outPath) } -// pgDump runs pg_dump against dbURL and returns the output. -func pgDump(dbURL string) ([]byte, error) { +// PGDump runs pg_dump against dbURL and returns the output. +// It is used by DumpOnFailure(). +func PGDump(dbURL string) ([]byte, error) { if _, err := exec.LookPath("pg_dump"); err != nil { return nil, xerrors.Errorf("could not find pg_dump in path: %w", err) } @@ -230,16 +232,79 @@ func pgDump(dbURL string) ([]byte, error) { return stdout.Bytes(), nil } -// Unfortunately, some insert expressions span multiple lines. -// The below may be over-permissive but better that than truncating data. -var insertExpr = regexp.MustCompile(`(?s)\bINSERT[^;]+;`) +const minimumPostgreSQLVersion = 13 -func filterDump(dump []byte) []byte { - var buf bytes.Buffer - matches := insertExpr.FindAll(dump, -1) - for _, m := range matches { - _, _ = buf.Write(m) - _, _ = buf.WriteRune('\n') +// PGDumpSchemaOnly is for use by gen/dump only. +// It runs pg_dump against dbURL and sets a consistent timezone and encoding. +func PGDumpSchemaOnly(dbURL string) ([]byte, error) { + hasPGDump := false + if _, err := exec.LookPath("pg_dump"); err == nil { + out, err := exec.Command("pg_dump", "--version").Output() + if err == nil { + // Parse output: + // pg_dump (PostgreSQL) 14.5 (Ubuntu 14.5-0ubuntu0.22.04.1) + parts := strings.Split(string(out), " ") + if len(parts) > 2 { + version, err := strconv.Atoi(strings.Split(parts[2], ".")[0]) + if err == nil && version >= minimumPostgreSQLVersion { + hasPGDump = true + } + } + } } - return buf.Bytes() + + cmdArgs := []string{ + "pg_dump", + "--schema-only", + dbURL, + "--no-privileges", + "--no-owner", + "--no-privileges", + "--no-publication", + "--no-security-labels", + "--no-subscriptions", + "--no-tablespaces", + + // We never want to manually generate + // queries executing against this table. + "--exclude-table=schema_migrations", + } + + if !hasPGDump { + cmdArgs = append([]string{ + "docker", + "run", + "--rm", + "--network=host", + fmt.Sprintf("gcr.io/coder-dev-1/postgres:%d", minimumPostgreSQLVersion), + }, cmdArgs...) + } + cmd := exec.Command(cmdArgs[0], cmdArgs[1:]...) //#nosec + cmd.Env = append(os.Environ(), []string{ + "PGTZ=UTC", + "PGCLIENTENCODING=UTF8", + }...) + var output bytes.Buffer + cmd.Stdout = &output + cmd.Stderr = os.Stderr + err := cmd.Run() + if err != nil { + return nil, err + } + return normalizeDump(output.Bytes()), nil +} + +func normalizeDump(schema []byte) []byte { + // Remove all comments. + schema = regexp.MustCompile(`(?im)^(--.*)$`).ReplaceAll(schema, []byte{}) + // Public is implicit in the schema. + schema = regexp.MustCompile(`(?im)( |::|'|\()public\.`).ReplaceAll(schema, []byte(`$1`)) + // Remove database settings. + schema = regexp.MustCompile(`(?im)^(SET.*;)`).ReplaceAll(schema, []byte(``)) + // Remove select statements + schema = regexp.MustCompile(`(?im)^(SELECT.*;)`).ReplaceAll(schema, []byte(``)) + // Removes multiple newlines. + schema = regexp.MustCompile(`(?im)\n{3,}`).ReplaceAll(schema, []byte("\n\n")) + + return schema } diff --git a/coderd/database/gen/dump/main.go b/coderd/database/gen/dump/main.go index 1a9debbae2..f563e11426 100644 --- a/coderd/database/gen/dump/main.go +++ b/coderd/database/gen/dump/main.go @@ -1,21 +1,16 @@ package main import ( - "bytes" "database/sql" - "fmt" "os" - "os/exec" "path/filepath" "runtime" - "strconv" - "strings" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/migrations" ) -const minimumPostgreSQLVersion = 13 +var preamble = []byte("-- Code generated by 'make coderd/database/generate'. DO NOT EDIT.") func main() { connection, closeFn, err := dbtestutil.Open() @@ -28,95 +23,23 @@ func main() { if err != nil { panic(err) } + defer db.Close() err = migrations.Up(db) if err != nil { panic(err) } - hasPGDump := false - if _, err = exec.LookPath("pg_dump"); err == nil { - out, err := exec.Command("pg_dump", "--version").Output() - if err == nil { - // Parse output: - // pg_dump (PostgreSQL) 14.5 (Ubuntu 14.5-0ubuntu0.22.04.1) - parts := strings.Split(string(out), " ") - if len(parts) > 2 { - version, err := strconv.Atoi(strings.Split(parts[2], ".")[0]) - if err == nil && version >= minimumPostgreSQLVersion { - hasPGDump = true - } - } - } - } - - cmdArgs := []string{ - "pg_dump", - "--schema-only", - connection, - "--no-privileges", - "--no-owner", - - // We never want to manually generate - // queries executing against this table. - "--exclude-table=schema_migrations", - } - - if !hasPGDump { - cmdArgs = append([]string{ - "docker", - "run", - "--rm", - "--network=host", - fmt.Sprintf("gcr.io/coder-dev-1/postgres:%d", minimumPostgreSQLVersion), - }, cmdArgs...) - } - cmd := exec.Command(cmdArgs[0], cmdArgs[1:]...) //#nosec - cmd.Env = append(os.Environ(), []string{ - "PGTZ=UTC", - "PGCLIENTENCODING=UTF8", - }...) - var output bytes.Buffer - cmd.Stdout = &output - cmd.Stderr = os.Stderr - err = cmd.Run() + dumpBytes, err := dbtestutil.PGDumpSchemaOnly(connection) if err != nil { panic(err) } - for _, sed := range []string{ - // Remove all comments. - "/^--/d", - // Public is implicit in the schema. - "s/ public\\./ /g", - "s/::public\\./::/g", - "s/'public\\./'/g", - "s/(public\\./(/g", - // Remove database settings. - "s/SET .* = .*;//g", - // Remove select statements. These aren't useful - // to a reader of the dump. - "s/SELECT.*;//g", - // Removes multiple newlines. - "/^$/N;/^\\n$/D", - } { - cmd := exec.Command("sed", "-e", sed) - cmd.Stdin = bytes.NewReader(output.Bytes()) - output = bytes.Buffer{} - cmd.Stdout = &output - cmd.Stderr = os.Stderr - err = cmd.Run() - if err != nil { - panic(err) - } - } - - dump := fmt.Sprintf("-- Code generated by 'make coderd/database/generate'. DO NOT EDIT.\n%s", output.Bytes()) _, mainPath, _, ok := runtime.Caller(0) if !ok { panic("couldn't get caller path") } - err = os.WriteFile(filepath.Join(mainPath, "..", "..", "..", "dump.sql"), []byte(dump), 0o600) + err = os.WriteFile(filepath.Join(mainPath, "..", "..", "..", "dump.sql"), append(preamble, dumpBytes...), 0o600) if err != nil { panic(err) } diff --git a/coderd/database/migrations/migrate.go b/coderd/database/migrations/migrate.go index adcb1d4c22..213408bbad 100644 --- a/coderd/database/migrations/migrate.go +++ b/coderd/database/migrations/migrate.go @@ -17,9 +17,12 @@ import ( //go:embed *.sql var migrations embed.FS -func setup(db *sql.DB) (source.Driver, *migrate.Migrate, error) { +func setup(db *sql.DB, migs fs.FS) (source.Driver, *migrate.Migrate, error) { + if migs == nil { + migs = migrations + } ctx := context.Background() - sourceDriver, err := iofs.New(migrations, ".") + sourceDriver, err := iofs.New(migs, ".") if err != nil { return nil, nil, xerrors.Errorf("create iofs: %w", err) } @@ -47,8 +50,13 @@ func setup(db *sql.DB) (source.Driver, *migrate.Migrate, error) { } // Up runs SQL migrations to ensure the database schema is up-to-date. -func Up(db *sql.DB) (retErr error) { - _, m, err := setup(db) +func Up(db *sql.DB) error { + return UpWithFS(db, migrations) +} + +// UpWithFS runs SQL migrations in the given fs. +func UpWithFS(db *sql.DB, migs fs.FS) (retErr error) { + _, m, err := setup(db, migs) if err != nil { return xerrors.Errorf("migrate setup: %w", err) } @@ -79,7 +87,7 @@ func Up(db *sql.DB) (retErr error) { // Down runs all down SQL migrations. func Down(db *sql.DB) error { - _, m, err := setup(db) + _, m, err := setup(db, migrations) if err != nil { return xerrors.Errorf("migrate setup: %w", err) } @@ -101,7 +109,7 @@ func Down(db *sql.DB) error { // applied, without making any changes to the database. If not, returns a // non-nil error. func EnsureClean(db *sql.DB) error { - sourceDriver, m, err := setup(db) + sourceDriver, m, err := setup(db, migrations) if err != nil { return xerrors.Errorf("migrate setup: %w", err) } @@ -167,7 +175,7 @@ func CheckLatestVersion(sourceDriver source.Driver, currentVersion uint) error { // Stepper cannot be closed pre-emptively, it must be run to completion // (or until an error is encountered). func Stepper(db *sql.DB) (next func() (version uint, more bool, err error), err error) { - _, m, err := setup(db) + _, m, err := setup(db, migrations) if err != nil { return nil, xerrors.Errorf("migrate setup: %w", err) } diff --git a/scripts/migrate-test/main.go b/scripts/migrate-test/main.go new file mode 100644 index 0000000000..deaa7a021b --- /dev/null +++ b/scripts/migrate-test/main.go @@ -0,0 +1,147 @@ +package main + +import ( + "archive/zip" + "bytes" + "database/sql" + "flag" + "fmt" + "io/fs" + "os" + "os/exec" + "regexp" + + "github.com/google/go-cmp/cmp" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/database/migrations" +) + +// This script validates the migration path between two versions. +// It performs the following actions: +// Given OLD_VERSION and NEW_VERSION: +// 1. Checks out $OLD_VERSION and inits schema at that version. +// 2. Checks out $NEW_VERSION and runs migrations. +// 3. Compares database schema post-migrate to that in VCS. +// If any diffs are found, exits with an error. +func main() { + var ( + migrateFromVersion string + migrateToVersion string + postgresURL string + skipCleanup bool + ) + + flag.StringVar(&migrateFromVersion, "from", "", "Migrate from this version") + flag.StringVar(&migrateToVersion, "to", "", "Migrate to this version") + flag.StringVar(&postgresURL, "postgres-url", "postgresql://postgres:postgres@localhost:5432/postgres?sslmode=disable", "Postgres URL to migrate") + flag.BoolVar(&skipCleanup, "skip-cleanup", false, "Do not clean up on exit.") + flag.Parse() + + if migrateFromVersion == "" || migrateToVersion == "" { + _, _ = fmt.Fprintln(os.Stderr, "must specify --from= and --to=") + os.Exit(1) + } + + _, _ = fmt.Fprintf(os.Stderr, "Read schema at version %q\n", migrateToVersion) + expectedSchemaAfter, err := gitShow("coderd/database/dump.sql", migrateToVersion) + if err != nil { + panic(err) + } + + _, _ = fmt.Fprintf(os.Stderr, "Read migrations for %q\n", migrateFromVersion) + migrateFromFS, err := makeMigrateFS(migrateFromVersion) + if err != nil { + panic(err) + } + _, _ = fmt.Fprintf(os.Stderr, "Read migrations for %q\n", migrateToVersion) + migrateToFS, err := makeMigrateFS(migrateToVersion) + if err != nil { + panic(err) + } + + _, _ = fmt.Fprintf(os.Stderr, "Connect to postgres\n") + conn, err := sql.Open("postgres", postgresURL) + if err != nil { + panic(err) + } + defer conn.Close() + + ver, err := checkMigrateVersion(conn) + if err != nil { + panic(err) + } + if ver < 0 { + _, _ = fmt.Fprintf(os.Stderr, "No previous migration detected.\n") + } else { + _, _ = fmt.Fprintf(os.Stderr, "Detected migration version %d\n", ver) + } + + _, _ = fmt.Fprintf(os.Stderr, "Init database at version %q\n", migrateFromVersion) + if err := migrations.UpWithFS(conn, migrateFromFS); err != nil { + panic(err) + } + + _, _ = fmt.Fprintf(os.Stderr, "Migrate to version %q\n", migrateToVersion) + if err := migrations.UpWithFS(conn, migrateToFS); err != nil { + panic(err) + } + + _, _ = fmt.Fprintf(os.Stderr, "Dump schema at version %q\n", migrateToVersion) + dumpBytesAfter, err := dbtestutil.PGDumpSchemaOnly(postgresURL) + if err != nil { + panic(err) + } + + if diff := cmp.Diff(string(dumpBytesAfter), string(stripGenPreamble(expectedSchemaAfter))); diff != "" { + _, _ = fmt.Fprintf(os.Stderr, "Schema differs from expected after migration: %s\n", diff) + os.Exit(1) + } + _, _ = fmt.Fprintf(os.Stderr, "OK\n") +} + +func makeMigrateFS(version string) (fs.FS, error) { + // Export the migrations from the requested version to a zip archive + out, err := exec.Command("git", "archive", "--format=zip", version, "coderd/database/migrations").CombinedOutput() + if err != nil { + return nil, xerrors.Errorf("git archive: %s\n", out) + } + // Make a zip.Reader on top of it. This implements fs.fs! + zr, err := zip.NewReader(bytes.NewReader(out), int64(len(out))) + if err != nil { + return nil, xerrors.Errorf("create zip reader: %w", err) + } + // Sub-FS to it's rooted at migrations dir. + subbed, err := fs.Sub(zr, "coderd/database/migrations") + if err != nil { + return nil, xerrors.Errorf("sub fs: %w", err) + } + return subbed, nil +} + +func gitShow(path, version string) ([]byte, error) { + out, err := exec.Command("git", "show", version+":"+path).CombinedOutput() //nolint:gosec + if err != nil { + return nil, xerrors.Errorf("git show: %s\n", out) + } + return out, nil +} + +func stripGenPreamble(bs []byte) []byte { + return regexp.MustCompile(`(?im)^(-- Code generated.*DO NOT EDIT.)$`).ReplaceAll(bs, []byte{}) +} + +func checkMigrateVersion(conn *sql.DB) (int, error) { + var version int + rows, err := conn.Query(`SELECT version FROM schema_migrations LIMIT 1;`) + if err != nil { + return -1, nil // not migrated + } + for rows.Next() { + if err := rows.Scan(&version); err != nil { + return 0, xerrors.Errorf("scan version: %w", err) + } + } + return version, nil +}