mirror of https://github.com/coder/coder.git
681 lines
17 KiB
Go
681 lines
17 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"fmt"
|
|
"go/format"
|
|
"go/token"
|
|
"os"
|
|
"path"
|
|
"path/filepath"
|
|
"reflect"
|
|
"runtime"
|
|
"strings"
|
|
|
|
"github.com/dave/dst"
|
|
"github.com/dave/dst/decorator"
|
|
"github.com/dave/dst/decorator/resolver/goast"
|
|
"github.com/dave/dst/decorator/resolver/guess"
|
|
"golang.org/x/tools/imports"
|
|
"golang.org/x/xerrors"
|
|
)
|
|
|
|
var (
|
|
funcs []querierFunction
|
|
funcByName map[string]struct{}
|
|
)
|
|
|
|
func init() {
|
|
var err error
|
|
funcs, err = readQuerierFunctions()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
funcByName = map[string]struct{}{}
|
|
for _, f := range funcs {
|
|
funcByName[f.Name] = struct{}{}
|
|
}
|
|
}
|
|
|
|
func main() {
|
|
err := run()
|
|
if err != nil {
|
|
_, _ = fmt.Fprintf(os.Stderr, "error: %s\n", err)
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
func run() error {
|
|
localPath, err := localFilePath()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
databasePath := filepath.Join(localPath, "..", "..", "..", "coderd", "database")
|
|
|
|
err = orderAndStubDatabaseFunctions(filepath.Join(databasePath, "dbmem", "dbmem.go"), "q", "FakeQuerier", func(params stubParams) string {
|
|
return `panic("not implemented")`
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("stub dbmem: %w", err)
|
|
}
|
|
|
|
err = orderAndStubDatabaseFunctions(filepath.Join(databasePath, "dbmetrics", "dbmetrics.go"), "m", "metricsStore", func(params stubParams) string {
|
|
return fmt.Sprintf(`
|
|
start := time.Now()
|
|
%s := m.s.%s(%s)
|
|
m.queryLatencies.WithLabelValues("%s").Observe(time.Since(start).Seconds())
|
|
return %s
|
|
`, params.Returns, params.FuncName, params.Parameters, params.FuncName, params.Returns)
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("stub dbmetrics: %w", err)
|
|
}
|
|
|
|
err = orderAndStubDatabaseFunctions(filepath.Join(databasePath, "dbauthz", "dbauthz.go"), "q", "querier", func(params stubParams) string {
|
|
return `panic("not implemented")`
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("stub dbauthz: %w", err)
|
|
}
|
|
|
|
err = generateUniqueConstraints()
|
|
if err != nil {
|
|
return xerrors.Errorf("generate unique constraints: %w", err)
|
|
}
|
|
|
|
err = generateForeignKeyConstraints()
|
|
if err != nil {
|
|
return xerrors.Errorf("generate foreign key constraints: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// generateUniqueConstraints generates the UniqueConstraint enum.
|
|
func generateUniqueConstraints() error {
|
|
localPath, err := localFilePath()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
databasePath := filepath.Join(localPath, "..", "..", "..", "coderd", "database")
|
|
|
|
dump, err := os.Open(filepath.Join(databasePath, "dump.sql"))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer dump.Close()
|
|
|
|
var uniqueConstraints []string
|
|
dumpScanner := bufio.NewScanner(dump)
|
|
query := ""
|
|
for dumpScanner.Scan() {
|
|
line := strings.TrimSpace(dumpScanner.Text())
|
|
switch {
|
|
case strings.HasPrefix(line, "--"):
|
|
case line == "":
|
|
case strings.HasSuffix(line, ";"):
|
|
query += line
|
|
if strings.Contains(query, "UNIQUE") || strings.Contains(query, "PRIMARY KEY") {
|
|
uniqueConstraints = append(uniqueConstraints, query)
|
|
}
|
|
query = ""
|
|
default:
|
|
query += line + " "
|
|
}
|
|
}
|
|
if err = dumpScanner.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
s := &bytes.Buffer{}
|
|
|
|
_, _ = fmt.Fprint(s, `// Code generated by scripts/dbgen/main.go. DO NOT EDIT.
|
|
package database
|
|
`)
|
|
_, _ = fmt.Fprint(s, `
|
|
// UniqueConstraint represents a named unique constraint on a table.
|
|
type UniqueConstraint string
|
|
|
|
// UniqueConstraint enums.
|
|
const (
|
|
`)
|
|
for _, query := range uniqueConstraints {
|
|
name := ""
|
|
switch {
|
|
case strings.Contains(query, "ALTER TABLE") && strings.Contains(query, "ADD CONSTRAINT"):
|
|
name = strings.Split(query, " ")[6]
|
|
case strings.Contains(query, "CREATE UNIQUE INDEX"):
|
|
name = strings.Split(query, " ")[3]
|
|
default:
|
|
return xerrors.Errorf("unknown unique constraint format: %s", query)
|
|
}
|
|
_, _ = fmt.Fprintf(s, "\tUnique%s UniqueConstraint = %q // %s\n", nameFromSnakeCase(name), name, query)
|
|
}
|
|
_, _ = fmt.Fprint(s, ")\n")
|
|
|
|
outputPath := filepath.Join(databasePath, "unique_constraint.go")
|
|
|
|
data, err := imports.Process(outputPath, s.Bytes(), &imports.Options{
|
|
Comments: true,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return os.WriteFile(outputPath, data, 0o600)
|
|
}
|
|
|
|
// generateForeignKeyConstraints generates the ForeignKeyConstraint enum.
|
|
func generateForeignKeyConstraints() error {
|
|
localPath, err := localFilePath()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
databasePath := filepath.Join(localPath, "..", "..", "..", "coderd", "database")
|
|
|
|
dump, err := os.Open(filepath.Join(databasePath, "dump.sql"))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer dump.Close()
|
|
|
|
var foreignKeyConstraints []string
|
|
dumpScanner := bufio.NewScanner(dump)
|
|
query := ""
|
|
for dumpScanner.Scan() {
|
|
line := strings.TrimSpace(dumpScanner.Text())
|
|
switch {
|
|
case strings.HasPrefix(line, "--"):
|
|
case line == "":
|
|
case strings.HasSuffix(line, ";"):
|
|
query += line
|
|
if strings.Contains(query, "FOREIGN KEY") {
|
|
foreignKeyConstraints = append(foreignKeyConstraints, query)
|
|
}
|
|
query = ""
|
|
default:
|
|
query += line + " "
|
|
}
|
|
}
|
|
|
|
if err := dumpScanner.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
s := &bytes.Buffer{}
|
|
|
|
_, _ = fmt.Fprint(s, `// Code generated by scripts/dbgen/main.go. DO NOT EDIT.
|
|
package database
|
|
`)
|
|
_, _ = fmt.Fprint(s, `
|
|
// ForeignKeyConstraint represents a named foreign key constraint on a table.
|
|
type ForeignKeyConstraint string
|
|
|
|
// ForeignKeyConstraint enums.
|
|
const (
|
|
`)
|
|
for _, query := range foreignKeyConstraints {
|
|
name := ""
|
|
switch {
|
|
case strings.Contains(query, "ALTER TABLE") && strings.Contains(query, "ADD CONSTRAINT"):
|
|
name = strings.Split(query, " ")[6]
|
|
default:
|
|
return xerrors.Errorf("unknown foreign key constraint format: %s", query)
|
|
}
|
|
_, _ = fmt.Fprintf(s, "\tForeignKey%s ForeignKeyConstraint = %q // %s\n", nameFromSnakeCase(name), name, query)
|
|
}
|
|
_, _ = fmt.Fprint(s, ")\n")
|
|
|
|
outputPath := filepath.Join(databasePath, "foreign_key_constraint.go")
|
|
|
|
data, err := imports.Process(outputPath, s.Bytes(), &imports.Options{
|
|
Comments: true,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return os.WriteFile(outputPath, data, 0o600)
|
|
}
|
|
|
|
type stubParams struct {
|
|
FuncName string
|
|
Parameters string
|
|
Returns string
|
|
}
|
|
|
|
// orderAndStubDatabaseFunctions orders the functions in the file and stubs them.
|
|
// This is useful for when we want to add a new function to the database and
|
|
// we want to make sure that it's ordered correctly.
|
|
//
|
|
// querierFuncs is a list of functions that are in the database.
|
|
// file is the path to the file that contains all the functions.
|
|
// structName is the name of the struct that contains the functions.
|
|
// stub is a string that will be used to stub the functions.
|
|
func orderAndStubDatabaseFunctions(filePath, receiver, structName string, stub func(params stubParams) string) error {
|
|
declByName := map[string]*dst.FuncDecl{}
|
|
packageName := filepath.Base(filepath.Dir(filePath))
|
|
|
|
contents, err := os.ReadFile(filePath)
|
|
if err != nil {
|
|
return xerrors.Errorf("read dbmem: %w", err)
|
|
}
|
|
|
|
// Required to preserve imports!
|
|
f, err := decorator.NewDecoratorWithImports(token.NewFileSet(), packageName, goast.New()).Parse(contents)
|
|
if err != nil {
|
|
return xerrors.Errorf("parse dbmem: %w", err)
|
|
}
|
|
|
|
pointer := false
|
|
for i := 0; i < len(f.Decls); i++ {
|
|
funcDecl, ok := f.Decls[i].(*dst.FuncDecl)
|
|
if !ok || funcDecl.Recv == nil || len(funcDecl.Recv.List) == 0 {
|
|
continue
|
|
}
|
|
|
|
var ident *dst.Ident
|
|
switch t := funcDecl.Recv.List[0].Type.(type) {
|
|
case *dst.Ident:
|
|
ident = t
|
|
case *dst.StarExpr:
|
|
ident, ok = t.X.(*dst.Ident)
|
|
if !ok {
|
|
continue
|
|
}
|
|
pointer = true
|
|
}
|
|
if ident == nil || ident.Name != structName {
|
|
continue
|
|
}
|
|
if _, ok := funcByName[funcDecl.Name.Name]; !ok {
|
|
continue
|
|
}
|
|
declByName[funcDecl.Name.Name] = funcDecl
|
|
f.Decls = append(f.Decls[:i], f.Decls[i+1:]...)
|
|
i--
|
|
}
|
|
|
|
for _, fn := range funcs {
|
|
var bodyStmts []dst.Stmt
|
|
|
|
// Add input validation, only relevant for dbmem.
|
|
if strings.Contains(filePath, "dbmem") && len(fn.Func.Params.List) == 2 && fn.Func.Params.List[1].Names[0].Name == "arg" {
|
|
/*
|
|
err := validateDatabaseType(arg)
|
|
if err != nil {
|
|
return database.User{}, err
|
|
}
|
|
*/
|
|
bodyStmts = append(bodyStmts, &dst.AssignStmt{
|
|
Lhs: []dst.Expr{dst.NewIdent("err")},
|
|
Tok: token.DEFINE,
|
|
Rhs: []dst.Expr{
|
|
&dst.CallExpr{
|
|
Fun: &dst.Ident{
|
|
Name: "validateDatabaseType",
|
|
},
|
|
Args: []dst.Expr{dst.NewIdent("arg")},
|
|
},
|
|
},
|
|
})
|
|
returnStmt := &dst.ReturnStmt{
|
|
Results: []dst.Expr{}, // Filled below.
|
|
}
|
|
bodyStmts = append(bodyStmts, &dst.IfStmt{
|
|
Cond: &dst.BinaryExpr{
|
|
X: dst.NewIdent("err"),
|
|
Op: token.NEQ,
|
|
Y: dst.NewIdent("nil"),
|
|
},
|
|
Body: &dst.BlockStmt{
|
|
List: []dst.Stmt{
|
|
returnStmt,
|
|
},
|
|
},
|
|
Decs: dst.IfStmtDecorations{
|
|
NodeDecs: dst.NodeDecs{
|
|
After: dst.EmptyLine,
|
|
},
|
|
},
|
|
})
|
|
for _, r := range fn.Func.Results.List {
|
|
switch typ := r.Type.(type) {
|
|
case *dst.StarExpr, *dst.ArrayType:
|
|
returnStmt.Results = append(returnStmt.Results, dst.NewIdent("nil"))
|
|
case *dst.Ident:
|
|
if typ.Path != "" {
|
|
returnStmt.Results = append(returnStmt.Results, dst.NewIdent(fmt.Sprintf("%s.%s{}", path.Base(typ.Path), typ.Name)))
|
|
} else {
|
|
switch typ.Name {
|
|
case "uint8", "uint16", "uint32", "uint64", "uint", "uintptr",
|
|
"int8", "int16", "int32", "int64", "int",
|
|
"byte", "rune",
|
|
"float32", "float64",
|
|
"complex64", "complex128":
|
|
returnStmt.Results = append(returnStmt.Results, dst.NewIdent("0"))
|
|
case "string":
|
|
returnStmt.Results = append(returnStmt.Results, dst.NewIdent("\"\""))
|
|
case "bool":
|
|
returnStmt.Results = append(returnStmt.Results, dst.NewIdent("false"))
|
|
case "error":
|
|
returnStmt.Results = append(returnStmt.Results, dst.NewIdent("err"))
|
|
default:
|
|
panic(fmt.Sprintf("unknown ident: %#v", r.Type))
|
|
}
|
|
}
|
|
default:
|
|
panic(fmt.Sprintf("unknown return type: %T", r.Type))
|
|
}
|
|
}
|
|
}
|
|
decl, ok := declByName[fn.Name]
|
|
if !ok {
|
|
typeName := structName
|
|
if pointer {
|
|
typeName = "*" + typeName
|
|
}
|
|
params := make([]string, 0)
|
|
if fn.Func.Params != nil {
|
|
for _, p := range fn.Func.Params.List {
|
|
for _, name := range p.Names {
|
|
params = append(params, name.Name)
|
|
}
|
|
}
|
|
}
|
|
returns := make([]string, 0)
|
|
if fn.Func.Results != nil {
|
|
for i := range fn.Func.Results.List {
|
|
returns = append(returns, fmt.Sprintf("r%d", i))
|
|
}
|
|
}
|
|
|
|
funcDecl, err := compileFuncDecl(stub(stubParams{
|
|
FuncName: fn.Name,
|
|
Parameters: strings.Join(params, ","),
|
|
Returns: strings.Join(returns, ","),
|
|
}))
|
|
if err != nil {
|
|
return xerrors.Errorf("compile func decl: %w", err)
|
|
}
|
|
|
|
// Not implemented!
|
|
decl = &dst.FuncDecl{
|
|
Name: dst.NewIdent(fn.Name),
|
|
Type: &dst.FuncType{
|
|
Func: true,
|
|
TypeParams: fn.Func.TypeParams,
|
|
Params: fn.Func.Params,
|
|
Results: fn.Func.Results,
|
|
Decs: fn.Func.Decs,
|
|
},
|
|
Recv: &dst.FieldList{
|
|
List: []*dst.Field{{
|
|
Names: []*dst.Ident{dst.NewIdent(receiver)},
|
|
Type: dst.NewIdent(typeName),
|
|
}},
|
|
},
|
|
Decs: dst.FuncDeclDecorations{
|
|
NodeDecs: dst.NodeDecs{
|
|
Before: dst.EmptyLine,
|
|
After: dst.EmptyLine,
|
|
},
|
|
},
|
|
Body: &dst.BlockStmt{
|
|
List: append(bodyStmts, funcDecl.Body.List...),
|
|
},
|
|
}
|
|
}
|
|
if ok {
|
|
for i, pm := range fn.Func.Params.List {
|
|
if len(decl.Type.Params.List) < i+1 {
|
|
decl.Type.Params.List = append(decl.Type.Params.List, pm)
|
|
}
|
|
if !reflect.DeepEqual(decl.Type.Params.List[i].Type, pm.Type) {
|
|
decl.Type.Params.List[i].Type = pm.Type
|
|
}
|
|
}
|
|
|
|
for i, res := range fn.Func.Results.List {
|
|
if len(decl.Type.Results.List) < i+1 {
|
|
decl.Type.Results.List = append(decl.Type.Results.List, res)
|
|
}
|
|
if !reflect.DeepEqual(decl.Type.Results.List[i].Type, res.Type) {
|
|
decl.Type.Results.List[i].Type = res.Type
|
|
}
|
|
}
|
|
}
|
|
f.Decls = append(f.Decls, decl)
|
|
}
|
|
|
|
// Required to preserve imports!
|
|
restorer := decorator.NewRestorerWithImports(packageName, guess.New())
|
|
restored, err := restorer.RestoreFile(f)
|
|
if err != nil {
|
|
return xerrors.Errorf("restore package: %w", err)
|
|
}
|
|
var buf bytes.Buffer
|
|
err = format.Node(&buf, restorer.Fset, restored)
|
|
if err != nil {
|
|
return xerrors.Errorf("format package: %w", err)
|
|
}
|
|
data, err := imports.Process(filePath, buf.Bytes(), &imports.Options{
|
|
Comments: true,
|
|
FormatOnly: true,
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("process imports: %w", err)
|
|
}
|
|
return os.WriteFile(filePath, data, 0o600)
|
|
}
|
|
|
|
// compileFuncDecl extracts the function declaration from the given code.
|
|
func compileFuncDecl(code string) (*dst.FuncDecl, error) {
|
|
f, err := decorator.Parse(fmt.Sprintf(`package stub
|
|
|
|
func stub() {
|
|
%s
|
|
}`, strings.TrimSpace(code)))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(f.Decls) != 1 {
|
|
return nil, xerrors.Errorf("expected 1 decl, got %d", len(f.Decls))
|
|
}
|
|
decl, ok := f.Decls[0].(*dst.FuncDecl)
|
|
if !ok {
|
|
return nil, xerrors.Errorf("expected func decl, got %T", f.Decls[0])
|
|
}
|
|
return decl, nil
|
|
}
|
|
|
|
type querierFunction struct {
|
|
// Name is the name of the function. Like "GetUserByID"
|
|
Name string
|
|
// Func is the AST representation of a function.
|
|
Func *dst.FuncType
|
|
}
|
|
|
|
// readQuerierFunctions reads the functions from coderd/database/querier.go
|
|
func readQuerierFunctions() ([]querierFunction, error) {
|
|
f, err := parseDBFile("querier.go")
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("parse querier.go: %w", err)
|
|
}
|
|
funcs, err := loadInterfaceFuncs(f, "sqlcQuerier")
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("load interface %s funcs: %w", "sqlcQuerier", err)
|
|
}
|
|
|
|
customFile, err := parseDBFile("modelqueries.go")
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("parse modelqueriers.go: %w", err)
|
|
}
|
|
// Custom funcs should be appended after the regular functions
|
|
customFuncs, err := loadInterfaceFuncs(customFile, "customQuerier")
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("load interface %s funcs: %w", "customQuerier", err)
|
|
}
|
|
|
|
return append(funcs, customFuncs...), nil
|
|
}
|
|
|
|
func parseDBFile(filename string) (*dst.File, error) {
|
|
localPath, err := localFilePath()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
querierPath := filepath.Join(localPath, "..", "..", "..", "coderd", "database", filename)
|
|
querierData, err := os.ReadFile(querierPath)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("read %s: %w", filename, err)
|
|
}
|
|
f, err := decorator.Parse(querierData)
|
|
return f, err
|
|
}
|
|
|
|
func loadInterfaceFuncs(f *dst.File, interfaceName string) ([]querierFunction, error) {
|
|
var querier *dst.InterfaceType
|
|
for _, decl := range f.Decls {
|
|
genDecl, ok := decl.(*dst.GenDecl)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
for _, spec := range genDecl.Specs {
|
|
typeSpec, ok := spec.(*dst.TypeSpec)
|
|
if !ok {
|
|
continue
|
|
}
|
|
// This is the name of the interface. If that ever changes,
|
|
// this will need to be updated.
|
|
if typeSpec.Name.Name != interfaceName {
|
|
continue
|
|
}
|
|
querier, ok = typeSpec.Type.(*dst.InterfaceType)
|
|
if !ok {
|
|
return nil, xerrors.Errorf("unexpected sqlcQuerier type: %T", typeSpec.Type)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
if querier == nil {
|
|
return nil, xerrors.Errorf("querier not found")
|
|
}
|
|
funcs := []querierFunction{}
|
|
allMethods := interfaceMethods(querier)
|
|
for _, method := range allMethods {
|
|
funcType, ok := method.Type.(*dst.FuncType)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
for _, t := range []*dst.FieldList{funcType.Params, funcType.Results, funcType.TypeParams} {
|
|
if t == nil {
|
|
continue
|
|
}
|
|
for _, f := range t.List {
|
|
var ident *dst.Ident
|
|
switch t := f.Type.(type) {
|
|
case *dst.Ident:
|
|
ident = t
|
|
case *dst.StarExpr:
|
|
ident, ok = t.X.(*dst.Ident)
|
|
if !ok {
|
|
continue
|
|
}
|
|
case *dst.SelectorExpr:
|
|
ident, ok = t.X.(*dst.Ident)
|
|
if !ok {
|
|
continue
|
|
}
|
|
case *dst.ArrayType:
|
|
ident, ok = t.Elt.(*dst.Ident)
|
|
if !ok {
|
|
continue
|
|
}
|
|
}
|
|
if ident == nil {
|
|
continue
|
|
}
|
|
// If the type is exported then we should be able to find it
|
|
// in the database package!
|
|
if !ident.IsExported() {
|
|
continue
|
|
}
|
|
ident.Path = "github.com/coder/coder/v2/coderd/database"
|
|
}
|
|
}
|
|
|
|
funcs = append(funcs, querierFunction{
|
|
Name: method.Names[0].Name,
|
|
Func: funcType,
|
|
})
|
|
}
|
|
return funcs, nil
|
|
}
|
|
|
|
// localFilePath returns the location of `main.go` in the dbgen package.
|
|
func localFilePath() (string, error) {
|
|
_, filename, _, ok := runtime.Caller(0)
|
|
if !ok {
|
|
return "", xerrors.Errorf("failed to get caller")
|
|
}
|
|
return filename, nil
|
|
}
|
|
|
|
// nameFromSnakeCase converts snake_case to CamelCase.
|
|
func nameFromSnakeCase(s string) string {
|
|
var ret string
|
|
for _, ss := range strings.Split(s, "_") {
|
|
switch ss {
|
|
case "id":
|
|
ret += "ID"
|
|
case "ids":
|
|
ret += "IDs"
|
|
case "jwt":
|
|
ret += "JWT"
|
|
case "idx":
|
|
ret += "Index"
|
|
case "api":
|
|
ret += "API"
|
|
case "uuid":
|
|
ret += "UUID"
|
|
case "gitsshkeys":
|
|
ret += "GitSSHKeys"
|
|
case "fkey":
|
|
// ignore
|
|
default:
|
|
ret += strings.Title(ss)
|
|
}
|
|
}
|
|
return ret
|
|
}
|
|
|
|
// interfaceMethods returns all embedded methods of an interface.
|
|
func interfaceMethods(i *dst.InterfaceType) []*dst.Field {
|
|
var allMethods []*dst.Field
|
|
for _, field := range i.Methods.List {
|
|
switch fieldType := field.Type.(type) {
|
|
case *dst.FuncType:
|
|
allMethods = append(allMethods, field)
|
|
case *dst.InterfaceType:
|
|
allMethods = append(allMethods, interfaceMethods(fieldType)...)
|
|
case *dst.Ident:
|
|
// Embedded interfaces are Idents -> TypeSpec -> InterfaceType
|
|
// If the embedded interface is not in the parsed file, then
|
|
// the Obj will be nil.
|
|
if fieldType.Obj != nil {
|
|
objDecl, ok := fieldType.Obj.Decl.(*dst.TypeSpec)
|
|
if ok {
|
|
isInterface, ok := objDecl.Type.(*dst.InterfaceType)
|
|
if ok {
|
|
allMethods = append(allMethods, interfaceMethods(isInterface)...)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return allMethods
|
|
}
|