chore: merge database gen scripts (#8073)

* chore: merge database gen scripts

* Fix type params gen

* Merge enum into dbgen
This commit is contained in:
Kyle Carberry 2023-06-20 11:24:33 -05:00 committed by GitHub
parent f3b2009499
commit bbb0fab1de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 551 additions and 835 deletions

View File

@ -486,7 +486,7 @@ coderd/database/dump.sql: coderd/database/gen/dump/main.go $(wildcard coderd/dat
go run ./coderd/database/gen/dump/main.go
# Generates Go code for querying the database.
coderd/database/querier.go: coderd/database/sqlc.yaml coderd/database/dump.sql $(wildcard coderd/database/queries/*.sql) coderd/database/gen/enum/main.go coderd/database/gen/fake/main.go
coderd/database/querier.go: coderd/database/sqlc.yaml coderd/database/dump.sql $(wildcard coderd/database/queries/*.sql)
./coderd/database/generate.sh

View File

@ -18,6 +18,13 @@ import (
"golang.org/x/xerrors"
)
// Well-known lock IDs for lock functions in the database. These should not
// change. If locks are deprecated, they should be kept to avoid reusing the
// same ID.
const (
LockIDDeploymentSetup = iota + 1
)
// Store contains all queryable database functions.
// It extends the generated interface to add transaction support.
type Store interface {

View File

@ -1,199 +0,0 @@
package main
import (
"go/format"
"go/token"
"log"
"os"
"github.com/dave/dst"
"github.com/dave/dst/decorator"
"github.com/dave/dst/decorator/resolver/goast"
"github.com/dave/dst/decorator/resolver/guess"
"golang.org/x/xerrors"
)
func main() {
err := run()
if err != nil {
log.Fatal(err)
}
}
func run() error {
funcs, err := readStoreInterface()
if err != nil {
return err
}
funcByName := map[string]struct{}{}
for _, f := range funcs {
funcByName[f.Name] = struct{}{}
}
declByName := map[string]*dst.FuncDecl{}
dbauthz, err := os.ReadFile("./dbauthz/dbauthz.go")
if err != nil {
return xerrors.Errorf("read dbauthz: %w", err)
}
// Required to preserve imports!
f, err := decorator.NewDecoratorWithImports(token.NewFileSet(), "dbauthz", goast.New()).Parse(dbauthz)
if err != nil {
return xerrors.Errorf("parse dbauthz: %w", err)
}
for i := 0; i < len(f.Decls); i++ {
funcDecl, ok := f.Decls[i].(*dst.FuncDecl)
if !ok || funcDecl.Recv == nil || len(funcDecl.Recv.List) == 0 {
continue
}
// Check if the receiver is the struct we're interested in
starExpr, ok := funcDecl.Recv.List[0].Type.(*dst.StarExpr)
if !ok {
continue
}
ident, ok := starExpr.X.(*dst.Ident)
if !ok || ident.Name != "querier" {
continue
}
if _, ok := funcByName[funcDecl.Name.Name]; !ok {
continue
}
declByName[funcDecl.Name.Name] = funcDecl
f.Decls = append(f.Decls[:i], f.Decls[i+1:]...)
i--
}
for _, fn := range funcs {
decl, ok := declByName[fn.Name]
if !ok {
// Not implemented!
decl = &dst.FuncDecl{
Name: dst.NewIdent(fn.Name),
Type: &dst.FuncType{
Func: true,
TypeParams: fn.Func.TypeParams,
Params: fn.Func.Params,
Results: fn.Func.Results,
Decs: fn.Func.Decs,
},
Recv: &dst.FieldList{
List: []*dst.Field{{
Names: []*dst.Ident{dst.NewIdent("q")},
Type: dst.NewIdent("*querier"),
}},
},
Decs: dst.FuncDeclDecorations{
NodeDecs: dst.NodeDecs{
Before: dst.EmptyLine,
After: dst.EmptyLine,
},
},
Body: &dst.BlockStmt{
List: []dst.Stmt{
&dst.ExprStmt{
X: &dst.CallExpr{
Fun: &dst.Ident{
Name: "panic",
},
Args: []dst.Expr{
&dst.BasicLit{
Kind: token.STRING,
Value: "\"Not implemented\"",
},
},
},
},
},
},
}
}
f.Decls = append(f.Decls, decl)
}
file, err := os.OpenFile("./dbauthz/dbauthz.go", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o755)
if err != nil {
return xerrors.Errorf("open dbauthz: %w", err)
}
defer file.Close()
// Required to preserve imports!
restorer := decorator.NewRestorerWithImports("dbauthz", guess.New())
restored, err := restorer.RestoreFile(f)
if err != nil {
return xerrors.Errorf("restore dbauthz: %w", err)
}
err = format.Node(file, restorer.Fset, restored)
return err
}
type storeMethod struct {
Name string
Func *dst.FuncType
}
func readStoreInterface() ([]storeMethod, error) {
querier, err := os.ReadFile("./querier.go")
if err != nil {
return nil, xerrors.Errorf("read querier: %w", err)
}
f, err := decorator.Parse(querier)
if err != nil {
return nil, err
}
var sqlcQuerier *dst.InterfaceType
for _, decl := range f.Decls {
genDecl, ok := decl.(*dst.GenDecl)
if !ok {
continue
}
for _, spec := range genDecl.Specs {
typeSpec, ok := spec.(*dst.TypeSpec)
if !ok {
continue
}
if typeSpec.Name.Name != "sqlcQuerier" {
continue
}
sqlcQuerier, ok = typeSpec.Type.(*dst.InterfaceType)
if !ok {
return nil, xerrors.Errorf("unexpected sqlcQuerier type: %T", typeSpec.Type)
}
break
}
}
if sqlcQuerier == nil {
return nil, xerrors.Errorf("sqlcQuerier not found")
}
funcs := []storeMethod{}
for _, method := range sqlcQuerier.Methods.List {
funcType, ok := method.Type.(*dst.FuncType)
if !ok {
continue
}
for _, t := range []*dst.FieldList{funcType.Params, funcType.Results} {
if t == nil {
continue
}
for _, f := range t.List {
ident, ok := f.Type.(*dst.Ident)
if !ok {
continue
}
if !ident.IsExported() {
continue
}
ident.Path = "github.com/coder/coder/coderd/database"
}
}
funcs = append(funcs, storeMethod{
Name: method.Names[0].Name,
Func: funcType,
})
}
return funcs, nil
}

View File

@ -1,120 +0,0 @@
package main
import (
"bufio"
"bytes"
"fmt"
"os"
"os/exec"
"strings"
"golang.org/x/xerrors"
)
const header = `// Code generated by gen/enum. DO NOT EDIT.
package database
`
func main() {
if err := run(); err != nil {
panic(err)
}
}
func run() error {
dump, err := os.Open("dump.sql")
if err != nil {
_, _ = fmt.Fprintf(os.Stderr, "error: %s must be run in the database directory with dump.sql present\n", os.Args[0])
return err
}
defer dump.Close()
var uniqueConstraints []string
s := bufio.NewScanner(dump)
query := ""
for s.Scan() {
line := strings.TrimSpace(s.Text())
switch {
case strings.HasPrefix(line, "--"):
case line == "":
case strings.HasSuffix(line, ";"):
query += line
if isUniqueConstraint(query) {
uniqueConstraints = append(uniqueConstraints, query)
}
query = ""
default:
query += line + " "
}
}
if err = s.Err(); err != nil {
return err
}
return writeContents("unique_constraint.go", uniqueConstraints, generateUniqueConstraints)
}
func isUniqueConstraint(query string) bool {
return strings.Contains(query, "UNIQUE")
}
func generateUniqueConstraints(queries []string) ([]byte, error) {
s := &bytes.Buffer{}
_, _ = fmt.Fprint(s, header)
_, _ = fmt.Fprint(s, `
// UniqueConstraint represents a named unique constraint on a table.
type UniqueConstraint string
// UniqueConstraint enums.
const (
`)
for _, query := range queries {
name := ""
switch {
case strings.Contains(query, "ALTER TABLE") && strings.Contains(query, "ADD CONSTRAINT"):
name = strings.Split(query, " ")[6]
case strings.Contains(query, "CREATE UNIQUE INDEX"):
name = strings.Split(query, " ")[3]
default:
return nil, xerrors.Errorf("unknown unique constraint format: %s", query)
}
_, _ = fmt.Fprintf(s, "\tUnique%s UniqueConstraint = %q // %s\n", nameFromSnakeCase(name), name, query)
}
_, _ = fmt.Fprint(s, ")\n")
return s.Bytes(), nil
}
func writeContents[T any](dest string, arg T, fn func(T) ([]byte, error)) error {
b, err := fn(arg)
if err != nil {
return err
}
err = os.WriteFile(dest, b, 0o600)
if err != nil {
return err
}
cmd := exec.Command("go", "run", "golang.org/x/tools/cmd/goimports@latest", "-w", dest)
return cmd.Run()
}
func nameFromSnakeCase(s string) string {
var ret string
for _, ss := range strings.Split(s, "_") {
switch ss {
case "id":
ret += "ID"
case "ids":
ret += "IDs"
case "jwt":
ret += "JWT"
case "idx":
ret += "Index"
default:
ret += strings.Title(ss)
}
}
return ret
}

View File

@ -1,285 +0,0 @@
package main
import (
"fmt"
"go/format"
"go/token"
"log"
"os"
"path"
"github.com/dave/dst"
"github.com/dave/dst/decorator"
"github.com/dave/dst/decorator/resolver/goast"
"github.com/dave/dst/decorator/resolver/guess"
"golang.org/x/xerrors"
)
func main() {
err := run()
if err != nil {
log.Fatal(err)
}
}
func run() error {
funcs, err := readStoreInterface()
if err != nil {
return err
}
funcByName := map[string]struct{}{}
for _, f := range funcs {
funcByName[f.Name] = struct{}{}
}
declByName := map[string]*dst.FuncDecl{}
dbfake, err := os.ReadFile("./dbfake/dbfake.go")
if err != nil {
return xerrors.Errorf("read dbfake: %w", err)
}
// Required to preserve imports!
f, err := decorator.NewDecoratorWithImports(token.NewFileSet(), "dbfake", goast.New()).Parse(dbfake)
if err != nil {
return xerrors.Errorf("parse dbfake: %w", err)
}
for i := 0; i < len(f.Decls); i++ {
funcDecl, ok := f.Decls[i].(*dst.FuncDecl)
if !ok || funcDecl.Recv == nil || len(funcDecl.Recv.List) == 0 {
continue
}
// Check if the receiver is the struct we're interested in
starExpr, ok := funcDecl.Recv.List[0].Type.(*dst.StarExpr)
if !ok {
continue
}
ident, ok := starExpr.X.(*dst.Ident)
if !ok || ident.Name != "fakeQuerier" {
continue
}
if _, ok := funcByName[funcDecl.Name.Name]; !ok {
continue
}
declByName[funcDecl.Name.Name] = funcDecl
f.Decls = append(f.Decls[:i], f.Decls[i+1:]...)
i--
}
for _, fn := range funcs {
var bodyStmts []dst.Stmt
if len(fn.Func.Params.List) == 2 && fn.Func.Params.List[1].Names[0].Name == "arg" {
/*
err := validateDatabaseType(arg)
if err != nil {
return database.User{}, err
}
*/
bodyStmts = append(bodyStmts, &dst.AssignStmt{
Lhs: []dst.Expr{dst.NewIdent("err")},
Tok: token.DEFINE,
Rhs: []dst.Expr{
&dst.CallExpr{
Fun: &dst.Ident{
Name: "validateDatabaseType",
},
Args: []dst.Expr{dst.NewIdent("arg")},
},
},
})
returnStmt := &dst.ReturnStmt{
Results: []dst.Expr{}, // Filled below.
}
bodyStmts = append(bodyStmts, &dst.IfStmt{
Cond: &dst.BinaryExpr{
X: dst.NewIdent("err"),
Op: token.NEQ,
Y: dst.NewIdent("nil"),
},
Body: &dst.BlockStmt{
List: []dst.Stmt{
returnStmt,
},
},
Decs: dst.IfStmtDecorations{
NodeDecs: dst.NodeDecs{
After: dst.EmptyLine,
},
},
})
for _, r := range fn.Func.Results.List {
switch typ := r.Type.(type) {
case *dst.StarExpr, *dst.ArrayType:
returnStmt.Results = append(returnStmt.Results, dst.NewIdent("nil"))
case *dst.Ident:
if typ.Path != "" {
returnStmt.Results = append(returnStmt.Results, dst.NewIdent(fmt.Sprintf("%s.%s{}", path.Base(typ.Path), typ.Name)))
} else {
switch typ.Name {
case "uint8", "uint16", "uint32", "uint64", "uint", "uintptr",
"int8", "int16", "int32", "int64", "int",
"byte", "rune",
"float32", "float64",
"complex64", "complex128":
returnStmt.Results = append(returnStmt.Results, dst.NewIdent("0"))
case "string":
returnStmt.Results = append(returnStmt.Results, dst.NewIdent("\"\""))
case "bool":
returnStmt.Results = append(returnStmt.Results, dst.NewIdent("false"))
case "error":
returnStmt.Results = append(returnStmt.Results, dst.NewIdent("err"))
default:
panic(fmt.Sprintf("unknown ident: %#v", r.Type))
}
}
default:
panic(fmt.Sprintf("unknown return type: %T", r.Type))
}
}
}
decl, ok := declByName[fn.Name]
if !ok {
// Not implemented!
decl = &dst.FuncDecl{
Name: dst.NewIdent(fn.Name),
Type: &dst.FuncType{
Func: true,
TypeParams: fn.Func.TypeParams,
Params: fn.Func.Params,
Results: fn.Func.Results,
Decs: fn.Func.Decs,
},
Recv: &dst.FieldList{
List: []*dst.Field{{
Names: []*dst.Ident{dst.NewIdent("q")},
Type: dst.NewIdent("*fakeQuerier"),
}},
},
Decs: dst.FuncDeclDecorations{
NodeDecs: dst.NodeDecs{
Before: dst.EmptyLine,
After: dst.EmptyLine,
},
},
Body: &dst.BlockStmt{
List: append(bodyStmts, &dst.ExprStmt{
X: &dst.CallExpr{
Fun: &dst.Ident{
Name: "panic",
},
Args: []dst.Expr{
&dst.BasicLit{
Kind: token.STRING,
Value: "\"Not implemented\"",
},
},
},
}),
},
}
}
f.Decls = append(f.Decls, decl)
}
file, err := os.OpenFile("./dbfake/dbfake.go", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o755)
if err != nil {
return xerrors.Errorf("open dbfake: %w", err)
}
defer file.Close()
// Required to preserve imports!
restorer := decorator.NewRestorerWithImports("dbfake", guess.New())
restored, err := restorer.RestoreFile(f)
if err != nil {
return xerrors.Errorf("restore dbfake: %w", err)
}
err = format.Node(file, restorer.Fset, restored)
return err
}
type storeMethod struct {
Name string
Func *dst.FuncType
}
func readStoreInterface() ([]storeMethod, error) {
querier, err := os.ReadFile("./querier.go")
if err != nil {
return nil, xerrors.Errorf("read querier: %w", err)
}
f, err := decorator.Parse(querier)
if err != nil {
return nil, err
}
var sqlcQuerier *dst.InterfaceType
for _, decl := range f.Decls {
genDecl, ok := decl.(*dst.GenDecl)
if !ok {
continue
}
for _, spec := range genDecl.Specs {
typeSpec, ok := spec.(*dst.TypeSpec)
if !ok {
continue
}
if typeSpec.Name.Name != "sqlcQuerier" {
continue
}
sqlcQuerier, ok = typeSpec.Type.(*dst.InterfaceType)
if !ok {
return nil, xerrors.Errorf("unexpected sqlcQuerier type: %T", typeSpec.Type)
}
break
}
}
if sqlcQuerier == nil {
return nil, xerrors.Errorf("sqlcQuerier not found")
}
funcs := []storeMethod{}
for _, method := range sqlcQuerier.Methods.List {
funcType, ok := method.Type.(*dst.FuncType)
if !ok {
continue
}
for _, t := range []*dst.FieldList{funcType.Params, funcType.Results} {
if t == nil {
continue
}
var (
ident *dst.Ident
ok bool
)
for _, f := range t.List {
switch typ := f.Type.(type) {
case *dst.StarExpr:
ident, ok = typ.X.(*dst.Ident)
if !ok {
continue
}
case *dst.ArrayType:
ident, ok = typ.Elt.(*dst.Ident)
if !ok {
continue
}
case *dst.Ident:
ident = typ
default:
continue
}
if !ident.IsExported() {
continue
}
ident.Path = "github.com/coder/coder/coderd/database"
}
}
funcs = append(funcs, storeMethod{
Name: method.Names[0].Name,
Func: funcType,
})
}
return funcs, nil
}

