coder/scripts/dbgen/main.go

681 lines
17 KiB
Go

package main
import (
"bufio"
"bytes"
"fmt"
"go/format"
"go/token"
"os"
"path"
"path/filepath"
"reflect"
"runtime"
"strings"
"github.com/dave/dst"
"github.com/dave/dst/decorator"
"github.com/dave/dst/decorator/resolver/goast"
"github.com/dave/dst/decorator/resolver/guess"
"golang.org/x/tools/imports"
"golang.org/x/xerrors"
)
var (
funcs []querierFunction
funcByName map[string]struct{}
)
func init() {
var err error
funcs, err = readQuerierFunctions()
if err != nil {
panic(err)
}
funcByName = map[string]struct{}{}
for _, f := range funcs {
funcByName[f.Name] = struct{}{}
}
}
func main() {
err := run()
if err != nil {
_, _ = fmt.Fprintf(os.Stderr, "error: %s\n", err)
os.Exit(1)
}
}
func run() error {
localPath, err := localFilePath()
if err != nil {
return err
}
databasePath := filepath.Join(localPath, "..", "..", "..", "coderd", "database")
err = orderAndStubDatabaseFunctions(filepath.Join(databasePath, "dbmem", "dbmem.go"), "q", "FakeQuerier", func(params stubParams) string {
return `panic("not implemented")`
})
if err != nil {
return xerrors.Errorf("stub dbmem: %w", err)
}
err = orderAndStubDatabaseFunctions(filepath.Join(databasePath, "dbmetrics", "dbmetrics.go"), "m", "metricsStore", func(params stubParams) string {
return fmt.Sprintf(`
start := time.Now()
%s := m.s.%s(%s)
m.queryLatencies.WithLabelValues("%s").Observe(time.Since(start).Seconds())
return %s
`, params.Returns, params.FuncName, params.Parameters, params.FuncName, params.Returns)
})
if err != nil {
return xerrors.Errorf("stub dbmetrics: %w", err)
}
err = orderAndStubDatabaseFunctions(filepath.Join(databasePath, "dbauthz", "dbauthz.go"), "q", "querier", func(params stubParams) string {
return `panic("not implemented")`
})
if err != nil {
return xerrors.Errorf("stub dbauthz: %w", err)
}
err = generateUniqueConstraints()
if err != nil {
return xerrors.Errorf("generate unique constraints: %w", err)
}
err = generateForeignKeyConstraints()
if err != nil {
return xerrors.Errorf("generate foreign key constraints: %w", err)
}
return nil
}
// generateUniqueConstraints generates the UniqueConstraint enum.
func generateUniqueConstraints() error {
localPath, err := localFilePath()
if err != nil {
return err
}
databasePath := filepath.Join(localPath, "..", "..", "..", "coderd", "database")
dump, err := os.Open(filepath.Join(databasePath, "dump.sql"))
if err != nil {
return err
}
defer dump.Close()
var uniqueConstraints []string
dumpScanner := bufio.NewScanner(dump)
query := ""
for dumpScanner.Scan() {
line := strings.TrimSpace(dumpScanner.Text())
switch {
case strings.HasPrefix(line, "--"):
case line == "":
case strings.HasSuffix(line, ";"):
query += line
if strings.Contains(query, "UNIQUE") || strings.Contains(query, "PRIMARY KEY") {
uniqueConstraints = append(uniqueConstraints, query)
}
query = ""
default:
query += line + " "
}
}
if err = dumpScanner.Err(); err != nil {
return err
}
s := &bytes.Buffer{}
_, _ = fmt.Fprint(s, `// Code generated by scripts/dbgen/main.go. DO NOT EDIT.
package database
`)
_, _ = fmt.Fprint(s, `
// UniqueConstraint represents a named unique constraint on a table.
type UniqueConstraint string
// UniqueConstraint enums.
const (
`)
for _, query := range uniqueConstraints {
name := ""
switch {
case strings.Contains(query, "ALTER TABLE") && strings.Contains(query, "ADD CONSTRAINT"):
name = strings.Split(query, " ")[6]
case strings.Contains(query, "CREATE UNIQUE INDEX"):
name = strings.Split(query, " ")[3]
default:
return xerrors.Errorf("unknown unique constraint format: %s", query)
}
_, _ = fmt.Fprintf(s, "\tUnique%s UniqueConstraint = %q // %s\n", nameFromSnakeCase(name), name, query)
}
_, _ = fmt.Fprint(s, ")\n")
outputPath := filepath.Join(databasePath, "unique_constraint.go")
data, err := imports.Process(outputPath, s.Bytes(), &imports.Options{
Comments: true,
})
if err != nil {
return err
}
return os.WriteFile(outputPath, data, 0o600)
}
// generateForeignKeyConstraints generates the ForeignKeyConstraint enum.
func generateForeignKeyConstraints() error {
localPath, err := localFilePath()
if err != nil {
return err
}
databasePath := filepath.Join(localPath, "..", "..", "..", "coderd", "database")
dump, err := os.Open(filepath.Join(databasePath, "dump.sql"))
if err != nil {
return err
}
defer dump.Close()
var foreignKeyConstraints []string
dumpScanner := bufio.NewScanner(dump)
query := ""
for dumpScanner.Scan() {
line := strings.TrimSpace(dumpScanner.Text())
switch {
case strings.HasPrefix(line, "--"):
case line == "":
case strings.HasSuffix(line, ";"):
query += line
if strings.Contains(query, "FOREIGN KEY") {
foreignKeyConstraints = append(foreignKeyConstraints, query)
}
query = ""
default:
query += line + " "
}
}
if err := dumpScanner.Err(); err != nil {
return err
}
s := &bytes.Buffer{}
_, _ = fmt.Fprint(s, `// Code generated by scripts/dbgen/main.go. DO NOT EDIT.
package database
`)
_, _ = fmt.Fprint(s, `
// ForeignKeyConstraint represents a named foreign key constraint on a table.
type ForeignKeyConstraint string
// ForeignKeyConstraint enums.
const (
`)
for _, query := range foreignKeyConstraints {
name := ""
switch {
case strings.Contains(query, "ALTER TABLE") && strings.Contains(query, "ADD CONSTRAINT"):
name = strings.Split(query, " ")[6]
default:
return xerrors.Errorf("unknown foreign key constraint format: %s", query)
}
_, _ = fmt.Fprintf(s, "\tForeignKey%s ForeignKeyConstraint = %q // %s\n", nameFromSnakeCase(name), name, query)
}
_, _ = fmt.Fprint(s, ")\n")
outputPath := filepath.Join(databasePath, "foreign_key_constraint.go")
data, err := imports.Process(outputPath, s.Bytes(), &imports.Options{
Comments: true,
})
if err != nil {
return err
}
return os.WriteFile(outputPath, data, 0o600)
}
type stubParams struct {
FuncName string
Parameters string
Returns string
}
// orderAndStubDatabaseFunctions orders the functions in the file and stubs them.
// This is useful for when we want to add a new function to the database and
// we want to make sure that it's ordered correctly.
//
// querierFuncs is a list of functions that are in the database.
// file is the path to the file that contains all the functions.
// structName is the name of the struct that contains the functions.
// stub is a string that will be used to stub the functions.
func orderAndStubDatabaseFunctions(filePath, receiver, structName string, stub func(params stubParams) string) error {
declByName := map[string]*dst.FuncDecl{}
packageName := filepath.Base(filepath.Dir(filePath))
contents, err := os.ReadFile(filePath)
if err != nil {
return xerrors.Errorf("read dbmem: %w", err)
}
// Required to preserve imports!
f, err := decorator.NewDecoratorWithImports(token.NewFileSet(), packageName, goast.New()).Parse(contents)
if err != nil {
return xerrors.Errorf("parse dbmem: %w", err)
}
pointer := false
for i := 0; i < len(f.Decls); i++ {
funcDecl, ok := f.Decls[i].(*dst.FuncDecl)
if !ok || funcDecl.Recv == nil || len(funcDecl.Recv.List) == 0 {
continue
}
var ident *dst.Ident
switch t := funcDecl.Recv.List[0].Type.(type) {
case *dst.Ident:
ident = t
case *dst.StarExpr:
ident, ok = t.X.(*dst.Ident)
if !ok {
continue
}
pointer = true
}
if ident == nil || ident.Name != structName {
continue
}
if _, ok := funcByName[funcDecl.Name.Name]; !ok {
continue
}
declByName[funcDecl.Name.Name] = funcDecl
f.Decls = append(f.Decls[:i], f.Decls[i+1:]...)
i--
}
for _, fn := range funcs {
var bodyStmts []dst.Stmt
// Add input validation, only relevant for dbmem.
if strings.Contains(filePath, "dbmem") && len(fn.Func.Params.List) == 2 && fn.Func.Params.List[1].Names[0].Name == "arg" {
/*
err := validateDatabaseType(arg)
if err != nil {
return database.User{}, err
}
*/
bodyStmts = append(bodyStmts, &dst.AssignStmt{
Lhs: []dst.Expr{dst.NewIdent("err")},
Tok: token.DEFINE,
Rhs: []dst.Expr{
&dst.CallExpr{
Fun: &dst.Ident{
Name: "validateDatabaseType",
},
Args: []dst.Expr{dst.NewIdent("arg")},
},
},
})
returnStmt := &dst.ReturnStmt{
Results: []dst.Expr{}, // Filled below.
}
bodyStmts = append(bodyStmts, &dst.IfStmt{
Cond: &dst.BinaryExpr{
X: dst.NewIdent("err"),
Op: token.NEQ,
Y: dst.NewIdent("nil"),
},
Body: &dst.BlockStmt{
List: []dst.Stmt{
returnStmt,
},
},
Decs: dst.IfStmtDecorations{
NodeDecs: dst.NodeDecs{
After: dst.EmptyLine,
},
},
})
for _, r := range fn.Func.Results.List {
switch typ := r.Type.(type) {
case *dst.StarExpr, *dst.ArrayType:
returnStmt.Results = append(returnStmt.Results, dst.NewIdent("nil"))
case *dst.Ident:
if typ.Path != "" {
returnStmt.Results = append(returnStmt.Results, dst.NewIdent(fmt.Sprintf("%s.%s{}", path.Base(typ.Path), typ.Name)))
} else {
switch typ.Name {
case "uint8", "uint16", "uint32", "uint64", "uint", "uintptr",
"int8", "int16", "int32", "int64", "int",
"byte", "rune",
"float32", "float64",
"complex64", "complex128":
returnStmt.Results = append(returnStmt.Results, dst.NewIdent("0"))
case "string":
returnStmt.Results = append(returnStmt.Results, dst.NewIdent("\"\""))
case "bool":
returnStmt.Results = append(returnStmt.Results, dst.NewIdent("false"))
case "error":
returnStmt.Results = append(returnStmt.Results, dst.NewIdent("err"))
default:
panic(fmt.Sprintf("unknown ident: %#v", r.Type))
}
}
default:
panic(fmt.Sprintf("unknown return type: %T", r.Type))
}
}
}
decl, ok := declByName[fn.Name]
if !ok {
typeName := structName
if pointer {
typeName = "*" + typeName
}
params := make([]string, 0)
if fn.Func.Params != nil {
for _, p := range fn.Func.Params.List {
for _, name := range p.Names {
params = append(params, name.Name)
}
}
}
returns := make([]string, 0)
if fn.Func.Results != nil {
for i := range fn.Func.Results.List {
returns = append(returns, fmt.Sprintf("r%d", i))
}
}
funcDecl, err := compileFuncDecl(stub(stubParams{
FuncName: fn.Name,
Parameters: strings.Join(params, ","),
Returns: strings.Join(returns, ","),
}))
if err != nil {
return xerrors.Errorf("compile func decl: %w", err)
}
// Not implemented!
decl = &dst.FuncDecl{
Name: dst.NewIdent(fn.Name),
Type: &dst.FuncType{
Func: true,
TypeParams: fn.Func.TypeParams,
Params: fn.Func.Params,
Results: fn.Func.Results,
Decs: fn.Func.Decs,
},
Recv: &dst.FieldList{
List: []*dst.Field{{
Names: []*dst.Ident{dst.NewIdent(receiver)},
Type: dst.NewIdent(typeName),
}},
},
Decs: dst.FuncDeclDecorations{
NodeDecs: dst.NodeDecs{
Before: dst.EmptyLine,
After: dst.EmptyLine,
},
},
Body: &dst.BlockStmt{
List: append(bodyStmts, funcDecl.Body.List...),
},
}
}
if ok {
for i, pm := range fn.Func.Params.List {
if len(decl.Type.Params.List) < i+1 {
decl.Type.Params.List = append(decl.Type.Params.List, pm)
}
if !reflect.DeepEqual(decl.Type.Params.List[i].Type, pm.Type) {
decl.Type.Params.List[i].Type = pm.Type
}
}
for i, res := range fn.Func.Results.List {
if len(decl.Type.Results.List) < i+1 {
decl.Type.Results.List = append(decl.Type.Results.List, res)
}
if !reflect.DeepEqual(decl.Type.Results.List[i].Type, res.Type) {
decl.Type.Results.List[i].Type = res.Type
}
}
}
f.Decls = append(f.Decls, decl)
}
// Required to preserve imports!
restorer := decorator.NewRestorerWithImports(packageName, guess.New())
restored, err := restorer.RestoreFile(f)
if err != nil {
return xerrors.Errorf("restore package: %w", err)
}
var buf bytes.Buffer
err = format.Node(&buf, restorer.Fset, restored)
if err != nil {
return xerrors.Errorf("format package: %w", err)
}
data, err := imports.Process(filePath, buf.Bytes(), &imports.Options{
Comments: true,
FormatOnly: true,
})
if err != nil {
return xerrors.Errorf("process imports: %w", err)
}
return os.WriteFile(filePath, data, 0o600)
}
// compileFuncDecl extracts the function declaration from the given code.
func compileFuncDecl(code string) (*dst.FuncDecl, error) {
f, err := decorator.Parse(fmt.Sprintf(`package stub
func stub() {
%s
}`, strings.TrimSpace(code)))
if err != nil {
return nil, err
}
if len(f.Decls) != 1 {
return nil, xerrors.Errorf("expected 1 decl, got %d", len(f.Decls))
}
decl, ok := f.Decls[0].(*dst.FuncDecl)
if !ok {
return nil, xerrors.Errorf("expected func decl, got %T", f.Decls[0])
}
return decl, nil
}
type querierFunction struct {
// Name is the name of the function. Like "GetUserByID"
Name string
// Func is the AST representation of a function.
Func *dst.FuncType
}
// readQuerierFunctions reads the functions from coderd/database/querier.go
func readQuerierFunctions() ([]querierFunction, error) {
f, err := parseDBFile("querier.go")
if err != nil {
return nil, xerrors.Errorf("parse querier.go: %w", err)
}
funcs, err := loadInterfaceFuncs(f, "sqlcQuerier")
if err != nil {
return nil, xerrors.Errorf("load interface %s funcs: %w", "sqlcQuerier", err)
}
customFile, err := parseDBFile("modelqueries.go")
if err != nil {
return nil, xerrors.Errorf("parse modelqueriers.go: %w", err)
}
// Custom funcs should be appended after the regular functions
customFuncs, err := loadInterfaceFuncs(customFile, "customQuerier")
if err != nil {
return nil, xerrors.Errorf("load interface %s funcs: %w", "customQuerier", err)
}
return append(funcs, customFuncs...), nil
}
func parseDBFile(filename string) (*dst.File, error) {
localPath, err := localFilePath()
if err != nil {
return nil, err
}
querierPath := filepath.Join(localPath, "..", "..", "..", "coderd", "database", filename)
querierData, err := os.ReadFile(querierPath)
if err != nil {
return nil, xerrors.Errorf("read %s: %w", filename, err)
}
f, err := decorator.Parse(querierData)
return f, err
}
func loadInterfaceFuncs(f *dst.File, interfaceName string) ([]querierFunction, error) {
var querier *dst.InterfaceType
for _, decl := range f.Decls {
genDecl, ok := decl.(*dst.GenDecl)
if !ok {
continue
}
for _, spec := range genDecl.Specs {
typeSpec, ok := spec.(*dst.TypeSpec)
if !ok {
continue
}
// This is the name of the interface. If that ever changes,
// this will need to be updated.
if typeSpec.Name.Name != interfaceName {
continue
}
querier, ok = typeSpec.Type.(*dst.InterfaceType)
if !ok {
return nil, xerrors.Errorf("unexpected sqlcQuerier type: %T", typeSpec.Type)
}
break
}
}
if querier == nil {
return nil, xerrors.Errorf("querier not found")
}
funcs := []querierFunction{}
allMethods := interfaceMethods(querier)
for _, method := range allMethods {
funcType, ok := method.Type.(*dst.FuncType)
if !ok {
continue
}
for _, t := range []*dst.FieldList{funcType.Params, funcType.Results, funcType.TypeParams} {
if t == nil {
continue
}
for _, f := range t.List {
var ident *dst.Ident
switch t := f.Type.(type) {
case *dst.Ident:
ident = t
case *dst.StarExpr:
ident, ok = t.X.(*dst.Ident)
if !ok {
continue
}
case *dst.SelectorExpr:
ident, ok = t.X.(*dst.Ident)
if !ok {
continue
}
case *dst.ArrayType:
ident, ok = t.Elt.(*dst.Ident)
if !ok {
continue
}
}
if ident == nil {
continue
}
// If the type is exported then we should be able to find it
// in the database package!
if !ident.IsExported() {
continue
}
ident.Path = "github.com/coder/coder/v2/coderd/database"
}
}
funcs = append(funcs, querierFunction{
Name: method.Names[0].Name,
Func: funcType,
})
}
return funcs, nil
}
// localFilePath returns the location of `main.go` in the dbgen package.
func localFilePath() (string, error) {
_, filename, _, ok := runtime.Caller(0)
if !ok {
return "", xerrors.Errorf("failed to get caller")
}
return filename, nil
}
// nameFromSnakeCase converts snake_case to CamelCase.
func nameFromSnakeCase(s string) string {
var ret string
for _, ss := range strings.Split(s, "_") {
switch ss {
case "id":
ret += "ID"
case "ids":
ret += "IDs"
case "jwt":
ret += "JWT"
case "idx":
ret += "Index"
case "api":
ret += "API"
case "uuid":
ret += "UUID"
case "gitsshkeys":
ret += "GitSSHKeys"
case "fkey":
// ignore
default:
ret += strings.Title(ss)
}
}
return ret
}
// interfaceMethods returns all embedded methods of an interface.
func interfaceMethods(i *dst.InterfaceType) []*dst.Field {
var allMethods []*dst.Field
for _, field := range i.Methods.List {
switch fieldType := field.Type.(type) {
case *dst.FuncType:
allMethods = append(allMethods, field)
case *dst.InterfaceType:
allMethods = append(allMethods, interfaceMethods(fieldType)...)
case *dst.Ident:
// Embedded interfaces are Idents -> TypeSpec -> InterfaceType
// If the embedded interface is not in the parsed file, then
// the Obj will be nil.
if fieldType.Obj != nil {
objDecl, ok := fieldType.Obj.Decl.(*dst.TypeSpec)
if ok {
isInterface, ok := objDecl.Type.(*dst.InterfaceType)
if ok {
allMethods = append(allMethods, interfaceMethods(isInterface)...)
}
}
}
}
}
return allMethods
}