mirror of https://github.com/coder/coder.git
183 lines
5.2 KiB
Go
183 lines
5.2 KiB
Go
package gentest_test
|
|
|
|
import (
|
|
"fmt"
|
|
"go/ast"
|
|
"go/parser"
|
|
"go/token"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"golang.org/x/exp/slices"
|
|
)
|
|
|
|
// TestCustomQueriesSynced makes sure the manual custom queries in modelqueries.go
|
|
// are synced with the autogenerated queries.sql.go. This should probably be
|
|
// autogenerated, but it's not atm and this is easy to throw in to elevate a better
|
|
// error message.
|
|
//
|
|
// If this breaks, and is hard to fix, you can t.Skip() it. It is not a critical
|
|
// test. Ping @Emyrk to fix it again.
|
|
func TestCustomQueriesSyncedRowScan(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
funcsToTrack := map[string]string{
|
|
"GetTemplatesWithFilter": "GetAuthorizedTemplates",
|
|
"GetWorkspaces": "GetAuthorizedWorkspaces",
|
|
"GetUsers": "GetAuthorizedUsers",
|
|
}
|
|
|
|
// Scan custom
|
|
var custom []string
|
|
for _, fn := range funcsToTrack {
|
|
custom = append(custom, fn)
|
|
}
|
|
|
|
customFns := parseFile(t, "../modelqueries.go", func(name string) bool {
|
|
return slices.Contains(custom, name)
|
|
})
|
|
generatedFns := parseFile(t, "../queries.sql.go", func(name string) bool {
|
|
_, ok := funcsToTrack[name]
|
|
return ok
|
|
})
|
|
merged := customFns
|
|
for k, v := range generatedFns {
|
|
merged[k] = v
|
|
}
|
|
|
|
for a, b := range funcsToTrack {
|
|
a, b := a, b
|
|
if !compareFns(t, a, b, merged[a], merged[b]) {
|
|
//nolint:revive
|
|
defer func() {
|
|
// Run this at the end so the suggested fix is the last thing printed.
|
|
t.Errorf("The functions %q and %q need to have identical 'rows.Scan()' "+
|
|
"and 'db.QueryContext()' arguments in their function bodies. "+
|
|
"Make sure to copy the function body from the autogenerated %q body. "+
|
|
"Specifically the parameters for 'rows.Scan()' and 'db.QueryContext()'.", a, b, a)
|
|
}()
|
|
}
|
|
}
|
|
}
|
|
|
|
type parsedFunc struct {
|
|
RowScanArgs []ast.Expr
|
|
QueryArgs []ast.Expr
|
|
}
|
|
|
|
func parseFile(t *testing.T, filename string, trackFunc func(name string) bool) map[string]*parsedFunc {
|
|
fset := token.NewFileSet()
|
|
f, err := parser.ParseFile(fset, filename, nil, parser.SkipObjectResolution)
|
|
require.NoErrorf(t, err, "failed to parse file %q", filename)
|
|
|
|
parsed := make(map[string]*parsedFunc)
|
|
for _, decl := range f.Decls {
|
|
if fn, ok := decl.(*ast.FuncDecl); ok {
|
|
if trackFunc(fn.Name.Name) {
|
|
parsed[fn.Name.String()] = &parsedFunc{
|
|
RowScanArgs: pullRowScanArgs(fn),
|
|
QueryArgs: pullQueryArgs(fn),
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return parsed
|
|
}
|
|
|
|
func compareFns(t *testing.T, aName, bName string, a, b *parsedFunc) bool {
|
|
if a == nil {
|
|
t.Errorf("The function %q is missing", aName)
|
|
return false
|
|
}
|
|
if b == nil {
|
|
t.Errorf("The function %q is missing", bName)
|
|
return false
|
|
}
|
|
r := compareArgs(t, "rows.Scan() arguments", aName, bName, a.RowScanArgs, b.RowScanArgs)
|
|
if len(a.QueryArgs) > 2 && len(b.QueryArgs) > 2 {
|
|
// This is because the actual query param name is different. One uses the
|
|
// const, the other uses a variable that is a mutation of the original query.
|
|
a.QueryArgs[1] = b.QueryArgs[1]
|
|
}
|
|
q := compareArgs(t, "db.QueryContext() arguments", aName, bName, a.QueryArgs, b.QueryArgs)
|
|
return r && q
|
|
}
|
|
|
|
func compareArgs(t *testing.T, argType string, aName, bName string, a, b []ast.Expr) bool {
|
|
return assert.Equal(t, argList(t, a), argList(t, b), "mismatched %s for %s and %s", argType, aName, bName)
|
|
}
|
|
|
|
func argList(t *testing.T, args []ast.Expr) []string {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
t.Errorf("Recovered in f reading arg names: %s", r)
|
|
}
|
|
}()
|
|
|
|
var argNames []string
|
|
for _, arg := range args {
|
|
argname := "unknown"
|
|
// This is "&i.Arg" style stuff
|
|
if unary, ok := arg.(*ast.UnaryExpr); ok {
|
|
argname = unary.X.(*ast.SelectorExpr).Sel.Name
|
|
}
|
|
if ident, ok := arg.(*ast.Ident); ok {
|
|
argname = ident.Name
|
|
}
|
|
if sel, ok := arg.(*ast.SelectorExpr); ok {
|
|
argname = sel.Sel.Name
|
|
}
|
|
if call, ok := arg.(*ast.CallExpr); ok {
|
|
// Eh, this is pg.Array style stuff. Do a best effort.
|
|
argname = fmt.Sprintf("call(%d)", len(call.Args))
|
|
if fnCall, ok := call.Fun.(*ast.SelectorExpr); ok {
|
|
argname = fmt.Sprintf("%s(%d)", fnCall.Sel.Name, len(call.Args))
|
|
}
|
|
}
|
|
|
|
if argname == "unknown" {
|
|
t.Errorf("Unknown arg, cannot parse: %T", arg)
|
|
}
|
|
argNames = append(argNames, argname)
|
|
}
|
|
return argNames
|
|
}
|
|
|
|
func pullQueryArgs(fn *ast.FuncDecl) []ast.Expr {
|
|
for _, exp := range fn.Body.List {
|
|
// find "rows, err :="
|
|
if assign, ok := exp.(*ast.AssignStmt); ok {
|
|
if len(assign.Lhs) == 2 {
|
|
if id, ok := assign.Lhs[0].(*ast.Ident); ok && id.Name == "rows" {
|
|
// This is rows, err :=
|
|
query := assign.Rhs[0].(*ast.CallExpr)
|
|
if qSel, ok := query.Fun.(*ast.SelectorExpr); ok && qSel.Sel.Name == "QueryContext" {
|
|
return query.Args
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func pullRowScanArgs(fn *ast.FuncDecl) []ast.Expr {
|
|
for _, exp := range fn.Body.List {
|
|
if forStmt, ok := exp.(*ast.ForStmt); ok {
|
|
// This came from the debugger window and tracking it down.
|
|
rowScan := (forStmt.Body.
|
|
// Second statement in the for loop is the if statement
|
|
// with rows.can
|
|
List[1].(*ast.IfStmt).
|
|
// This is the err := rows.Scan()
|
|
Init.(*ast.AssignStmt).
|
|
// Rhs is the row.Scan part
|
|
Rhs)[0].(*ast.CallExpr)
|
|
return rowScan.Args
|
|
}
|
|
}
|
|
return nil
|
|
}
|