View File

@ -1,210 +0,0 @@
package main
import (
"fmt"
"go/format"
"go/token"
"log"
"os"
"strings"
"github.com/dave/dst"
"github.com/dave/dst/decorator"
"github.com/dave/dst/decorator/resolver/goast"
"github.com/dave/dst/decorator/resolver/guess"
"golang.org/x/xerrors"
)
func main() {
err := run()
if err != nil {
log.Fatal(err)
}
}
func run() error {
funcs, err := readStoreInterface()
if err != nil {
return err
}
funcByName := map[string]struct{}{}
for _, f := range funcs {
funcByName[f.Name] = struct{}{}
}
declByName := map[string]*dst.FuncDecl{}
dbmetrics, err := os.ReadFile("./dbmetrics/dbmetrics.go")
if err != nil {
return xerrors.Errorf("read dbfake: %w", err)
}
// Required to preserve imports!
f, err := decorator.NewDecoratorWithImports(token.NewFileSet(), "dbmetrics", goast.New()).Parse(dbmetrics)
if err != nil {
return xerrors.Errorf("parse dbfake: %w", err)
}
for i := 0; i < len(f.Decls); i++ {
funcDecl, ok := f.Decls[i].(*dst.FuncDecl)
if !ok || funcDecl.Recv == nil || len(funcDecl.Recv.List) == 0 {
continue
}
// Check if the receiver is the struct we're interested in
_, ok = funcDecl.Recv.List[0].Type.(*dst.Ident)
if !ok {
continue
}
if _, ok := funcByName[funcDecl.Name.Name]; !ok {
continue
}
declByName[funcDecl.Name.Name] = funcDecl
f.Decls = append(f.Decls[:i], f.Decls[i+1:]...)
i--
}
for _, fn := range funcs {
decl, ok := declByName[fn.Name]
if !ok {
params := make([]string, 0)
if fn.Func.Params != nil {
for _, p := range fn.Func.Params.List {
for _, name := range p.Names {
params = append(params, name.Name)
}
}
}
returns := make([]string, 0)
if fn.Func.Results != nil {
for i := range fn.Func.Results.List {
returns = append(returns, fmt.Sprintf("r%d", i))
}
}
code := fmt.Sprintf(`
package stub
func stub() {
start := time.Now()
%s := m.s.%s(%s)
m.queryLatencies.WithLabelValues("%s").Observe(time.Since(start).Seconds())
return %s
}
`, strings.Join(returns, ","), fn.Name, strings.Join(params, ","), fn.Name, strings.Join(returns, ","))
file, err := decorator.Parse(code)
if err != nil {
return xerrors.Errorf("parse code: %w", err)
}
stmt, ok := file.Decls[0].(*dst.FuncDecl)
if !ok {
return xerrors.Errorf("not ok %T", file.Decls[0])
}
// Not implemented!
// When a function isn't implemented, we automatically stub it!
decl = &dst.FuncDecl{
Name: dst.NewIdent(fn.Name),
Type: fn.Func,
Recv: &dst.FieldList{
List: []*dst.Field{{
Names: []*dst.Ident{dst.NewIdent("m")},
Type: dst.NewIdent("metricsStore"),
}},
},
Decs: dst.FuncDeclDecorations{
NodeDecs: dst.NodeDecs{
Before: dst.EmptyLine,
After: dst.EmptyLine,
},
},
Body: stmt.Body,
}
}
f.Decls = append(f.Decls, decl)
}
file, err := os.OpenFile("./dbmetrics/dbmetrics.go", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o755)
if err != nil {
return xerrors.Errorf("open dbfake: %w", err)
}
defer file.Close()
// Required to preserve imports!
restorer := decorator.NewRestorerWithImports("dbmetrics", guess.New())
restored, err := restorer.RestoreFile(f)
if err != nil {
return xerrors.Errorf("restore dbfake: %w", err)
}
err = format.Node(file, restorer.Fset, restored)
return err
}
type storeMethod struct {
Name string
Func *dst.FuncType
}
func readStoreInterface() ([]storeMethod, error) {
querier, err := os.ReadFile("./querier.go")
if err != nil {
return nil, xerrors.Errorf("read querier: %w", err)
}
f, err := decorator.Parse(querier)
if err != nil {
return nil, err
}
var sqlcQuerier *dst.InterfaceType
for _, decl := range f.Decls {
genDecl, ok := decl.(*dst.GenDecl)
if !ok {
continue
}
for _, spec := range genDecl.Specs {
typeSpec, ok := spec.(*dst.TypeSpec)
if !ok {
continue
}
if typeSpec.Name.Name != "sqlcQuerier" {
continue
}
sqlcQuerier, ok = typeSpec.Type.(*dst.InterfaceType)
if !ok {
return nil, xerrors.Errorf("unexpected sqlcQuerier type: %T", typeSpec.Type)
}
break
}
}
if sqlcQuerier == nil {
return nil, xerrors.Errorf("sqlcQuerier not found")
}
funcs := []storeMethod{}
for _, method := range sqlcQuerier.Methods.List {
funcType, ok := method.Type.(*dst.FuncType)
if !ok {
continue
}
for _, t := range []*dst.FieldList{funcType.Params, funcType.Results} {
if t == nil {
continue
}
for _, f := range t.List {
ident, ok := f.Type.(*dst.Ident)
if !ok {
continue
}
if !ident.IsExported() {
continue
}
ident.Path = "github.com/coder/coder/coderd/database"
}
}
funcs = append(funcs, storeMethod{
Name: method.Names[0].Name,
Func: funcType,
})
}
return funcs, nil
}

