From 3f6096b0d7785494144fa7aff6ae0b9fab39ae01 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 15 Dec 2023 14:31:07 -0600 Subject: [PATCH] chore: unit test to enforce authorized queries match args (#11211) * chore: unit test to enforce authorized queries match args * Also check querycontext arguments --- coderd/database/gentest/modelqueries_test.go | 182 +++++++++++++++++++ 1 file changed, 182 insertions(+) create mode 100644 coderd/database/gentest/modelqueries_test.go diff --git a/coderd/database/gentest/modelqueries_test.go b/coderd/database/gentest/modelqueries_test.go new file mode 100644 index 0000000000..52a99b5440 --- /dev/null +++ b/coderd/database/gentest/modelqueries_test.go @@ -0,0 +1,182 @@ +package gentest_test + +import ( + "fmt" + "go/ast" + "go/parser" + "go/token" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/exp/slices" +) + +// TestCustomQueriesSynced makes sure the manual custom queries in modelqueries.go +// are synced with the autogenerated queries.sql.go. This should probably be +// autogenerated, but it's not atm and this is easy to throw in to elevate a better +// error message. +// +// If this breaks, and is hard to fix, you can t.Skip() it. It is not a critical +// test. Ping @Emyrk to fix it again. +func TestCustomQueriesSyncedRowScan(t *testing.T) { + t.Parallel() + + funcsToTrack := map[string]string{ + "GetTemplatesWithFilter": "GetAuthorizedTemplates", + "GetWorkspaces": "GetAuthorizedWorkspaces", + "GetUsers": "GetAuthorizedUsers", + } + + // Scan custom + var custom []string + for _, fn := range funcsToTrack { + custom = append(custom, fn) + } + + customFns := parseFile(t, "../modelqueries.go", func(name string) bool { + return slices.Contains(custom, name) + }) + generatedFns := parseFile(t, "../queries.sql.go", func(name string) bool { + _, ok := funcsToTrack[name] + return ok + }) + merged := customFns + for k, v := range generatedFns { + merged[k] = v + } + + for a, b := range funcsToTrack { + a, b := a, b + if !compareFns(t, a, b, merged[a], merged[b]) { + //nolint:revive + defer func() { + // Run this at the end so the suggested fix is the last thing printed. + t.Errorf("The functions %q and %q need to have identical 'rows.Scan()' "+ + "and 'db.QueryContext()' arguments in their function bodies. "+ + "Make sure to copy the function body from the autogenerated %q body. "+ + "Specifically the parameters for 'rows.Scan()' and 'db.QueryContext()'.", a, b, a) + }() + } + } +} + +type parsedFunc struct { + RowScanArgs []ast.Expr + QueryArgs []ast.Expr +} + +func parseFile(t *testing.T, filename string, trackFunc func(name string) bool) map[string]*parsedFunc { + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, filename, nil, parser.SkipObjectResolution) + require.NoErrorf(t, err, "failed to parse file %q", filename) + + parsed := make(map[string]*parsedFunc) + for _, decl := range f.Decls { + if fn, ok := decl.(*ast.FuncDecl); ok { + if trackFunc(fn.Name.Name) { + parsed[fn.Name.String()] = &parsedFunc{ + RowScanArgs: pullRowScanArgs(fn), + QueryArgs: pullQueryArgs(fn), + } + } + } + } + + return parsed +} + +func compareFns(t *testing.T, aName, bName string, a, b *parsedFunc) bool { + if a == nil { + t.Errorf("The function %q is missing", aName) + return false + } + if b == nil { + t.Errorf("The function %q is missing", bName) + return false + } + r := compareArgs(t, "rows.Scan() arguments", aName, bName, a.RowScanArgs, b.RowScanArgs) + if len(a.QueryArgs) > 2 && len(b.QueryArgs) > 2 { + // This is because the actual query param name is different. One uses the + // const, the other uses a variable that is a mutation of the original query. + a.QueryArgs[1] = b.QueryArgs[1] + } + q := compareArgs(t, "db.QueryContext() arguments", aName, bName, a.QueryArgs, b.QueryArgs) + return r && q +} + +func compareArgs(t *testing.T, argType string, aName, bName string, a, b []ast.Expr) bool { + return assert.Equal(t, argList(t, a), argList(t, b), "mismatched %s for %s and %s", argType, aName, bName) +} + +func argList(t *testing.T, args []ast.Expr) []string { + defer func() { + if r := recover(); r != nil { + t.Errorf("Recovered in f reading arg names: %s", r) + } + }() + + var argNames []string + for _, arg := range args { + argname := "unknown" + // This is "&i.Arg" style stuff + if unary, ok := arg.(*ast.UnaryExpr); ok { + argname = unary.X.(*ast.SelectorExpr).Sel.Name + } + if ident, ok := arg.(*ast.Ident); ok { + argname = ident.Name + } + if sel, ok := arg.(*ast.SelectorExpr); ok { + argname = sel.Sel.Name + } + if call, ok := arg.(*ast.CallExpr); ok { + // Eh, this is pg.Array style stuff. Do a best effort. + argname = fmt.Sprintf("call(%d)", len(call.Args)) + if fnCall, ok := call.Fun.(*ast.SelectorExpr); ok { + argname = fmt.Sprintf("%s(%d)", fnCall.Sel.Name, len(call.Args)) + } + } + + if argname == "unknown" { + t.Errorf("Unknown arg, cannot parse: %T", arg) + } + argNames = append(argNames, argname) + } + return argNames +} + +func pullQueryArgs(fn *ast.FuncDecl) []ast.Expr { + for _, exp := range fn.Body.List { + // find "rows, err :=" + if assign, ok := exp.(*ast.AssignStmt); ok { + if len(assign.Lhs) == 2 { + if id, ok := assign.Lhs[0].(*ast.Ident); ok && id.Name == "rows" { + // This is rows, err := + query := assign.Rhs[0].(*ast.CallExpr) + if qSel, ok := query.Fun.(*ast.SelectorExpr); ok && qSel.Sel.Name == "QueryContext" { + return query.Args + } + } + } + } + } + return nil +} + +func pullRowScanArgs(fn *ast.FuncDecl) []ast.Expr { + for _, exp := range fn.Body.List { + if forStmt, ok := exp.(*ast.ForStmt); ok { + // This came from the debugger window and tracking it down. + rowScan := (forStmt.Body. + // Second statement in the for loop is the if statement + // with rows.can + List[1].(*ast.IfStmt). + // This is the err := rows.Scan() + Init.(*ast.AssignStmt). + // Rhs is the row.Scan part + Rhs)[0].(*ast.CallExpr) + return rowScan.Args + } + } + return nil +}