chore: unit test to enforce authorized queries match args (#11211)

* chore: unit test to enforce authorized queries match args
* Also check querycontext arguments
This commit is contained in:
Steven Masley 2023-12-15 14:31:07 -06:00 committed by GitHub
parent 7924bb2a56
commit 3f6096b0d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 182 additions and 0 deletions

View File

@ -0,0 +1,182 @@
package gentest_test
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
)
// TestCustomQueriesSynced makes sure the manual custom queries in modelqueries.go
// are synced with the autogenerated queries.sql.go. This should probably be
// autogenerated, but it's not atm and this is easy to throw in to elevate a better
// error message.
//
// If this breaks, and is hard to fix, you can t.Skip() it. It is not a critical
// test. Ping @Emyrk to fix it again.
func TestCustomQueriesSyncedRowScan(t *testing.T) {
t.Parallel()
funcsToTrack := map[string]string{
"GetTemplatesWithFilter": "GetAuthorizedTemplates",
"GetWorkspaces": "GetAuthorizedWorkspaces",
"GetUsers": "GetAuthorizedUsers",
}
// Scan custom
var custom []string
for _, fn := range funcsToTrack {
custom = append(custom, fn)
}
customFns := parseFile(t, "../modelqueries.go", func(name string) bool {
return slices.Contains(custom, name)
})
generatedFns := parseFile(t, "../queries.sql.go", func(name string) bool {
_, ok := funcsToTrack[name]
return ok
})
merged := customFns
for k, v := range generatedFns {
merged[k] = v
}
for a, b := range funcsToTrack {
a, b := a, b
if !compareFns(t, a, b, merged[a], merged[b]) {
//nolint:revive
defer func() {
// Run this at the end so the suggested fix is the last thing printed.
t.Errorf("The functions %q and %q need to have identical 'rows.Scan()' "+
"and 'db.QueryContext()' arguments in their function bodies. "+
"Make sure to copy the function body from the autogenerated %q body. "+
"Specifically the parameters for 'rows.Scan()' and 'db.QueryContext()'.", a, b, a)
}()
}
}
}
type parsedFunc struct {
RowScanArgs []ast.Expr
QueryArgs []ast.Expr
}
func parseFile(t *testing.T, filename string, trackFunc func(name string) bool) map[string]*parsedFunc {
fset := token.NewFileSet()
f, err := parser.ParseFile(fset, filename, nil, parser.SkipObjectResolution)
require.NoErrorf(t, err, "failed to parse file %q", filename)
parsed := make(map[string]*parsedFunc)
for _, decl := range f.Decls {
if fn, ok := decl.(*ast.FuncDecl); ok {
if trackFunc(fn.Name.Name) {
parsed[fn.Name.String()] = &parsedFunc{
RowScanArgs: pullRowScanArgs(fn),
QueryArgs: pullQueryArgs(fn),
}
}
}
}
return parsed
}
func compareFns(t *testing.T, aName, bName string, a, b *parsedFunc) bool {
if a == nil {
t.Errorf("The function %q is missing", aName)
return false
}
if b == nil {
t.Errorf("The function %q is missing", bName)
return false
}
r := compareArgs(t, "rows.Scan() arguments", aName, bName, a.RowScanArgs, b.RowScanArgs)
if len(a.QueryArgs) > 2 && len(b.QueryArgs) > 2 {
// This is because the actual query param name is different. One uses the
// const, the other uses a variable that is a mutation of the original query.
a.QueryArgs[1] = b.QueryArgs[1]
}
q := compareArgs(t, "db.QueryContext() arguments", aName, bName, a.QueryArgs, b.QueryArgs)
return r && q
}
func compareArgs(t *testing.T, argType string, aName, bName string, a, b []ast.Expr) bool {
return assert.Equal(t, argList(t, a), argList(t, b), "mismatched %s for %s and %s", argType, aName, bName)
}
func argList(t *testing.T, args []ast.Expr) []string {
defer func() {
if r := recover(); r != nil {
t.Errorf("Recovered in f reading arg names: %s", r)
}
}()
var argNames []string
for _, arg := range args {
argname := "unknown"
// This is "&i.Arg" style stuff
if unary, ok := arg.(*ast.UnaryExpr); ok {
argname = unary.X.(*ast.SelectorExpr).Sel.Name
}
if ident, ok := arg.(*ast.Ident); ok {
argname = ident.Name
}
if sel, ok := arg.(*ast.SelectorExpr); ok {
argname = sel.Sel.Name
}
if call, ok := arg.(*ast.CallExpr); ok {
// Eh, this is pg.Array style stuff. Do a best effort.
argname = fmt.Sprintf("call(%d)", len(call.Args))
if fnCall, ok := call.Fun.(*ast.SelectorExpr); ok {
argname = fmt.Sprintf("%s(%d)", fnCall.Sel.Name, len(call.Args))
}
}
if argname == "unknown" {
t.Errorf("Unknown arg, cannot parse: %T", arg)
}
argNames = append(argNames, argname)
}
return argNames
}
func pullQueryArgs(fn *ast.FuncDecl) []ast.Expr {
for _, exp := range fn.Body.List {
// find "rows, err :="
if assign, ok := exp.(*ast.AssignStmt); ok {
if len(assign.Lhs) == 2 {
if id, ok := assign.Lhs[0].(*ast.Ident); ok && id.Name == "rows" {
// This is rows, err :=
query := assign.Rhs[0].(*ast.CallExpr)
if qSel, ok := query.Fun.(*ast.SelectorExpr); ok && qSel.Sel.Name == "QueryContext" {
return query.Args
}
}
}
}
}
return nil
}
func pullRowScanArgs(fn *ast.FuncDecl) []ast.Expr {
for _, exp := range fn.Body.List {
if forStmt, ok := exp.(*ast.ForStmt); ok {
// This came from the debugger window and tracking it down.
rowScan := (forStmt.Body.
// Second statement in the for loop is the if statement
// with rows.can
List[1].(*ast.IfStmt).
// This is the err := rows.Scan()
Init.(*ast.AssignStmt).
// Rhs is the row.Scan part
Rhs)[0].(*ast.CallExpr)
return rowScan.Args
}
}
return nil
}