View File

@ -56,16 +56,5 @@ SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}")
go mod download
go run golang.org/x/tools/cmd/goimports@latest -w queries.sql.go
# Generate enums (e.g. unique constraints).
go run gen/enum/main.go
# Generate the database fake!
go run gen/fake/main.go
go run golang.org/x/tools/cmd/goimports@latest -w ./dbfake/dbfake.go
go run gen/authz/main.go
go run golang.org/x/tools/cmd/goimports@latest -w ./dbauthz/dbauthz.go
go run gen/metrics/main.go
go run golang.org/x/tools/cmd/goimports@latest -w ./dbmetrics/dbmetrics.go
go run ../../scripts/dbgen/main.go
)

View File

@ -1,8 +0,0 @@
package database
// Well-known lock IDs for lock functions in the database. These should not
// change. If locks are deprecated, they should be kept to avoid reusing the
// same ID.
const (
LockIDDeploymentSetup = iota + 1
)

542
scripts/dbgen/main.go Normal file
View File

@ -0,0 +1,542 @@
package main
import (
"bufio"
"bytes"
"fmt"
"go/format"
"go/token"
"os"
"path"
"path/filepath"
"reflect"
"runtime"
"strings"
"github.com/dave/dst"
"github.com/dave/dst/decorator"
"github.com/dave/dst/decorator/resolver/goast"
"github.com/dave/dst/decorator/resolver/guess"
"golang.org/x/tools/imports"
"golang.org/x/xerrors"
)
var (
funcs []querierFunction
funcByName map[string]struct{}
)
func init() {
var err error
funcs, err = readQuerierFunctions()
if err != nil {
panic(err)
}
funcByName = map[string]struct{}{}
for _, f := range funcs {
funcByName[f.Name] = struct{}{}
}
}
func main() {
err := run()
if err != nil {
_, _ = fmt.Fprintf(os.Stderr, "error: %s\n", err)
os.Exit(1)
}
}
func run() error {
localPath, err := localFilePath()
if err != nil {
return err
}
databasePath := filepath.Join(localPath, "..", "..", "..", "coderd", "database")
err = orderAndStubDatabaseFunctions(filepath.Join(databasePath, "dbfake", "dbfake.go"), "q", "fakeQuerier", func(params stubParams) string {
return `panic("not implemented")`
})
if err != nil {
return xerrors.Errorf("stub dbfake: %w", err)
}
err = orderAndStubDatabaseFunctions(filepath.Join(databasePath, "dbmetrics", "dbmetrics.go"), "m", "metricsStore", func(params stubParams) string {
return fmt.Sprintf(`
start := time.Now()
%s := m.s.%s(%s)
m.queryLatencies.WithLabelValues("%s").Observe(time.Since(start).Seconds())
return %s
`, params.Returns, params.FuncName, params.Parameters, params.FuncName, params.Returns)
})
if err != nil {
return xerrors.Errorf("stub dbmetrics: %w", err)
}
err = orderAndStubDatabaseFunctions(filepath.Join(databasePath, "dbauthz", "dbauthz.go"), "q", "querier", func(params stubParams) string {
return `panic("not implemented")`
})
if err != nil {
return xerrors.Errorf("stub dbauthz: %w", err)
}
err = generateUniqueConstraints()
if err != nil {
return xerrors.Errorf("generate unique constraints: %w", err)
}
return nil
}
// generateUniqueConstraints generates the UniqueConstraint enum.
func generateUniqueConstraints() error {
localPath, err := localFilePath()
if err != nil {
return err
}
databasePath := filepath.Join(localPath, "..", "..", "..", "coderd", "database")
dump, err := os.Open(filepath.Join(databasePath, "dump.sql"))
if err != nil {
return err
}
defer dump.Close()
var uniqueConstraints []string
dumpScanner := bufio.NewScanner(dump)
query := ""
for dumpScanner.Scan() {
line := strings.TrimSpace(dumpScanner.Text())
switch {
case strings.HasPrefix(line, "--"):
case line == "":
case strings.HasSuffix(line, ";"):
query += line
if strings.Contains(query, "UNIQUE") {
uniqueConstraints = append(uniqueConstraints, query)
}
query = ""
default:
query += line + " "
}
}
if err = dumpScanner.Err(); err != nil {
return err
}
s := &bytes.Buffer{}
_, _ = fmt.Fprint(s, `// Code generated by gen/enum. DO NOT EDIT.
package database
`)
_, _ = fmt.Fprint(s, `
// UniqueConstraint represents a named unique constraint on a table.
type UniqueConstraint string
// UniqueConstraint enums.
const (
`)
for _, query := range uniqueConstraints {
name := ""
switch {
case strings.Contains(query, "ALTER TABLE") && strings.Contains(query, "ADD CONSTRAINT"):
name = strings.Split(query, " ")[6]
case strings.Contains(query, "CREATE UNIQUE INDEX"):
name = strings.Split(query, " ")[3]
default:
return xerrors.Errorf("unknown unique constraint format: %s", query)
}
_, _ = fmt.Fprintf(s, "\tUnique%s UniqueConstraint = %q // %s\n", nameFromSnakeCase(name), name, query)
}
_, _ = fmt.Fprint(s, ")\n")
outputPath := filepath.Join(databasePath, "unique_constraint.go")
data, err := imports.Process(outputPath, s.Bytes(), &imports.Options{
Comments: true,
})
if err != nil {
return err
}
return os.WriteFile(outputPath, data, 0o600)
}
type stubParams struct {
FuncName string
Parameters string
Returns string
}
// orderAndStubDatabaseFunctions orders the functions in the file and stubs them.
// This is useful for when we want to add a new function to the database and
// we want to make sure that it's ordered correctly.
//
// querierFuncs is a list of functions that are in the database.
// file is the path to the file that contains all the functions.
// structName is the name of the struct that contains the functions.
// stub is a string that will be used to stub the functions.
func orderAndStubDatabaseFunctions(filePath, receiver, structName string, stub func(params stubParams) string) error {
declByName := map[string]*dst.FuncDecl{}
packageName := filepath.Base(filepath.Dir(filePath))
contents, err := os.ReadFile(filePath)
if err != nil {
return xerrors.Errorf("read dbfake: %w", err)
}
// Required to preserve imports!
f, err := decorator.NewDecoratorWithImports(token.NewFileSet(), packageName, goast.New()).Parse(contents)
if err != nil {
return xerrors.Errorf("parse dbfake: %w", err)
}
pointer := false
for i := 0; i < len(f.Decls); i++ {
funcDecl, ok := f.Decls[i].(*dst.FuncDecl)
if !ok || funcDecl.Recv == nil || len(funcDecl.Recv.List) == 0 {
continue
}
var ident *dst.Ident
switch t := funcDecl.Recv.List[0].Type.(type) {
case *dst.Ident:
ident = t
case *dst.StarExpr:
ident, ok = t.X.(*dst.Ident)
if !ok {
continue
}
pointer = true
}
if ident == nil || ident.Name != structName {
continue
}
if _, ok := funcByName[funcDecl.Name.Name]; !ok {
continue
}
declByName[funcDecl.Name.Name] = funcDecl
f.Decls = append(f.Decls[:i], f.Decls[i+1:]...)
i--
}
for _, fn := range funcs {
var bodyStmts []dst.Stmt
if len(fn.Func.Params.List) == 2 && fn.Func.Params.List[1].Names[0].Name == "arg" {
/*
err := validateDatabaseType(arg)
if err != nil {
return database.User{}, err
}
*/
bodyStmts = append(bodyStmts, &dst.AssignStmt{
Lhs: []dst.Expr{dst.NewIdent("err")},
Tok: token.DEFINE,
Rhs: []dst.Expr{
&dst.CallExpr{
Fun: &dst.Ident{
Name: "validateDatabaseType",
},
Args: []dst.Expr{dst.NewIdent("arg")},
},
},
})
returnStmt := &dst.ReturnStmt{
Results: []dst.Expr{}, // Filled below.
}
bodyStmts = append(bodyStmts, &dst.IfStmt{
Cond: &dst.BinaryExpr{
X: dst.NewIdent("err"),
Op: token.NEQ,
Y: dst.NewIdent("nil"),
},
Body: &dst.BlockStmt{
List: []dst.Stmt{
returnStmt,
},
},
Decs: dst.IfStmtDecorations{
NodeDecs: dst.NodeDecs{
After: dst.EmptyLine,
},
},
})
for _, r := range fn.Func.Results.List {
switch typ := r.Type.(type) {
case *dst.StarExpr, *dst.ArrayType:
returnStmt.Results = append(returnStmt.Results, dst.NewIdent("nil"))
case *dst.Ident:
if typ.Path != "" {
returnStmt.Results = append(returnStmt.Results, dst.NewIdent(fmt.Sprintf("%s.%s{}", path.Base(typ.Path), typ.Name)))
} else {
switch typ.Name {
case "uint8", "uint16", "uint32", "uint64", "uint", "uintptr",
"int8", "int16", "int32", "int64", "int",
"byte", "rune",
"float32", "float64",
"complex64", "complex128":
returnStmt.Results = append(returnStmt.Results, dst.NewIdent("0"))
case "string":
returnStmt.Results = append(returnStmt.Results, dst.NewIdent("\"\""))
case "bool":
returnStmt.Results = append(returnStmt.Results, dst.NewIdent("false"))
case "error":
returnStmt.Results = append(returnStmt.Results, dst.NewIdent("err"))
default:
panic(fmt.Sprintf("unknown ident: %#v", r.Type))
}
}
default:
panic(fmt.Sprintf("unknown return type: %T", r.Type))
}
}
}
decl, ok := declByName[fn.Name]
if !ok {
typeName := structName
if pointer {
typeName = "*" + typeName
}
params := make([]string, 0)
if fn.Func.Params != nil {
for _, p := range fn.Func.Params.List {
for _, name := range p.Names {
params = append(params, name.Name)
}
}
}
returns := make([]string, 0)
if fn.Func.Results != nil {
for i := range fn.Func.Results.List {
returns = append(returns, fmt.Sprintf("r%d", i))
}
}
funcDecl, err := compileFuncDecl(stub(stubParams{
FuncName: fn.Name,
Parameters: strings.Join(params, ","),
Returns: strings.Join(returns, ","),
}))
if err != nil {
return xerrors.Errorf("compile func decl: %w", err)
}
// Not implemented!
decl = &dst.FuncDecl{
Name: dst.NewIdent(fn.Name),
Type: &dst.FuncType{
Func: true,
TypeParams: fn.Func.TypeParams,
Params: fn.Func.Params,
Results: fn.Func.Results,
Decs: fn.Func.Decs,
},
Recv: &dst.FieldList{
List: []*dst.Field{{
Names: []*dst.Ident{dst.NewIdent(receiver)},
Type: dst.NewIdent(typeName),
}},
},
Decs: dst.FuncDeclDecorations{
NodeDecs: dst.NodeDecs{
Before: dst.EmptyLine,
After: dst.EmptyLine,
},
},
Body: &dst.BlockStmt{
List: append(bodyStmts, funcDecl.Body.List...),
},
}
}
if ok {
for i, pm := range fn.Func.Params.List {
if len(decl.Type.Params.List) < i+1 {
decl.Type.Params.List = append(decl.Type.Params.List, pm)
}
if !reflect.DeepEqual(decl.Type.Params.List[i].Type, pm.Type) {
decl.Type.Params.List[i].Type = pm.Type
}
}
for i, res := range fn.Func.Results.List {
if len(decl.Type.Results.List) < i+1 {
decl.Type.Results.List = append(decl.Type.Results.List, res)
}
if !reflect.DeepEqual(decl.Type.Results.List[i].Type, res.Type) {
decl.Type.Results.List[i].Type = res.Type
}
}
}
f.Decls = append(f.Decls, decl)
}
// Required to preserve imports!
restorer := decorator.NewRestorerWithImports(packageName, guess.New())
restored, err := restorer.RestoreFile(f)
if err != nil {
return xerrors.Errorf("restore package: %w", err)
}
var buf bytes.Buffer
err = format.Node(&buf, restorer.Fset, restored)
if err != nil {
return xerrors.Errorf("format package: %w", err)
}
data, err := imports.Process(filePath, buf.Bytes(), &imports.Options{
Comments: true,
FormatOnly: true,
})
if err != nil {
return xerrors.Errorf("process imports: %w", err)
}
return os.WriteFile(filePath, data, 0o600)
}
// compileFuncDecl extracts the function declaration from the given code.
func compileFuncDecl(code string) (*dst.FuncDecl, error) {
f, err := decorator.Parse(fmt.Sprintf(`package stub
func stub() {
%s
}`, strings.TrimSpace(code)))
if err != nil {
return nil, err
}
if len(f.Decls) != 1 {
return nil, xerrors.Errorf("expected 1 decl, got %d", len(f.Decls))
}
decl, ok := f.Decls[0].(*dst.FuncDecl)
if !ok {
return nil, xerrors.Errorf("expected func decl, got %T", f.Decls[0])
}
return decl, nil
}
type querierFunction struct {
// Name is the name of the function. Like "GetUserByID"
Name string
// Func is the AST representation of a function.
Func *dst.FuncType
}
// readQuerierFunctions reads the functions from coderd/database/querier.go
func readQuerierFunctions() ([]querierFunction, error) {
localPath, err := localFilePath()
if err != nil {
return nil, err
}
querierPath := filepath.Join(localPath, "..", "..", "..", "coderd", "database", "querier.go")
querierData, err := os.ReadFile(querierPath)
if err != nil {
return nil, xerrors.Errorf("read querier: %w", err)
}
f, err := decorator.Parse(querierData)
if err != nil {
return nil, err
}
var querier *dst.InterfaceType
for _, decl := range f.Decls {
genDecl, ok := decl.(*dst.GenDecl)
if !ok {
continue
}
for _, spec := range genDecl.Specs {
typeSpec, ok := spec.(*dst.TypeSpec)
if !ok {
continue
}
// This is the name of the interface. If that ever changes,
// this will need to be updated.
if typeSpec.Name.Name != "sqlcQuerier" {
continue
}
querier, ok = typeSpec.Type.(*dst.InterfaceType)
if !ok {
return nil, xerrors.Errorf("unexpected sqlcQuerier type: %T", typeSpec.Type)
}
break
}
}
if querier == nil {
return nil, xerrors.Errorf("querier not found")
}
funcs := []querierFunction{}
for _, method := range querier.Methods.List {
funcType, ok := method.Type.(*dst.FuncType)
if !ok {
continue
}
for _, t := range []*dst.FieldList{funcType.Params, funcType.Results, funcType.TypeParams} {
if t == nil {
continue
}
for _, f := range t.List {
var ident *dst.Ident
switch t := f.Type.(type) {
case *dst.Ident:
ident = t
case *dst.StarExpr:
ident, ok = t.X.(*dst.Ident)
if !ok {
continue
}
case *dst.SelectorExpr:
ident, ok = t.X.(*dst.Ident)
if !ok {
continue
}
case *dst.ArrayType:
ident, ok = t.Elt.(*dst.Ident)
if !ok {
continue
}
}
if ident == nil {
continue
}
// If the type is exported then we should be able to find it
// in the database package!
if !ident.IsExported() {
continue
}
ident.Path = "github.com/coder/coder/coderd/database"
}
}
funcs = append(funcs, querierFunction{
Name: method.Names[0].Name,
Func: funcType,
})
}
return funcs, nil
}
// localFilePath returns the location of `main.go` in the dbgen package.
func localFilePath() (string, error) {
_, filename, _, ok := runtime.Caller(0)
if !ok {
return "", xerrors.Errorf("failed to get caller")
}
return filename, nil
}
// nameFromSnakeCase converts snake_case to CamelCase.
func nameFromSnakeCase(s string) string {
var ret string
for _, ss := range strings.Split(s, "_") {
switch ss {
case "id":
ret += "ID"
case "ids":
ret += "IDs"
case "jwt":
ret += "JWT"
case "idx":
ret += "Index"
default:
ret += strings.Title(ss)
}
}
return ret
}