mirror of https://github.com/coder/coder.git
chore: unit test to enforce authorized queries match args (#11211)
* chore: unit test to enforce authorized queries match args * Also check querycontext arguments
This commit is contained in:
parent
7924bb2a56
commit
3f6096b0d7
|
@ -0,0 +1,182 @@
|
|||
package gentest_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"go/token"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// TestCustomQueriesSynced makes sure the manual custom queries in modelqueries.go
|
||||
// are synced with the autogenerated queries.sql.go. This should probably be
|
||||
// autogenerated, but it's not atm and this is easy to throw in to elevate a better
|
||||
// error message.
|
||||
//
|
||||
// If this breaks, and is hard to fix, you can t.Skip() it. It is not a critical
|
||||
// test. Ping @Emyrk to fix it again.
|
||||
func TestCustomQueriesSyncedRowScan(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
funcsToTrack := map[string]string{
|
||||
"GetTemplatesWithFilter": "GetAuthorizedTemplates",
|
||||
"GetWorkspaces": "GetAuthorizedWorkspaces",
|
||||
"GetUsers": "GetAuthorizedUsers",
|
||||
}
|
||||
|
||||
// Scan custom
|
||||
var custom []string
|
||||
for _, fn := range funcsToTrack {
|
||||
custom = append(custom, fn)
|
||||
}
|
||||
|
||||
customFns := parseFile(t, "../modelqueries.go", func(name string) bool {
|
||||
return slices.Contains(custom, name)
|
||||
})
|
||||
generatedFns := parseFile(t, "../queries.sql.go", func(name string) bool {
|
||||
_, ok := funcsToTrack[name]
|
||||
return ok
|
||||
})
|
||||
merged := customFns
|
||||
for k, v := range generatedFns {
|
||||
merged[k] = v
|
||||
}
|
||||
|
||||
for a, b := range funcsToTrack {
|
||||
a, b := a, b
|
||||
if !compareFns(t, a, b, merged[a], merged[b]) {
|
||||
//nolint:revive
|
||||
defer func() {
|
||||
// Run this at the end so the suggested fix is the last thing printed.
|
||||
t.Errorf("The functions %q and %q need to have identical 'rows.Scan()' "+
|
||||
"and 'db.QueryContext()' arguments in their function bodies. "+
|
||||
"Make sure to copy the function body from the autogenerated %q body. "+
|
||||
"Specifically the parameters for 'rows.Scan()' and 'db.QueryContext()'.", a, b, a)
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type parsedFunc struct {
|
||||
RowScanArgs []ast.Expr
|
||||
QueryArgs []ast.Expr
|
||||
}
|
||||
|
||||
func parseFile(t *testing.T, filename string, trackFunc func(name string) bool) map[string]*parsedFunc {
|
||||
fset := token.NewFileSet()
|
||||
f, err := parser.ParseFile(fset, filename, nil, parser.SkipObjectResolution)
|
||||
require.NoErrorf(t, err, "failed to parse file %q", filename)
|
||||
|
||||
parsed := make(map[string]*parsedFunc)
|
||||
for _, decl := range f.Decls {
|
||||
if fn, ok := decl.(*ast.FuncDecl); ok {
|
||||
if trackFunc(fn.Name.Name) {
|
||||
parsed[fn.Name.String()] = &parsedFunc{
|
||||
RowScanArgs: pullRowScanArgs(fn),
|
||||
QueryArgs: pullQueryArgs(fn),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return parsed
|
||||
}
|
||||
|
||||
func compareFns(t *testing.T, aName, bName string, a, b *parsedFunc) bool {
|
||||
if a == nil {
|
||||
t.Errorf("The function %q is missing", aName)
|
||||
return false
|
||||
}
|
||||
if b == nil {
|
||||
t.Errorf("The function %q is missing", bName)
|
||||
return false
|
||||
}
|
||||
r := compareArgs(t, "rows.Scan() arguments", aName, bName, a.RowScanArgs, b.RowScanArgs)
|
||||
if len(a.QueryArgs) > 2 && len(b.QueryArgs) > 2 {
|
||||
// This is because the actual query param name is different. One uses the
|
||||
// const, the other uses a variable that is a mutation of the original query.
|
||||
a.QueryArgs[1] = b.QueryArgs[1]
|
||||
}
|
||||
q := compareArgs(t, "db.QueryContext() arguments", aName, bName, a.QueryArgs, b.QueryArgs)
|
||||
return r && q
|
||||
}
|
||||
|
||||
func compareArgs(t *testing.T, argType string, aName, bName string, a, b []ast.Expr) bool {
|
||||
return assert.Equal(t, argList(t, a), argList(t, b), "mismatched %s for %s and %s", argType, aName, bName)
|
||||
}
|
||||
|
||||
func argList(t *testing.T, args []ast.Expr) []string {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("Recovered in f reading arg names: %s", r)
|
||||
}
|
||||
}()
|
||||
|
||||
var argNames []string
|
||||
for _, arg := range args {
|
||||
argname := "unknown"
|
||||
// This is "&i.Arg" style stuff
|
||||
if unary, ok := arg.(*ast.UnaryExpr); ok {
|
||||
argname = unary.X.(*ast.SelectorExpr).Sel.Name
|
||||
}
|
||||
if ident, ok := arg.(*ast.Ident); ok {
|
||||
argname = ident.Name
|
||||
}
|
||||
if sel, ok := arg.(*ast.SelectorExpr); ok {
|
||||
argname = sel.Sel.Name
|
||||
}
|
||||
if call, ok := arg.(*ast.CallExpr); ok {
|
||||
// Eh, this is pg.Array style stuff. Do a best effort.
|
||||
argname = fmt.Sprintf("call(%d)", len(call.Args))
|
||||
if fnCall, ok := call.Fun.(*ast.SelectorExpr); ok {
|
||||
argname = fmt.Sprintf("%s(%d)", fnCall.Sel.Name, len(call.Args))
|
||||
}
|
||||
}
|
||||
|
||||
if argname == "unknown" {
|
||||
t.Errorf("Unknown arg, cannot parse: %T", arg)
|
||||
}
|
||||
argNames = append(argNames, argname)
|
||||
}
|
||||
return argNames
|
||||
}
|
||||
|
||||
func pullQueryArgs(fn *ast.FuncDecl) []ast.Expr {
|
||||
for _, exp := range fn.Body.List {
|
||||
// find "rows, err :="
|
||||
if assign, ok := exp.(*ast.AssignStmt); ok {
|
||||
if len(assign.Lhs) == 2 {
|
||||
if id, ok := assign.Lhs[0].(*ast.Ident); ok && id.Name == "rows" {
|
||||
// This is rows, err :=
|
||||
query := assign.Rhs[0].(*ast.CallExpr)
|
||||
if qSel, ok := query.Fun.(*ast.SelectorExpr); ok && qSel.Sel.Name == "QueryContext" {
|
||||
return query.Args
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func pullRowScanArgs(fn *ast.FuncDecl) []ast.Expr {
|
||||
for _, exp := range fn.Body.List {
|
||||
if forStmt, ok := exp.(*ast.ForStmt); ok {
|
||||
// This came from the debugger window and tracking it down.
|
||||
rowScan := (forStmt.Body.
|
||||
// Second statement in the for loop is the if statement
|
||||
// with rows.can
|
||||
List[1].(*ast.IfStmt).
|
||||
// This is the err := rows.Scan()
|
||||
Init.(*ast.AssignStmt).
|
||||
// Rhs is the row.Scan part
|
||||
Rhs)[0].(*ast.CallExpr)
|
||||
return rowScan.Args
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
Loading…
Reference in New Issue