From bbb0fab1de45fda4724f7d9a84508ac098dce5b0 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Tue, 20 Jun 2023 11:24:33 -0500 Subject: [PATCH] chore: merge database gen scripts (#8073) * chore: merge database gen scripts * Fix type params gen * Merge enum into dbgen --- Makefile | 2 +- coderd/database/db.go | 7 + coderd/database/gen/authz/main.go | 199 ---------- coderd/database/gen/enum/main.go | 120 ------ coderd/database/gen/fake/main.go | 285 --------------- coderd/database/gen/metrics/main.go | 210 ----------- coderd/database/generate.sh | 13 +- coderd/database/lock.go | 8 - scripts/dbgen/main.go | 542 ++++++++++++++++++++++++++++ 9 files changed, 551 insertions(+), 835 deletions(-) delete mode 100644 coderd/database/gen/authz/main.go delete mode 100644 coderd/database/gen/enum/main.go delete mode 100644 coderd/database/gen/fake/main.go delete mode 100644 coderd/database/gen/metrics/main.go delete mode 100644 coderd/database/lock.go create mode 100644 scripts/dbgen/main.go diff --git a/Makefile b/Makefile index a96f216ee0..01024643b0 100644 --- a/Makefile +++ b/Makefile @@ -486,7 +486,7 @@ coderd/database/dump.sql: coderd/database/gen/dump/main.go $(wildcard coderd/dat go run ./coderd/database/gen/dump/main.go # Generates Go code for querying the database. -coderd/database/querier.go: coderd/database/sqlc.yaml coderd/database/dump.sql $(wildcard coderd/database/queries/*.sql) coderd/database/gen/enum/main.go coderd/database/gen/fake/main.go +coderd/database/querier.go: coderd/database/sqlc.yaml coderd/database/dump.sql $(wildcard coderd/database/queries/*.sql) ./coderd/database/generate.sh diff --git a/coderd/database/db.go b/coderd/database/db.go index 9ad1234070..bcf4de9a35 100644 --- a/coderd/database/db.go +++ b/coderd/database/db.go @@ -18,6 +18,13 @@ import ( "golang.org/x/xerrors" ) +// Well-known lock IDs for lock functions in the database. These should not +// change. If locks are deprecated, they should be kept to avoid reusing the +// same ID. +const ( + LockIDDeploymentSetup = iota + 1 +) + // Store contains all queryable database functions. // It extends the generated interface to add transaction support. type Store interface { diff --git a/coderd/database/gen/authz/main.go b/coderd/database/gen/authz/main.go deleted file mode 100644 index 2c781faedf..0000000000 --- a/coderd/database/gen/authz/main.go +++ /dev/null @@ -1,199 +0,0 @@ -package main - -import ( - "go/format" - "go/token" - "log" - "os" - - "github.com/dave/dst" - "github.com/dave/dst/decorator" - "github.com/dave/dst/decorator/resolver/goast" - "github.com/dave/dst/decorator/resolver/guess" - "golang.org/x/xerrors" -) - -func main() { - err := run() - if err != nil { - log.Fatal(err) - } -} - -func run() error { - funcs, err := readStoreInterface() - if err != nil { - return err - } - funcByName := map[string]struct{}{} - for _, f := range funcs { - funcByName[f.Name] = struct{}{} - } - declByName := map[string]*dst.FuncDecl{} - - dbauthz, err := os.ReadFile("./dbauthz/dbauthz.go") - if err != nil { - return xerrors.Errorf("read dbauthz: %w", err) - } - - // Required to preserve imports! - f, err := decorator.NewDecoratorWithImports(token.NewFileSet(), "dbauthz", goast.New()).Parse(dbauthz) - if err != nil { - return xerrors.Errorf("parse dbauthz: %w", err) - } - - for i := 0; i < len(f.Decls); i++ { - funcDecl, ok := f.Decls[i].(*dst.FuncDecl) - if !ok || funcDecl.Recv == nil || len(funcDecl.Recv.List) == 0 { - continue - } - // Check if the receiver is the struct we're interested in - starExpr, ok := funcDecl.Recv.List[0].Type.(*dst.StarExpr) - if !ok { - continue - } - ident, ok := starExpr.X.(*dst.Ident) - if !ok || ident.Name != "querier" { - continue - } - if _, ok := funcByName[funcDecl.Name.Name]; !ok { - continue - } - declByName[funcDecl.Name.Name] = funcDecl - f.Decls = append(f.Decls[:i], f.Decls[i+1:]...) - i-- - } - - for _, fn := range funcs { - decl, ok := declByName[fn.Name] - if !ok { - // Not implemented! - decl = &dst.FuncDecl{ - Name: dst.NewIdent(fn.Name), - Type: &dst.FuncType{ - Func: true, - TypeParams: fn.Func.TypeParams, - Params: fn.Func.Params, - Results: fn.Func.Results, - Decs: fn.Func.Decs, - }, - Recv: &dst.FieldList{ - List: []*dst.Field{{ - Names: []*dst.Ident{dst.NewIdent("q")}, - Type: dst.NewIdent("*querier"), - }}, - }, - Decs: dst.FuncDeclDecorations{ - NodeDecs: dst.NodeDecs{ - Before: dst.EmptyLine, - After: dst.EmptyLine, - }, - }, - Body: &dst.BlockStmt{ - List: []dst.Stmt{ - &dst.ExprStmt{ - X: &dst.CallExpr{ - Fun: &dst.Ident{ - Name: "panic", - }, - Args: []dst.Expr{ - &dst.BasicLit{ - Kind: token.STRING, - Value: "\"Not implemented\"", - }, - }, - }, - }, - }, - }, - } - } - f.Decls = append(f.Decls, decl) - } - - file, err := os.OpenFile("./dbauthz/dbauthz.go", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o755) - if err != nil { - return xerrors.Errorf("open dbauthz: %w", err) - } - defer file.Close() - - // Required to preserve imports! - restorer := decorator.NewRestorerWithImports("dbauthz", guess.New()) - restored, err := restorer.RestoreFile(f) - if err != nil { - return xerrors.Errorf("restore dbauthz: %w", err) - } - err = format.Node(file, restorer.Fset, restored) - return err -} - -type storeMethod struct { - Name string - Func *dst.FuncType -} - -func readStoreInterface() ([]storeMethod, error) { - querier, err := os.ReadFile("./querier.go") - if err != nil { - return nil, xerrors.Errorf("read querier: %w", err) - } - f, err := decorator.Parse(querier) - if err != nil { - return nil, err - } - - var sqlcQuerier *dst.InterfaceType - for _, decl := range f.Decls { - genDecl, ok := decl.(*dst.GenDecl) - if !ok { - continue - } - - for _, spec := range genDecl.Specs { - typeSpec, ok := spec.(*dst.TypeSpec) - if !ok { - continue - } - if typeSpec.Name.Name != "sqlcQuerier" { - continue - } - sqlcQuerier, ok = typeSpec.Type.(*dst.InterfaceType) - if !ok { - return nil, xerrors.Errorf("unexpected sqlcQuerier type: %T", typeSpec.Type) - } - break - } - } - if sqlcQuerier == nil { - return nil, xerrors.Errorf("sqlcQuerier not found") - } - funcs := []storeMethod{} - for _, method := range sqlcQuerier.Methods.List { - funcType, ok := method.Type.(*dst.FuncType) - if !ok { - continue - } - - for _, t := range []*dst.FieldList{funcType.Params, funcType.Results} { - if t == nil { - continue - } - for _, f := range t.List { - ident, ok := f.Type.(*dst.Ident) - if !ok { - continue - } - if !ident.IsExported() { - continue - } - ident.Path = "github.com/coder/coder/coderd/database" - } - } - - funcs = append(funcs, storeMethod{ - Name: method.Names[0].Name, - Func: funcType, - }) - } - return funcs, nil -} diff --git a/coderd/database/gen/enum/main.go b/coderd/database/gen/enum/main.go deleted file mode 100644 index 960bac922a..0000000000 --- a/coderd/database/gen/enum/main.go +++ /dev/null @@ -1,120 +0,0 @@ -package main - -import ( - "bufio" - "bytes" - "fmt" - "os" - "os/exec" - "strings" - - "golang.org/x/xerrors" -) - -const header = `// Code generated by gen/enum. DO NOT EDIT. -package database -` - -func main() { - if err := run(); err != nil { - panic(err) - } -} - -func run() error { - dump, err := os.Open("dump.sql") - if err != nil { - _, _ = fmt.Fprintf(os.Stderr, "error: %s must be run in the database directory with dump.sql present\n", os.Args[0]) - return err - } - defer dump.Close() - - var uniqueConstraints []string - - s := bufio.NewScanner(dump) - query := "" - for s.Scan() { - line := strings.TrimSpace(s.Text()) - switch { - case strings.HasPrefix(line, "--"): - case line == "": - case strings.HasSuffix(line, ";"): - query += line - if isUniqueConstraint(query) { - uniqueConstraints = append(uniqueConstraints, query) - } - query = "" - default: - query += line + " " - } - } - if err = s.Err(); err != nil { - return err - } - - return writeContents("unique_constraint.go", uniqueConstraints, generateUniqueConstraints) -} - -func isUniqueConstraint(query string) bool { - return strings.Contains(query, "UNIQUE") -} - -func generateUniqueConstraints(queries []string) ([]byte, error) { - s := &bytes.Buffer{} - - _, _ = fmt.Fprint(s, header) - _, _ = fmt.Fprint(s, ` -// UniqueConstraint represents a named unique constraint on a table. -type UniqueConstraint string - -// UniqueConstraint enums. -const ( -`) - for _, query := range queries { - name := "" - switch { - case strings.Contains(query, "ALTER TABLE") && strings.Contains(query, "ADD CONSTRAINT"): - name = strings.Split(query, " ")[6] - case strings.Contains(query, "CREATE UNIQUE INDEX"): - name = strings.Split(query, " ")[3] - default: - return nil, xerrors.Errorf("unknown unique constraint format: %s", query) - } - _, _ = fmt.Fprintf(s, "\tUnique%s UniqueConstraint = %q // %s\n", nameFromSnakeCase(name), name, query) - } - _, _ = fmt.Fprint(s, ")\n") - - return s.Bytes(), nil -} - -func writeContents[T any](dest string, arg T, fn func(T) ([]byte, error)) error { - b, err := fn(arg) - if err != nil { - return err - } - err = os.WriteFile(dest, b, 0o600) - if err != nil { - return err - } - cmd := exec.Command("go", "run", "golang.org/x/tools/cmd/goimports@latest", "-w", dest) - return cmd.Run() -} - -func nameFromSnakeCase(s string) string { - var ret string - for _, ss := range strings.Split(s, "_") { - switch ss { - case "id": - ret += "ID" - case "ids": - ret += "IDs" - case "jwt": - ret += "JWT" - case "idx": - ret += "Index" - default: - ret += strings.Title(ss) - } - } - return ret -} diff --git a/coderd/database/gen/fake/main.go b/coderd/database/gen/fake/main.go deleted file mode 100644 index 2d399192fa..0000000000 --- a/coderd/database/gen/fake/main.go +++ /dev/null @@ -1,285 +0,0 @@ -package main - -import ( - "fmt" - "go/format" - "go/token" - "log" - "os" - "path" - - "github.com/dave/dst" - "github.com/dave/dst/decorator" - "github.com/dave/dst/decorator/resolver/goast" - "github.com/dave/dst/decorator/resolver/guess" - "golang.org/x/xerrors" -) - -func main() { - err := run() - if err != nil { - log.Fatal(err) - } -} - -func run() error { - funcs, err := readStoreInterface() - if err != nil { - return err - } - funcByName := map[string]struct{}{} - for _, f := range funcs { - funcByName[f.Name] = struct{}{} - } - declByName := map[string]*dst.FuncDecl{} - - dbfake, err := os.ReadFile("./dbfake/dbfake.go") - if err != nil { - return xerrors.Errorf("read dbfake: %w", err) - } - - // Required to preserve imports! - f, err := decorator.NewDecoratorWithImports(token.NewFileSet(), "dbfake", goast.New()).Parse(dbfake) - if err != nil { - return xerrors.Errorf("parse dbfake: %w", err) - } - - for i := 0; i < len(f.Decls); i++ { - funcDecl, ok := f.Decls[i].(*dst.FuncDecl) - if !ok || funcDecl.Recv == nil || len(funcDecl.Recv.List) == 0 { - continue - } - // Check if the receiver is the struct we're interested in - starExpr, ok := funcDecl.Recv.List[0].Type.(*dst.StarExpr) - if !ok { - continue - } - ident, ok := starExpr.X.(*dst.Ident) - if !ok || ident.Name != "fakeQuerier" { - continue - } - if _, ok := funcByName[funcDecl.Name.Name]; !ok { - continue - } - declByName[funcDecl.Name.Name] = funcDecl - f.Decls = append(f.Decls[:i], f.Decls[i+1:]...) - i-- - } - - for _, fn := range funcs { - var bodyStmts []dst.Stmt - if len(fn.Func.Params.List) == 2 && fn.Func.Params.List[1].Names[0].Name == "arg" { - /* - err := validateDatabaseType(arg) - if err != nil { - return database.User{}, err - } - */ - bodyStmts = append(bodyStmts, &dst.AssignStmt{ - Lhs: []dst.Expr{dst.NewIdent("err")}, - Tok: token.DEFINE, - Rhs: []dst.Expr{ - &dst.CallExpr{ - Fun: &dst.Ident{ - Name: "validateDatabaseType", - }, - Args: []dst.Expr{dst.NewIdent("arg")}, - }, - }, - }) - returnStmt := &dst.ReturnStmt{ - Results: []dst.Expr{}, // Filled below. - } - bodyStmts = append(bodyStmts, &dst.IfStmt{ - Cond: &dst.BinaryExpr{ - X: dst.NewIdent("err"), - Op: token.NEQ, - Y: dst.NewIdent("nil"), - }, - Body: &dst.BlockStmt{ - List: []dst.Stmt{ - returnStmt, - }, - }, - Decs: dst.IfStmtDecorations{ - NodeDecs: dst.NodeDecs{ - After: dst.EmptyLine, - }, - }, - }) - for _, r := range fn.Func.Results.List { - switch typ := r.Type.(type) { - case *dst.StarExpr, *dst.ArrayType: - returnStmt.Results = append(returnStmt.Results, dst.NewIdent("nil")) - case *dst.Ident: - if typ.Path != "" { - returnStmt.Results = append(returnStmt.Results, dst.NewIdent(fmt.Sprintf("%s.%s{}", path.Base(typ.Path), typ.Name))) - } else { - switch typ.Name { - case "uint8", "uint16", "uint32", "uint64", "uint", "uintptr", - "int8", "int16", "int32", "int64", "int", - "byte", "rune", - "float32", "float64", - "complex64", "complex128": - returnStmt.Results = append(returnStmt.Results, dst.NewIdent("0")) - case "string": - returnStmt.Results = append(returnStmt.Results, dst.NewIdent("\"\"")) - case "bool": - returnStmt.Results = append(returnStmt.Results, dst.NewIdent("false")) - case "error": - returnStmt.Results = append(returnStmt.Results, dst.NewIdent("err")) - default: - panic(fmt.Sprintf("unknown ident: %#v", r.Type)) - } - } - default: - panic(fmt.Sprintf("unknown return type: %T", r.Type)) - } - } - } - decl, ok := declByName[fn.Name] - if !ok { - // Not implemented! - decl = &dst.FuncDecl{ - Name: dst.NewIdent(fn.Name), - Type: &dst.FuncType{ - Func: true, - TypeParams: fn.Func.TypeParams, - Params: fn.Func.Params, - Results: fn.Func.Results, - Decs: fn.Func.Decs, - }, - Recv: &dst.FieldList{ - List: []*dst.Field{{ - Names: []*dst.Ident{dst.NewIdent("q")}, - Type: dst.NewIdent("*fakeQuerier"), - }}, - }, - Decs: dst.FuncDeclDecorations{ - NodeDecs: dst.NodeDecs{ - Before: dst.EmptyLine, - After: dst.EmptyLine, - }, - }, - Body: &dst.BlockStmt{ - List: append(bodyStmts, &dst.ExprStmt{ - X: &dst.CallExpr{ - Fun: &dst.Ident{ - Name: "panic", - }, - Args: []dst.Expr{ - &dst.BasicLit{ - Kind: token.STRING, - Value: "\"Not implemented\"", - }, - }, - }, - }), - }, - } - } - f.Decls = append(f.Decls, decl) - } - - file, err := os.OpenFile("./dbfake/dbfake.go", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o755) - if err != nil { - return xerrors.Errorf("open dbfake: %w", err) - } - defer file.Close() - - // Required to preserve imports! - restorer := decorator.NewRestorerWithImports("dbfake", guess.New()) - restored, err := restorer.RestoreFile(f) - if err != nil { - return xerrors.Errorf("restore dbfake: %w", err) - } - err = format.Node(file, restorer.Fset, restored) - return err -} - -type storeMethod struct { - Name string - Func *dst.FuncType -} - -func readStoreInterface() ([]storeMethod, error) { - querier, err := os.ReadFile("./querier.go") - if err != nil { - return nil, xerrors.Errorf("read querier: %w", err) - } - f, err := decorator.Parse(querier) - if err != nil { - return nil, err - } - - var sqlcQuerier *dst.InterfaceType - for _, decl := range f.Decls { - genDecl, ok := decl.(*dst.GenDecl) - if !ok { - continue - } - - for _, spec := range genDecl.Specs { - typeSpec, ok := spec.(*dst.TypeSpec) - if !ok { - continue - } - if typeSpec.Name.Name != "sqlcQuerier" { - continue - } - sqlcQuerier, ok = typeSpec.Type.(*dst.InterfaceType) - if !ok { - return nil, xerrors.Errorf("unexpected sqlcQuerier type: %T", typeSpec.Type) - } - break - } - } - if sqlcQuerier == nil { - return nil, xerrors.Errorf("sqlcQuerier not found") - } - funcs := []storeMethod{} - for _, method := range sqlcQuerier.Methods.List { - funcType, ok := method.Type.(*dst.FuncType) - if !ok { - continue - } - - for _, t := range []*dst.FieldList{funcType.Params, funcType.Results} { - if t == nil { - continue - } - var ( - ident *dst.Ident - ok bool - ) - for _, f := range t.List { - switch typ := f.Type.(type) { - case *dst.StarExpr: - ident, ok = typ.X.(*dst.Ident) - if !ok { - continue - } - case *dst.ArrayType: - ident, ok = typ.Elt.(*dst.Ident) - if !ok { - continue - } - case *dst.Ident: - ident = typ - default: - continue - } - if !ident.IsExported() { - continue - } - ident.Path = "github.com/coder/coder/coderd/database" - } - } - - funcs = append(funcs, storeMethod{ - Name: method.Names[0].Name, - Func: funcType, - }) - } - return funcs, nil -} diff --git a/coderd/database/gen/metrics/main.go b/coderd/database/gen/metrics/main.go deleted file mode 100644 index ded12c0eae..0000000000 --- a/coderd/database/gen/metrics/main.go +++ /dev/null @@ -1,210 +0,0 @@ -package main - -import ( - "fmt" - "go/format" - "go/token" - "log" - "os" - "strings" - - "github.com/dave/dst" - "github.com/dave/dst/decorator" - "github.com/dave/dst/decorator/resolver/goast" - "github.com/dave/dst/decorator/resolver/guess" - "golang.org/x/xerrors" -) - -func main() { - err := run() - if err != nil { - log.Fatal(err) - } -} - -func run() error { - funcs, err := readStoreInterface() - if err != nil { - return err - } - funcByName := map[string]struct{}{} - for _, f := range funcs { - funcByName[f.Name] = struct{}{} - } - declByName := map[string]*dst.FuncDecl{} - - dbmetrics, err := os.ReadFile("./dbmetrics/dbmetrics.go") - if err != nil { - return xerrors.Errorf("read dbfake: %w", err) - } - - // Required to preserve imports! - f, err := decorator.NewDecoratorWithImports(token.NewFileSet(), "dbmetrics", goast.New()).Parse(dbmetrics) - if err != nil { - return xerrors.Errorf("parse dbfake: %w", err) - } - - for i := 0; i < len(f.Decls); i++ { - funcDecl, ok := f.Decls[i].(*dst.FuncDecl) - if !ok || funcDecl.Recv == nil || len(funcDecl.Recv.List) == 0 { - continue - } - // Check if the receiver is the struct we're interested in - _, ok = funcDecl.Recv.List[0].Type.(*dst.Ident) - if !ok { - continue - } - if _, ok := funcByName[funcDecl.Name.Name]; !ok { - continue - } - declByName[funcDecl.Name.Name] = funcDecl - f.Decls = append(f.Decls[:i], f.Decls[i+1:]...) - i-- - } - - for _, fn := range funcs { - decl, ok := declByName[fn.Name] - if !ok { - params := make([]string, 0) - if fn.Func.Params != nil { - for _, p := range fn.Func.Params.List { - for _, name := range p.Names { - params = append(params, name.Name) - } - } - } - returns := make([]string, 0) - if fn.Func.Results != nil { - for i := range fn.Func.Results.List { - returns = append(returns, fmt.Sprintf("r%d", i)) - } - } - - code := fmt.Sprintf(` -package stub - -func stub() { - start := time.Now() - %s := m.s.%s(%s) - m.queryLatencies.WithLabelValues("%s").Observe(time.Since(start).Seconds()) - return %s -} -`, strings.Join(returns, ","), fn.Name, strings.Join(params, ","), fn.Name, strings.Join(returns, ",")) - file, err := decorator.Parse(code) - if err != nil { - return xerrors.Errorf("parse code: %w", err) - } - stmt, ok := file.Decls[0].(*dst.FuncDecl) - if !ok { - return xerrors.Errorf("not ok %T", file.Decls[0]) - } - - // Not implemented! - // When a function isn't implemented, we automatically stub it! - decl = &dst.FuncDecl{ - Name: dst.NewIdent(fn.Name), - Type: fn.Func, - Recv: &dst.FieldList{ - List: []*dst.Field{{ - Names: []*dst.Ident{dst.NewIdent("m")}, - Type: dst.NewIdent("metricsStore"), - }}, - }, - Decs: dst.FuncDeclDecorations{ - NodeDecs: dst.NodeDecs{ - Before: dst.EmptyLine, - After: dst.EmptyLine, - }, - }, - Body: stmt.Body, - } - } - f.Decls = append(f.Decls, decl) - } - - file, err := os.OpenFile("./dbmetrics/dbmetrics.go", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o755) - if err != nil { - return xerrors.Errorf("open dbfake: %w", err) - } - defer file.Close() - - // Required to preserve imports! - restorer := decorator.NewRestorerWithImports("dbmetrics", guess.New()) - restored, err := restorer.RestoreFile(f) - if err != nil { - return xerrors.Errorf("restore dbfake: %w", err) - } - err = format.Node(file, restorer.Fset, restored) - return err -} - -type storeMethod struct { - Name string - Func *dst.FuncType -} - -func readStoreInterface() ([]storeMethod, error) { - querier, err := os.ReadFile("./querier.go") - if err != nil { - return nil, xerrors.Errorf("read querier: %w", err) - } - f, err := decorator.Parse(querier) - if err != nil { - return nil, err - } - - var sqlcQuerier *dst.InterfaceType - for _, decl := range f.Decls { - genDecl, ok := decl.(*dst.GenDecl) - if !ok { - continue - } - - for _, spec := range genDecl.Specs { - typeSpec, ok := spec.(*dst.TypeSpec) - if !ok { - continue - } - if typeSpec.Name.Name != "sqlcQuerier" { - continue - } - sqlcQuerier, ok = typeSpec.Type.(*dst.InterfaceType) - if !ok { - return nil, xerrors.Errorf("unexpected sqlcQuerier type: %T", typeSpec.Type) - } - break - } - } - if sqlcQuerier == nil { - return nil, xerrors.Errorf("sqlcQuerier not found") - } - funcs := []storeMethod{} - for _, method := range sqlcQuerier.Methods.List { - funcType, ok := method.Type.(*dst.FuncType) - if !ok { - continue - } - - for _, t := range []*dst.FieldList{funcType.Params, funcType.Results} { - if t == nil { - continue - } - for _, f := range t.List { - ident, ok := f.Type.(*dst.Ident) - if !ok { - continue - } - if !ident.IsExported() { - continue - } - ident.Path = "github.com/coder/coder/coderd/database" - } - } - - funcs = append(funcs, storeMethod{ - Name: method.Names[0].Name, - Func: funcType, - }) - } - return funcs, nil -} diff --git a/coderd/database/generate.sh b/coderd/database/generate.sh index f94ba151c0..ffbf290949 100755 --- a/coderd/database/generate.sh +++ b/coderd/database/generate.sh @@ -56,16 +56,5 @@ SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}") go mod download go run golang.org/x/tools/cmd/goimports@latest -w queries.sql.go - # Generate enums (e.g. unique constraints). - go run gen/enum/main.go - - # Generate the database fake! - go run gen/fake/main.go - go run golang.org/x/tools/cmd/goimports@latest -w ./dbfake/dbfake.go - - go run gen/authz/main.go - go run golang.org/x/tools/cmd/goimports@latest -w ./dbauthz/dbauthz.go - - go run gen/metrics/main.go - go run golang.org/x/tools/cmd/goimports@latest -w ./dbmetrics/dbmetrics.go + go run ../../scripts/dbgen/main.go ) diff --git a/coderd/database/lock.go b/coderd/database/lock.go deleted file mode 100644 index 56675282f9..0000000000 --- a/coderd/database/lock.go +++ /dev/null @@ -1,8 +0,0 @@ -package database - -// Well-known lock IDs for lock functions in the database. These should not -// change. If locks are deprecated, they should be kept to avoid reusing the -// same ID. -const ( - LockIDDeploymentSetup = iota + 1 -) diff --git a/scripts/dbgen/main.go b/scripts/dbgen/main.go new file mode 100644 index 0000000000..8f87892510 --- /dev/null +++ b/scripts/dbgen/main.go @@ -0,0 +1,542 @@ +package main + +import ( + "bufio" + "bytes" + "fmt" + "go/format" + "go/token" + "os" + "path" + "path/filepath" + "reflect" + "runtime" + "strings" + + "github.com/dave/dst" + "github.com/dave/dst/decorator" + "github.com/dave/dst/decorator/resolver/goast" + "github.com/dave/dst/decorator/resolver/guess" + "golang.org/x/tools/imports" + "golang.org/x/xerrors" +) + +var ( + funcs []querierFunction + funcByName map[string]struct{} +) + +func init() { + var err error + funcs, err = readQuerierFunctions() + if err != nil { + panic(err) + } + funcByName = map[string]struct{}{} + for _, f := range funcs { + funcByName[f.Name] = struct{}{} + } +} + +func main() { + err := run() + if err != nil { + _, _ = fmt.Fprintf(os.Stderr, "error: %s\n", err) + os.Exit(1) + } +} + +func run() error { + localPath, err := localFilePath() + if err != nil { + return err + } + databasePath := filepath.Join(localPath, "..", "..", "..", "coderd", "database") + + err = orderAndStubDatabaseFunctions(filepath.Join(databasePath, "dbfake", "dbfake.go"), "q", "fakeQuerier", func(params stubParams) string { + return `panic("not implemented")` + }) + if err != nil { + return xerrors.Errorf("stub dbfake: %w", err) + } + + err = orderAndStubDatabaseFunctions(filepath.Join(databasePath, "dbmetrics", "dbmetrics.go"), "m", "metricsStore", func(params stubParams) string { + return fmt.Sprintf(` +start := time.Now() +%s := m.s.%s(%s) +m.queryLatencies.WithLabelValues("%s").Observe(time.Since(start).Seconds()) +return %s +`, params.Returns, params.FuncName, params.Parameters, params.FuncName, params.Returns) + }) + if err != nil { + return xerrors.Errorf("stub dbmetrics: %w", err) + } + + err = orderAndStubDatabaseFunctions(filepath.Join(databasePath, "dbauthz", "dbauthz.go"), "q", "querier", func(params stubParams) string { + return `panic("not implemented")` + }) + if err != nil { + return xerrors.Errorf("stub dbauthz: %w", err) + } + + err = generateUniqueConstraints() + if err != nil { + return xerrors.Errorf("generate unique constraints: %w", err) + } + + return nil +} + +// generateUniqueConstraints generates the UniqueConstraint enum. +func generateUniqueConstraints() error { + localPath, err := localFilePath() + if err != nil { + return err + } + databasePath := filepath.Join(localPath, "..", "..", "..", "coderd", "database") + + dump, err := os.Open(filepath.Join(databasePath, "dump.sql")) + if err != nil { + return err + } + defer dump.Close() + + var uniqueConstraints []string + dumpScanner := bufio.NewScanner(dump) + query := "" + for dumpScanner.Scan() { + line := strings.TrimSpace(dumpScanner.Text()) + switch { + case strings.HasPrefix(line, "--"): + case line == "": + case strings.HasSuffix(line, ";"): + query += line + if strings.Contains(query, "UNIQUE") { + uniqueConstraints = append(uniqueConstraints, query) + } + query = "" + default: + query += line + " " + } + } + if err = dumpScanner.Err(); err != nil { + return err + } + + s := &bytes.Buffer{} + + _, _ = fmt.Fprint(s, `// Code generated by gen/enum. DO NOT EDIT. +package database +`) + _, _ = fmt.Fprint(s, ` +// UniqueConstraint represents a named unique constraint on a table. +type UniqueConstraint string + +// UniqueConstraint enums. +const ( +`) + for _, query := range uniqueConstraints { + name := "" + switch { + case strings.Contains(query, "ALTER TABLE") && strings.Contains(query, "ADD CONSTRAINT"): + name = strings.Split(query, " ")[6] + case strings.Contains(query, "CREATE UNIQUE INDEX"): + name = strings.Split(query, " ")[3] + default: + return xerrors.Errorf("unknown unique constraint format: %s", query) + } + _, _ = fmt.Fprintf(s, "\tUnique%s UniqueConstraint = %q // %s\n", nameFromSnakeCase(name), name, query) + } + _, _ = fmt.Fprint(s, ")\n") + + outputPath := filepath.Join(databasePath, "unique_constraint.go") + + data, err := imports.Process(outputPath, s.Bytes(), &imports.Options{ + Comments: true, + }) + if err != nil { + return err + } + return os.WriteFile(outputPath, data, 0o600) +} + +type stubParams struct { + FuncName string + Parameters string + Returns string +} + +// orderAndStubDatabaseFunctions orders the functions in the file and stubs them. +// This is useful for when we want to add a new function to the database and +// we want to make sure that it's ordered correctly. +// +// querierFuncs is a list of functions that are in the database. +// file is the path to the file that contains all the functions. +// structName is the name of the struct that contains the functions. +// stub is a string that will be used to stub the functions. +func orderAndStubDatabaseFunctions(filePath, receiver, structName string, stub func(params stubParams) string) error { + declByName := map[string]*dst.FuncDecl{} + packageName := filepath.Base(filepath.Dir(filePath)) + + contents, err := os.ReadFile(filePath) + if err != nil { + return xerrors.Errorf("read dbfake: %w", err) + } + + // Required to preserve imports! + f, err := decorator.NewDecoratorWithImports(token.NewFileSet(), packageName, goast.New()).Parse(contents) + if err != nil { + return xerrors.Errorf("parse dbfake: %w", err) + } + + pointer := false + for i := 0; i < len(f.Decls); i++ { + funcDecl, ok := f.Decls[i].(*dst.FuncDecl) + if !ok || funcDecl.Recv == nil || len(funcDecl.Recv.List) == 0 { + continue + } + + var ident *dst.Ident + switch t := funcDecl.Recv.List[0].Type.(type) { + case *dst.Ident: + ident = t + case *dst.StarExpr: + ident, ok = t.X.(*dst.Ident) + if !ok { + continue + } + pointer = true + } + if ident == nil || ident.Name != structName { + continue + } + if _, ok := funcByName[funcDecl.Name.Name]; !ok { + continue + } + declByName[funcDecl.Name.Name] = funcDecl + f.Decls = append(f.Decls[:i], f.Decls[i+1:]...) + i-- + } + + for _, fn := range funcs { + var bodyStmts []dst.Stmt + if len(fn.Func.Params.List) == 2 && fn.Func.Params.List[1].Names[0].Name == "arg" { + /* + err := validateDatabaseType(arg) + if err != nil { + return database.User{}, err + } + */ + bodyStmts = append(bodyStmts, &dst.AssignStmt{ + Lhs: []dst.Expr{dst.NewIdent("err")}, + Tok: token.DEFINE, + Rhs: []dst.Expr{ + &dst.CallExpr{ + Fun: &dst.Ident{ + Name: "validateDatabaseType", + }, + Args: []dst.Expr{dst.NewIdent("arg")}, + }, + }, + }) + returnStmt := &dst.ReturnStmt{ + Results: []dst.Expr{}, // Filled below. + } + bodyStmts = append(bodyStmts, &dst.IfStmt{ + Cond: &dst.BinaryExpr{ + X: dst.NewIdent("err"), + Op: token.NEQ, + Y: dst.NewIdent("nil"), + }, + Body: &dst.BlockStmt{ + List: []dst.Stmt{ + returnStmt, + }, + }, + Decs: dst.IfStmtDecorations{ + NodeDecs: dst.NodeDecs{ + After: dst.EmptyLine, + }, + }, + }) + for _, r := range fn.Func.Results.List { + switch typ := r.Type.(type) { + case *dst.StarExpr, *dst.ArrayType: + returnStmt.Results = append(returnStmt.Results, dst.NewIdent("nil")) + case *dst.Ident: + if typ.Path != "" { + returnStmt.Results = append(returnStmt.Results, dst.NewIdent(fmt.Sprintf("%s.%s{}", path.Base(typ.Path), typ.Name))) + } else { + switch typ.Name { + case "uint8", "uint16", "uint32", "uint64", "uint", "uintptr", + "int8", "int16", "int32", "int64", "int", + "byte", "rune", + "float32", "float64", + "complex64", "complex128": + returnStmt.Results = append(returnStmt.Results, dst.NewIdent("0")) + case "string": + returnStmt.Results = append(returnStmt.Results, dst.NewIdent("\"\"")) + case "bool": + returnStmt.Results = append(returnStmt.Results, dst.NewIdent("false")) + case "error": + returnStmt.Results = append(returnStmt.Results, dst.NewIdent("err")) + default: + panic(fmt.Sprintf("unknown ident: %#v", r.Type)) + } + } + default: + panic(fmt.Sprintf("unknown return type: %T", r.Type)) + } + } + } + decl, ok := declByName[fn.Name] + if !ok { + typeName := structName + if pointer { + typeName = "*" + typeName + } + params := make([]string, 0) + if fn.Func.Params != nil { + for _, p := range fn.Func.Params.List { + for _, name := range p.Names { + params = append(params, name.Name) + } + } + } + returns := make([]string, 0) + if fn.Func.Results != nil { + for i := range fn.Func.Results.List { + returns = append(returns, fmt.Sprintf("r%d", i)) + } + } + + funcDecl, err := compileFuncDecl(stub(stubParams{ + FuncName: fn.Name, + Parameters: strings.Join(params, ","), + Returns: strings.Join(returns, ","), + })) + if err != nil { + return xerrors.Errorf("compile func decl: %w", err) + } + + // Not implemented! + decl = &dst.FuncDecl{ + Name: dst.NewIdent(fn.Name), + Type: &dst.FuncType{ + Func: true, + TypeParams: fn.Func.TypeParams, + Params: fn.Func.Params, + Results: fn.Func.Results, + Decs: fn.Func.Decs, + }, + Recv: &dst.FieldList{ + List: []*dst.Field{{ + Names: []*dst.Ident{dst.NewIdent(receiver)}, + Type: dst.NewIdent(typeName), + }}, + }, + Decs: dst.FuncDeclDecorations{ + NodeDecs: dst.NodeDecs{ + Before: dst.EmptyLine, + After: dst.EmptyLine, + }, + }, + Body: &dst.BlockStmt{ + List: append(bodyStmts, funcDecl.Body.List...), + }, + } + } + if ok { + for i, pm := range fn.Func.Params.List { + if len(decl.Type.Params.List) < i+1 { + decl.Type.Params.List = append(decl.Type.Params.List, pm) + } + if !reflect.DeepEqual(decl.Type.Params.List[i].Type, pm.Type) { + decl.Type.Params.List[i].Type = pm.Type + } + } + + for i, res := range fn.Func.Results.List { + if len(decl.Type.Results.List) < i+1 { + decl.Type.Results.List = append(decl.Type.Results.List, res) + } + if !reflect.DeepEqual(decl.Type.Results.List[i].Type, res.Type) { + decl.Type.Results.List[i].Type = res.Type + } + } + } + f.Decls = append(f.Decls, decl) + } + + // Required to preserve imports! + restorer := decorator.NewRestorerWithImports(packageName, guess.New()) + restored, err := restorer.RestoreFile(f) + if err != nil { + return xerrors.Errorf("restore package: %w", err) + } + var buf bytes.Buffer + err = format.Node(&buf, restorer.Fset, restored) + if err != nil { + return xerrors.Errorf("format package: %w", err) + } + data, err := imports.Process(filePath, buf.Bytes(), &imports.Options{ + Comments: true, + FormatOnly: true, + }) + if err != nil { + return xerrors.Errorf("process imports: %w", err) + } + return os.WriteFile(filePath, data, 0o600) +} + +// compileFuncDecl extracts the function declaration from the given code. +func compileFuncDecl(code string) (*dst.FuncDecl, error) { + f, err := decorator.Parse(fmt.Sprintf(`package stub + +func stub() { + %s +}`, strings.TrimSpace(code))) + if err != nil { + return nil, err + } + if len(f.Decls) != 1 { + return nil, xerrors.Errorf("expected 1 decl, got %d", len(f.Decls)) + } + decl, ok := f.Decls[0].(*dst.FuncDecl) + if !ok { + return nil, xerrors.Errorf("expected func decl, got %T", f.Decls[0]) + } + return decl, nil +} + +type querierFunction struct { + // Name is the name of the function. Like "GetUserByID" + Name string + // Func is the AST representation of a function. + Func *dst.FuncType +} + +// readQuerierFunctions reads the functions from coderd/database/querier.go +func readQuerierFunctions() ([]querierFunction, error) { + localPath, err := localFilePath() + if err != nil { + return nil, err + } + querierPath := filepath.Join(localPath, "..", "..", "..", "coderd", "database", "querier.go") + + querierData, err := os.ReadFile(querierPath) + if err != nil { + return nil, xerrors.Errorf("read querier: %w", err) + } + f, err := decorator.Parse(querierData) + if err != nil { + return nil, err + } + + var querier *dst.InterfaceType + for _, decl := range f.Decls { + genDecl, ok := decl.(*dst.GenDecl) + if !ok { + continue + } + + for _, spec := range genDecl.Specs { + typeSpec, ok := spec.(*dst.TypeSpec) + if !ok { + continue + } + // This is the name of the interface. If that ever changes, + // this will need to be updated. + if typeSpec.Name.Name != "sqlcQuerier" { + continue + } + querier, ok = typeSpec.Type.(*dst.InterfaceType) + if !ok { + return nil, xerrors.Errorf("unexpected sqlcQuerier type: %T", typeSpec.Type) + } + break + } + } + if querier == nil { + return nil, xerrors.Errorf("querier not found") + } + funcs := []querierFunction{} + for _, method := range querier.Methods.List { + funcType, ok := method.Type.(*dst.FuncType) + if !ok { + continue + } + + for _, t := range []*dst.FieldList{funcType.Params, funcType.Results, funcType.TypeParams} { + if t == nil { + continue + } + for _, f := range t.List { + var ident *dst.Ident + switch t := f.Type.(type) { + case *dst.Ident: + ident = t + case *dst.StarExpr: + ident, ok = t.X.(*dst.Ident) + if !ok { + continue + } + case *dst.SelectorExpr: + ident, ok = t.X.(*dst.Ident) + if !ok { + continue + } + case *dst.ArrayType: + ident, ok = t.Elt.(*dst.Ident) + if !ok { + continue + } + } + if ident == nil { + continue + } + // If the type is exported then we should be able to find it + // in the database package! + if !ident.IsExported() { + continue + } + ident.Path = "github.com/coder/coder/coderd/database" + } + } + + funcs = append(funcs, querierFunction{ + Name: method.Names[0].Name, + Func: funcType, + }) + } + return funcs, nil +} + +// localFilePath returns the location of `main.go` in the dbgen package. +func localFilePath() (string, error) { + _, filename, _, ok := runtime.Caller(0) + if !ok { + return "", xerrors.Errorf("failed to get caller") + } + return filename, nil +} + +// nameFromSnakeCase converts snake_case to CamelCase. +func nameFromSnakeCase(s string) string { + var ret string + for _, ss := range strings.Split(s, "_") { + switch ss { + case "id": + ret += "ID" + case "ids": + ret += "IDs" + case "jwt": + ret += "JWT" + case "idx": + ret += "Index" + default: + ret += strings.Title(ss) + } + } + return ret +}