diff --git a/mockgen/generic.go b/mockgen/api/generic.go similarity index 99% rename from mockgen/generic.go rename to mockgen/api/generic.go index c2289c2..522b3fd 100644 --- a/mockgen/generic.go +++ b/mockgen/api/generic.go @@ -5,7 +5,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package main +package api import ( "errors" diff --git a/mockgen/api/mockgen.go b/mockgen/api/mockgen.go new file mode 100644 index 0000000..20025d3 --- /dev/null +++ b/mockgen/api/mockgen.go @@ -0,0 +1,868 @@ +// Copyright 2010 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// MockGen generates mock implementations of Go interfaces. +package api + +// TODO: This does not support recursive embedded interfaces. +// TODO: This does not support embedding package-local interfaces in a separate file. + +import ( + "bytes" + "encoding/json" + "errors" + "flag" + "fmt" + "go/token" + "log" + "os" + "os/exec" + "path" + "path/filepath" + "runtime" + "sort" + "strconv" + "strings" + "unicode" + + "golang.org/x/mod/modfile" + toolsimports "golang.org/x/tools/imports" + + "go.uber.org/mock/mockgen/model" +) + +const ( + gomockImportPath = "go.uber.org/mock/gomock" +) + +var ( + version = "" + commit = "none" + date = "unknown" +) + +type Config struct { + Source string + Destination string + MockNames string + PackageOut string + SelfPackage string + WriteCmdComment bool + WritePkgComment bool + WriteSourceComment bool + WriteGenerateDirective bool + CopyrightFile string + Typed bool + Imports string + AuxFiles string + ExcludeInterfaces string + DebugParser bool + ShowVersion bool +} + +// Generate generates mock structs based on the given configuration. +func Generate(c Config) { + if c.ShowVersion { + printVersion() + return + } + + var pkg *model.Package + var err error + var packageName string + if c.Source != "" { + pkg, err = sourceMode(c) + } else { + packageName = flag.Arg(0) + interfaces := strings.Split(flag.Arg(1), ",") + if packageName == "." { + dir, err := os.Getwd() + if err != nil { + log.Fatalf("Get current directory failed: %v", err) + } + packageName, err = packageNameOfDir(dir) + if err != nil { + log.Fatalf("Parse package name failed: %v", err) + } + } + pkg, err = reflectMode(packageName, interfaces) + } + if err != nil { + log.Fatalf("Loading input failed: %v", err) + } + + if c.DebugParser { + pkg.Print(os.Stdout) + return + } + + outputPackageName := c.PackageOut + if outputPackageName == "" { + // pkg.Name in reflect mode is the base name of the import path, + // which might have characters that are illegal to have in package names. + outputPackageName = "mock_" + sanitize(pkg.Name) + } + + // outputPackagePath represents the fully qualified name of the package of + // the generated code. Its purposes are to prevent the module from importing + // itself and to prevent qualifying type names that come from its own + // package (i.e. if there is a type called X then we want to print "X" not + // "package.X" since "package" is this package). This can happen if the mock + // is output into an already existing package. + outputPackagePath := c.SelfPackage + if outputPackagePath == "" && c.Destination != "" { + dstPath, err := filepath.Abs(filepath.Dir(c.Destination)) + if err == nil { + pkgPath, err := parsePackageImport(dstPath) + if err == nil { + outputPackagePath = pkgPath + } else { + log.Println("Unable to infer -self_package from destination file path:", err) + } + } else { + log.Println("Unable to determine destination file path:", err) + } + } + + g := new(generator) + if c.Source != "" { + g.filename = c.Source + } else { + g.srcPackage = packageName + g.srcInterfaces = flag.Arg(1) + } + g.destination = c.Destination + + if c.MockNames != "" { + g.mockNames = parseMockNames(c.MockNames) + } + if c.CopyrightFile != "" { + header, err := os.ReadFile(c.CopyrightFile) + if err != nil { + log.Fatalf("Failed reading copyright file: %v", err) + } + + g.copyrightHeader = string(header) + } + if err := g.Generate(pkg, outputPackageName, outputPackagePath); err != nil { + log.Fatalf("Failed generating mock: %v", err) + } + output := g.Output() + dst := os.Stdout + if len(c.Destination) > 0 { + if err := os.MkdirAll(filepath.Dir(c.Destination), os.ModePerm); err != nil { + log.Fatalf("Unable to create directory: %v", err) + } + existing, err := os.ReadFile(c.Destination) + if err != nil && !errors.Is(err, os.ErrNotExist) { + log.Fatalf("Failed reading pre-exiting destination file: %v", err) + } + if len(existing) == len(output) && bytes.Equal(existing, output) { + return + } + f, err := os.Create(c.Destination) + if err != nil { + log.Fatalf("Failed opening destination file: %v", err) + } + defer f.Close() + dst = f + } + if _, err := dst.Write(output); err != nil { + log.Fatalf("Failed writing to destination: %v", err) + } +} + +func parseMockNames(names string) map[string]string { + mocksMap := make(map[string]string) + for _, kv := range strings.Split(names, ",") { + parts := strings.SplitN(kv, "=", 2) + if len(parts) != 2 || parts[1] == "" { + log.Fatalf("bad mock names spec: %v", kv) + } + mocksMap[parts[0]] = parts[1] + } + return mocksMap +} + +func parseExcludeInterfaces(names string) map[string]struct{} { + splitNames := strings.Split(names, ",") + namesSet := make(map[string]struct{}, len(splitNames)) + for _, name := range splitNames { + if name == "" { + continue + } + + namesSet[name] = struct{}{} + } + + if len(namesSet) == 0 { + return nil + } + + return namesSet +} + +type generator struct { + c Config + + buf bytes.Buffer + indent string + mockNames map[string]string // may be empty + filename string // may be empty + destination string // may be empty + srcPackage, srcInterfaces string // may be empty + copyrightHeader string + + packageMap map[string]string // map from import path to package name +} + +func (g *generator) p(format string, args ...any) { + _, _ = fmt.Fprintf(&g.buf, g.indent+format+"\n", args...) +} + +func (g *generator) in() { + g.indent += "\t" +} + +func (g *generator) out() { + if len(g.indent) > 0 { + g.indent = g.indent[0 : len(g.indent)-1] + } +} + +// sanitize cleans up a string to make a suitable package name. +func sanitize(s string) string { + t := "" + for _, r := range s { + if t == "" { + if unicode.IsLetter(r) || r == '_' { + t += string(r) + continue + } + } else { + if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' { + t += string(r) + continue + } + } + t += "_" + } + if t == "_" { + t = "x" + } + return t +} + +func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPackagePath string) error { + if outputPkgName != pkg.Name && g.c.SelfPackage == "" { + // reset outputPackagePath if it's not passed in through -self_package + outputPackagePath = "" + } + + if g.copyrightHeader != "" { + lines := strings.Split(g.copyrightHeader, "\n") + for _, line := range lines { + g.p("// %s", line) + } + g.p("") + } + + g.p("// Code generated by MockGen. DO NOT EDIT.") + if g.c.WriteSourceComment { + if g.filename != "" { + g.p("// Source: %v", g.filename) + } else { + g.p("// Source: %v (interfaces: %v)", g.srcPackage, g.srcInterfaces) + } + } + if g.c.WriteCmdComment { + g.p("//") + g.p("// Generated by this command:") + g.p("//") + // only log the name of the executable, not the full path + name := filepath.Base(os.Args[0]) + if runtime.GOOS == "windows" { + name = strings.TrimSuffix(name, ".exe") + } + g.p("//\t%v", strings.Join(append([]string{name}, os.Args[1:]...), " ")) + g.p("//") + } + + // Get all required imports, and generate unique names for them all. + im := pkg.Imports() + im[gomockImportPath] = true + + // Only import reflect if it's used. We only use reflect in mocked methods + // so only import if any of the mocked interfaces have methods. + for _, intf := range pkg.Interfaces { + if len(intf.Methods) > 0 { + im["reflect"] = true + break + } + } + + // Sort keys to make import alias generation predictable + sortedPaths := make([]string, len(im)) + x := 0 + for pth := range im { + sortedPaths[x] = pth + x++ + } + sort.Strings(sortedPaths) + + packagesName := createPackageMap(sortedPaths) + + definedImports := make(map[string]string, len(im)) + if g.c.Imports != "" { + for _, kv := range strings.Split(g.c.Imports, ",") { + eq := strings.Index(kv, "=") + if k, v := kv[:eq], kv[eq+1:]; k != "." { + definedImports[v] = k + } + } + } + + g.packageMap = make(map[string]string, len(im)) + localNames := make(map[string]bool, len(im)) + for _, pth := range sortedPaths { + base, ok := packagesName[pth] + if !ok { + base = sanitize(path.Base(pth)) + } + + // Local names for an imported package can usually be the basename of the import path. + // A couple of situations don't permit that, such as duplicate local names + // (e.g. importing "html/template" and "text/template"), or where the basename is + // a keyword (e.g. "foo/case") or when defining a name for that by using the -imports flag. + // try base0, base1, ... + pkgName := base + + if _, ok := definedImports[base]; ok { + pkgName = definedImports[base] + } + + i := 0 + for localNames[pkgName] || token.Lookup(pkgName).IsKeyword() || pkgName == "any" { + pkgName = base + strconv.Itoa(i) + i++ + } + + // Avoid importing package if source pkg == output pkg + if pth == pkg.PkgPath && outputPackagePath == pkg.PkgPath { + continue + } + + g.packageMap[pth] = pkgName + localNames[pkgName] = true + } + + // Ensure there is an empty line between “generated by” block and + // package documentation comments to follow the recommendations: + // https://go.dev/wiki/CodeReviewComments#package-comments + // That is, “generated by” should not be a package comment. + g.p("") + + if g.c.WritePkgComment { + g.p("// Package %v is a generated GoMock package.", outputPkgName) + } + g.p("package %v", outputPkgName) + g.p("") + g.p("import (") + g.in() + for pkgPath, pkgName := range g.packageMap { + if pkgPath == outputPackagePath { + continue + } + g.p("%v %q", pkgName, pkgPath) + } + for _, pkgPath := range pkg.DotImports { + g.p(". %q", pkgPath) + } + g.out() + g.p(")") + + if g.c.WriteGenerateDirective { + g.p("//go:generate %v", strings.Join(os.Args, " ")) + } + + for _, intf := range pkg.Interfaces { + if err := g.GenerateMockInterface(intf, outputPackagePath); err != nil { + return err + } + } + + return nil +} + +// The name of the mock type to use for the given interface identifier. +func (g *generator) mockName(typeName string) string { + if mockName, ok := g.mockNames[typeName]; ok { + return mockName + } + + return "Mock" + typeName +} + +// formattedTypeParams returns a long and short form of type param info used for +// printing. If analyzing a interface with type param [I any, O any] the result +// will be: +// "[I any, O any]", "[I, O]" +func (g *generator) formattedTypeParams(it *model.Interface, pkgOverride string) (string, string) { + if len(it.TypeParams) == 0 { + return "", "" + } + var long, short strings.Builder + long.WriteString("[") + short.WriteString("[") + for i, v := range it.TypeParams { + if i != 0 { + long.WriteString(", ") + short.WriteString(", ") + } + long.WriteString(v.Name) + short.WriteString(v.Name) + long.WriteString(fmt.Sprintf(" %s", v.Type.String(g.packageMap, pkgOverride))) + } + long.WriteString("]") + short.WriteString("]") + return long.String(), short.String() +} + +func (g *generator) GenerateMockInterface(intf *model.Interface, outputPackagePath string) error { + mockType := g.mockName(intf.Name) + longTp, shortTp := g.formattedTypeParams(intf, outputPackagePath) + + g.p("") + g.p("// %v is a mock of %v interface.", mockType, intf.Name) + g.p("type %v%v struct {", mockType, longTp) + g.in() + g.p("ctrl *gomock.Controller") + g.p("recorder *%vMockRecorder%v", mockType, shortTp) + g.out() + g.p("}") + g.p("") + + g.p("// %vMockRecorder is the mock recorder for %v.", mockType, mockType) + g.p("type %vMockRecorder%v struct {", mockType, longTp) + g.in() + g.p("mock *%v%v", mockType, shortTp) + g.out() + g.p("}") + g.p("") + + g.p("// New%v creates a new mock instance.", mockType) + g.p("func New%v%v(ctrl *gomock.Controller) *%v%v {", mockType, longTp, mockType, shortTp) + g.in() + g.p("mock := &%v%v{ctrl: ctrl}", mockType, shortTp) + g.p("mock.recorder = &%vMockRecorder%v{mock}", mockType, shortTp) + g.p("return mock") + g.out() + g.p("}") + g.p("") + + // XXX: possible name collision here if someone has EXPECT in their interface. + g.p("// EXPECT returns an object that allows the caller to indicate expected use.") + g.p("func (m *%v%v) EXPECT() *%vMockRecorder%v {", mockType, shortTp, mockType, shortTp) + g.in() + g.p("return m.recorder") + g.out() + g.p("}") + + // XXX: possible name collision here if someone has ISGOMOCK in their interface. + g.p("// ISGOMOCK indicates that this struct is a gomock mock.") + g.p("func (m *%v%v) ISGOMOCK() struct{} {", mockType, shortTp) + g.in() + g.p("return struct{}{}") + g.out() + g.p("}") + + g.GenerateMockMethods(mockType, intf, outputPackagePath, longTp, shortTp, g.c.Typed) + + return nil +} + +type byMethodName []*model.Method + +func (b byMethodName) Len() int { return len(b) } +func (b byMethodName) Swap(i, j int) { b[i], b[j] = b[j], b[i] } +func (b byMethodName) Less(i, j int) bool { return b[i].Name < b[j].Name } + +func (g *generator) GenerateMockMethods(mockType string, intf *model.Interface, pkgOverride, longTp, shortTp string, typed bool) { + sort.Sort(byMethodName(intf.Methods)) + for _, m := range intf.Methods { + g.p("") + _ = g.GenerateMockMethod(mockType, m, pkgOverride, shortTp) + g.p("") + _ = g.GenerateMockRecorderMethod(intf, m, shortTp, typed) + if typed { + g.p("") + _ = g.GenerateMockReturnCallMethod(intf, m, pkgOverride, longTp, shortTp) + } + } +} + +func makeArgString(argNames, argTypes []string) string { + args := make([]string, len(argNames)) + for i, name := range argNames { + // specify the type only once for consecutive args of the same type + if i+1 < len(argTypes) && argTypes[i] == argTypes[i+1] { + args[i] = name + } else { + args[i] = name + " " + argTypes[i] + } + } + return strings.Join(args, ", ") +} + +// GenerateMockMethod generates a mock method implementation. +// If non-empty, pkgOverride is the package in which unqualified types reside. +func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOverride, shortTp string) error { + argNames := g.getArgNames(m, true /* in */) + argTypes := g.getArgTypes(m, pkgOverride, true /* in */) + argString := makeArgString(argNames, argTypes) + + rets := make([]string, len(m.Out)) + for i, p := range m.Out { + rets[i] = p.Type.String(g.packageMap, pkgOverride) + } + retString := strings.Join(rets, ", ") + if len(rets) > 1 { + retString = "(" + retString + ")" + } + if retString != "" { + retString = " " + retString + } + + ia := newIdentifierAllocator(argNames) + idRecv := ia.allocateIdentifier("m") + + g.p("// %v mocks base method.", m.Name) + g.p("func (%v *%v%v) %v(%v)%v {", idRecv, mockType, shortTp, m.Name, argString, retString) + g.in() + g.p("%s.ctrl.T.Helper()", idRecv) + + var callArgs string + if m.Variadic == nil { + if len(argNames) > 0 { + callArgs = ", " + strings.Join(argNames, ", ") + } + } else { + // Non-trivial. The generated code must build a []any, + // but the variadic argument may be any type. + idVarArgs := ia.allocateIdentifier("varargs") + idVArg := ia.allocateIdentifier("a") + g.p("%s := []any{%s}", idVarArgs, strings.Join(argNames[:len(argNames)-1], ", ")) + g.p("for _, %s := range %s {", idVArg, argNames[len(argNames)-1]) + g.in() + g.p("%s = append(%s, %s)", idVarArgs, idVarArgs, idVArg) + g.out() + g.p("}") + callArgs = ", " + idVarArgs + "..." + } + if len(m.Out) == 0 { + g.p(`%v.ctrl.Call(%v, %q%v)`, idRecv, idRecv, m.Name, callArgs) + } else { + idRet := ia.allocateIdentifier("ret") + g.p(`%v := %v.ctrl.Call(%v, %q%v)`, idRet, idRecv, idRecv, m.Name, callArgs) + + // Go does not allow "naked" type assertions on nil values, so we use the two-value form here. + // The value of that is either (x.(T), true) or (Z, false), where Z is the zero value for T. + // Happily, this coincides with the semantics we want here. + retNames := make([]string, len(rets)) + for i, t := range rets { + retNames[i] = ia.allocateIdentifier(fmt.Sprintf("ret%d", i)) + g.p("%s, _ := %s[%d].(%s)", retNames[i], idRet, i, t) + } + g.p("return " + strings.Join(retNames, ", ")) + } + + g.out() + g.p("}") + return nil +} + +func (g *generator) GenerateMockRecorderMethod(intf *model.Interface, m *model.Method, shortTp string, typed bool) error { + mockType := g.mockName(intf.Name) + argNames := g.getArgNames(m, true) + + var argString string + if m.Variadic == nil { + argString = strings.Join(argNames, ", ") + } else { + argString = strings.Join(argNames[:len(argNames)-1], ", ") + } + if argString != "" { + argString += " any" + } + + if m.Variadic != nil { + if argString != "" { + argString += ", " + } + argString += fmt.Sprintf("%s ...any", argNames[len(argNames)-1]) + } + + ia := newIdentifierAllocator(argNames) + idRecv := ia.allocateIdentifier("mr") + + g.p("// %v indicates an expected call of %v.", m.Name, m.Name) + if typed { + g.p("func (%s *%vMockRecorder%v) %v(%v) *%s%sCall%s {", idRecv, mockType, shortTp, m.Name, argString, mockType, m.Name, shortTp) + } else { + g.p("func (%s *%vMockRecorder%v) %v(%v) *gomock.Call {", idRecv, mockType, shortTp, m.Name, argString) + } + + g.in() + g.p("%s.mock.ctrl.T.Helper()", idRecv) + + var callArgs string + if m.Variadic == nil { + if len(argNames) > 0 { + callArgs = ", " + strings.Join(argNames, ", ") + } + } else { + if len(argNames) == 1 { + // Easy: just use ... to push the arguments through. + callArgs = ", " + argNames[0] + "..." + } else { + // Hard: create a temporary slice. + idVarArgs := ia.allocateIdentifier("varargs") + g.p("%s := append([]any{%s}, %s...)", + idVarArgs, + strings.Join(argNames[:len(argNames)-1], ", "), + argNames[len(argNames)-1]) + callArgs = ", " + idVarArgs + "..." + } + } + if typed { + g.p(`call := %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, shortTp, m.Name, callArgs) + g.p(`return &%s%sCall%s{Call: call}`, mockType, m.Name, shortTp) + } else { + g.p(`return %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, shortTp, m.Name, callArgs) + } + + g.out() + g.p("}") + return nil +} + +func (g *generator) GenerateMockReturnCallMethod(intf *model.Interface, m *model.Method, pkgOverride, longTp, shortTp string) error { + mockType := g.mockName(intf.Name) + argNames := g.getArgNames(m, true /* in */) + retNames := g.getArgNames(m, false /* out */) + argTypes := g.getArgTypes(m, pkgOverride, true /* in */) + retTypes := g.getArgTypes(m, pkgOverride, false /* out */) + argString := strings.Join(argTypes, ", ") + + rets := make([]string, len(m.Out)) + for i, p := range m.Out { + rets[i] = p.Type.String(g.packageMap, pkgOverride) + } + + var retString string + switch { + case len(rets) == 1: + retString = " " + rets[0] + case len(rets) > 1: + retString = " (" + strings.Join(rets, ", ") + ")" + } + + ia := newIdentifierAllocator(argNames) + idRecv := ia.allocateIdentifier("c") + + recvStructName := mockType + m.Name + + g.p("// %s%sCall wrap *gomock.Call", mockType, m.Name) + g.p("type %s%sCall%s struct{", mockType, m.Name, longTp) + g.in() + g.p("*gomock.Call") + g.out() + g.p("}") + + g.p("// Return rewrite *gomock.Call.Return") + g.p("func (%s *%sCall%s) Return(%v) *%sCall%s {", idRecv, recvStructName, shortTp, makeArgString(retNames, retTypes), recvStructName, shortTp) + g.in() + var retArgs string + if len(retNames) > 0 { + retArgs = strings.Join(retNames, ", ") + } + g.p(`%s.Call = %v.Call.Return(%v)`, idRecv, idRecv, retArgs) + g.p("return %s", idRecv) + g.out() + g.p("}") + + g.p("// Do rewrite *gomock.Call.Do") + g.p("func (%s *%sCall%s) Do(f func(%v)%v) *%sCall%s {", idRecv, recvStructName, shortTp, argString, retString, recvStructName, shortTp) + g.in() + g.p(`%s.Call = %v.Call.Do(f)`, idRecv, idRecv) + g.p("return %s", idRecv) + g.out() + g.p("}") + + g.p("// DoAndReturn rewrite *gomock.Call.DoAndReturn") + g.p("func (%s *%sCall%s) DoAndReturn(f func(%v)%v) *%sCall%s {", idRecv, recvStructName, shortTp, argString, retString, recvStructName, shortTp) + g.in() + g.p(`%s.Call = %v.Call.DoAndReturn(f)`, idRecv, idRecv) + g.p("return %s", idRecv) + g.out() + g.p("}") + return nil +} + +func (g *generator) getArgNames(m *model.Method, in bool) []string { + var params []*model.Parameter + if in { + params = m.In + } else { + params = m.Out + } + argNames := make([]string, len(params)) + for i, p := range params { + name := p.Name + if name == "" || name == "_" { + name = fmt.Sprintf("arg%d", i) + } + argNames[i] = name + } + if m.Variadic != nil && in { + name := m.Variadic.Name + if name == "" { + name = fmt.Sprintf("arg%d", len(params)) + } + argNames = append(argNames, name) + } + return argNames +} + +func (g *generator) getArgTypes(m *model.Method, pkgOverride string, in bool) []string { + var params []*model.Parameter + if in { + params = m.In + } else { + params = m.Out + } + argTypes := make([]string, len(params)) + for i, p := range params { + argTypes[i] = p.Type.String(g.packageMap, pkgOverride) + } + if m.Variadic != nil { + argTypes = append(argTypes, "..."+m.Variadic.Type.String(g.packageMap, pkgOverride)) + } + return argTypes +} + +type identifierAllocator map[string]struct{} + +func newIdentifierAllocator(taken []string) identifierAllocator { + a := make(identifierAllocator, len(taken)) + for _, s := range taken { + a[s] = struct{}{} + } + return a +} + +func (o identifierAllocator) allocateIdentifier(want string) string { + id := want + for i := 2; ; i++ { + if _, ok := o[id]; !ok { + o[id] = struct{}{} + return id + } + id = want + "_" + strconv.Itoa(i) + } +} + +// Output returns the generator's output, formatted in the standard Go style. +func (g *generator) Output() []byte { + src, err := toolsimports.Process(g.destination, g.buf.Bytes(), nil) + if err != nil { + log.Fatalf("Failed to format generated source code: %s\n%s", err, g.buf.String()) + } + return src +} + +// createPackageMap returns a map of import path to package name +// for specified importPaths. +func createPackageMap(importPaths []string) map[string]string { + var pkg struct { + Name string + ImportPath string + } + pkgMap := make(map[string]string) + b := bytes.NewBuffer(nil) + args := []string{"list", "-json"} + args = append(args, importPaths...) + cmd := exec.Command("go", args...) + cmd.Stdout = b + cmd.Run() + dec := json.NewDecoder(b) + for dec.More() { + err := dec.Decode(&pkg) + if err != nil { + log.Printf("failed to decode 'go list' output: %v", err) + continue + } + pkgMap[pkg.ImportPath] = pkg.Name + } + return pkgMap +} + +func printVersion() { + if version != "" { + fmt.Printf("v%s\nCommit: %s\nDate: %s\n", version, commit, date) + } else { + printModuleVersion() + } +} + +// parseImportPackage get package import path via source file +// an alternative implementation is to use: +// cfg := &packages.Config{Mode: packages.NeedName, Tests: true, Dir: srcDir} +// pkgs, err := packages.Load(cfg, "file="+source) +// However, it will call "go list" and slow down the performance +func parsePackageImport(srcDir string) (string, error) { + moduleMode := os.Getenv("GO111MODULE") + // trying to find the module + if moduleMode != "off" { + currentDir := srcDir + for { + dat, err := os.ReadFile(filepath.Join(currentDir, "go.mod")) + if os.IsNotExist(err) { + if currentDir == filepath.Dir(currentDir) { + // at the root + break + } + currentDir = filepath.Dir(currentDir) + continue + } else if err != nil { + return "", err + } + modulePath := modfile.ModulePath(dat) + return filepath.ToSlash(filepath.Join(modulePath, strings.TrimPrefix(srcDir, currentDir))), nil + } + } + // fall back to GOPATH mode + goPaths := os.Getenv("GOPATH") + if goPaths == "" { + return "", fmt.Errorf("GOPATH is not set") + } + goPathList := strings.Split(goPaths, string(os.PathListSeparator)) + for _, goPath := range goPathList { + sourceRoot := filepath.Join(goPath, "src") + string(os.PathSeparator) + if strings.HasPrefix(srcDir, sourceRoot) { + return filepath.ToSlash(strings.TrimPrefix(srcDir, sourceRoot)), nil + } + } + return "", errOutsideGoPath +} diff --git a/mockgen/mockgen_test.go b/mockgen/api/mockgen_test.go similarity index 99% rename from mockgen/mockgen_test.go rename to mockgen/api/mockgen_test.go index 6b17127..cf55ac0 100644 --- a/mockgen/mockgen_test.go +++ b/mockgen/api/mockgen_test.go @@ -1,4 +1,4 @@ -package main +package api import ( "fmt" diff --git a/mockgen/parse.go b/mockgen/api/parse.go similarity index 97% rename from mockgen/parse.go rename to mockgen/api/parse.go index f43321c..9dd1c6d 100644 --- a/mockgen/parse.go +++ b/mockgen/api/parse.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package main +package api // This file contains the model construction by parsing source files. @@ -36,8 +36,8 @@ import ( ) // sourceMode generates mocks via source file. -func sourceMode(source string) (*model.Package, error) { - srcDir, err := filepath.Abs(filepath.Dir(source)) +func sourceMode(c Config) (*model.Package, error) { + srcDir, err := filepath.Abs(filepath.Dir(c.Source)) if err != nil { return nil, fmt.Errorf("failed getting source directory: %v", err) } @@ -48,9 +48,9 @@ func sourceMode(source string) (*model.Package, error) { } fs := token.NewFileSet() - file, err := parser.ParseFile(fs, source, nil, 0) + file, err := parser.ParseFile(fs, c.Source, nil, 0) if err != nil { - return nil, fmt.Errorf("failed parsing source file %v: %v", source, err) + return nil, fmt.Errorf("failed parsing source file %v: %v", c.Source, err) } p := &fileParser{ @@ -63,8 +63,8 @@ func sourceMode(source string) (*model.Package, error) { // Handle -imports. dotImports := make(map[string]bool) - if *imports != "" { - for _, kv := range strings.Split(*imports, ",") { + if c.Imports != "" { + for _, kv := range strings.Split(c.Imports, ",") { eq := strings.Index(kv, "=") k, v := kv[:eq], kv[eq+1:] if k == "." { @@ -75,12 +75,12 @@ func sourceMode(source string) (*model.Package, error) { } } - if *excludeInterfaces != "" { - p.excludeNamesSet = parseExcludeInterfaces(*excludeInterfaces) + if c.ExcludeInterfaces != "" { + p.excludeNamesSet = parseExcludeInterfaces(c.ExcludeInterfaces) } // Handle -aux_files. - if err := p.parseAuxFiles(*auxFiles); err != nil { + if err := p.parseAuxFiles(c.AuxFiles); err != nil { return nil, err } p.addAuxInterfacesFromFile(packageImport, file) // this file diff --git a/mockgen/parse_test.go b/mockgen/api/parse_test.go similarity index 89% rename from mockgen/parse_test.go rename to mockgen/api/parse_test.go index 3c4ba4c..d6a6ca0 100644 --- a/mockgen/parse_test.go +++ b/mockgen/api/parse_test.go @@ -1,4 +1,4 @@ -package main +package api import ( "go/parser" @@ -8,7 +8,7 @@ import ( func TestFileParser_ParseFile(t *testing.T) { fs := token.NewFileSet() - file, err := parser.ParseFile(fs, "internal/tests/custom_package_name/greeter/greeter.go", nil, 0) + file, err := parser.ParseFile(fs, "../internal/tests/custom_package_name/greeter/greeter.go", nil, 0) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -39,7 +39,7 @@ func TestFileParser_ParseFile(t *testing.T) { func TestFileParser_ParsePackage(t *testing.T) { fs := token.NewFileSet() - _, err := parser.ParseFile(fs, "internal/tests/custom_package_name/greeter/greeter.go", nil, 0) + _, err := parser.ParseFile(fs, "../internal/tests/custom_package_name/greeter/greeter.go", nil, 0) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -60,7 +60,7 @@ func TestFileParser_ParsePackage(t *testing.T) { func TestImportsOfFile(t *testing.T) { fs := token.NewFileSet() - file, err := parser.ParseFile(fs, "internal/tests/custom_package_name/greeter/greeter.go", nil, 0) + file, err := parser.ParseFile(fs, "../internal/tests/custom_package_name/greeter/greeter.go", nil, 0) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -107,15 +107,17 @@ func checkGreeterImports(t *testing.T, imports map[string]importedPackage) { } func Benchmark_parseFile(b *testing.B) { - source := "internal/tests/performance/big_interface/big_interface.go" + source := "../internal/tests/performance/big_interface/big_interface.go" for n := 0; n < b.N; n++ { - sourceMode(source) + sourceMode(Config{ + Source: source, + }) } } func TestParseArrayWithConstLength(t *testing.T) { fs := token.NewFileSet() - srcDir := "internal/tests/const_array_length/input.go" + srcDir := "../internal/tests/const_array_length/input.go" file, err := parser.ParseFile(fs, srcDir, nil, 0) if err != nil { diff --git a/mockgen/reflect.go b/mockgen/api/reflect.go similarity index 98% rename from mockgen/reflect.go rename to mockgen/api/reflect.go index ca80ebb..cc838fb 100644 --- a/mockgen/reflect.go +++ b/mockgen/api/reflect.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package main +package api // This file contains the model construction by reflection. @@ -142,13 +142,13 @@ func runInDir(program []byte, dir string) (*model.Package, error) { } }() const progSource = "prog.go" - var progBinary = "prog.bin" + progBinary := "prog.bin" if runtime.GOOS == "windows" { // Windows won't execute a program unless it has a ".exe" suffix. progBinary += ".exe" } - if err := os.WriteFile(filepath.Join(tmpDir, progSource), program, 0600); err != nil { + if err := os.WriteFile(filepath.Join(tmpDir, progSource), program, 0o600); err != nil { return nil, err } diff --git a/mockgen/version.go b/mockgen/api/version.go similarity index 98% rename from mockgen/version.go rename to mockgen/api/version.go index 6db160a..400fcb5 100644 --- a/mockgen/version.go +++ b/mockgen/api/version.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package main +package api import ( "fmt" diff --git a/mockgen/mockgen.go b/mockgen/mockgen.go index 3ffc9fe..dcb6f7d 100644 --- a/mockgen/mockgen.go +++ b/mockgen/mockgen.go @@ -19,28 +19,12 @@ package main // TODO: This does not support embedding package-local interfaces in a separate file. import ( - "bytes" - "encoding/json" - "errors" "flag" - "fmt" - "go/token" "io" "log" "os" - "os/exec" - "path" - "path/filepath" - "runtime" - "sort" - "strconv" - "strings" - "unicode" - "golang.org/x/mod/modfile" - toolsimports "golang.org/x/tools/imports" - - "go.uber.org/mock/mockgen/model" + "go.uber.org/mock/mockgen/api" ) const ( @@ -77,148 +61,31 @@ func main() { flag.Usage = usage flag.Parse() - if *showVersion { - printVersion() - return - } - - var pkg *model.Package - var err error - var packageName string - if *source != "" { - pkg, err = sourceMode(*source) - } else { + if *source == "" { if flag.NArg() != 2 { usage() log.Fatal("Expected exactly two arguments") } - packageName = flag.Arg(0) - interfaces := strings.Split(flag.Arg(1), ",") - if packageName == "." { - dir, err := os.Getwd() - if err != nil { - log.Fatalf("Get current directory failed: %v", err) - } - packageName, err = packageNameOfDir(dir) - if err != nil { - log.Fatalf("Parse package name failed: %v", err) - } - } - pkg, err = reflectMode(packageName, interfaces) - } - if err != nil { - log.Fatalf("Loading input failed: %v", err) - } - - if *debugParser { - pkg.Print(os.Stdout) - return - } - - outputPackageName := *packageOut - if outputPackageName == "" { - // pkg.Name in reflect mode is the base name of the import path, - // which might have characters that are illegal to have in package names. - outputPackageName = "mock_" + sanitize(pkg.Name) - } - - // outputPackagePath represents the fully qualified name of the package of - // the generated code. Its purposes are to prevent the module from importing - // itself and to prevent qualifying type names that come from its own - // package (i.e. if there is a type called X then we want to print "X" not - // "package.X" since "package" is this package). This can happen if the mock - // is output into an already existing package. - outputPackagePath := *selfPackage - if outputPackagePath == "" && *destination != "" { - dstPath, err := filepath.Abs(filepath.Dir(*destination)) - if err == nil { - pkgPath, err := parsePackageImport(dstPath) - if err == nil { - outputPackagePath = pkgPath - } else { - log.Println("Unable to infer -self_package from destination file path:", err) - } - } else { - log.Println("Unable to determine destination file path:", err) - } - } - - g := new(generator) - if *source != "" { - g.filename = *source - } else { - g.srcPackage = packageName - g.srcInterfaces = flag.Arg(1) - } - g.destination = *destination - - if *mockNames != "" { - g.mockNames = parseMockNames(*mockNames) - } - if *copyrightFile != "" { - header, err := os.ReadFile(*copyrightFile) - if err != nil { - log.Fatalf("Failed reading copyright file: %v", err) - } - - g.copyrightHeader = string(header) - } - if err := g.Generate(pkg, outputPackageName, outputPackagePath); err != nil { - log.Fatalf("Failed generating mock: %v", err) - } - output := g.Output() - dst := os.Stdout - if len(*destination) > 0 { - if err := os.MkdirAll(filepath.Dir(*destination), os.ModePerm); err != nil { - log.Fatalf("Unable to create directory: %v", err) - } - existing, err := os.ReadFile(*destination) - if err != nil && !errors.Is(err, os.ErrNotExist) { - log.Fatalf("Failed reading pre-exiting destination file: %v", err) - } - if len(existing) == len(output) && bytes.Equal(existing, output) { - return - } - f, err := os.Create(*destination) - if err != nil { - log.Fatalf("Failed opening destination file: %v", err) - } - defer f.Close() - dst = f - } - if _, err := dst.Write(output); err != nil { - log.Fatalf("Failed writing to destination: %v", err) - } -} - -func parseMockNames(names string) map[string]string { - mocksMap := make(map[string]string) - for _, kv := range strings.Split(names, ",") { - parts := strings.SplitN(kv, "=", 2) - if len(parts) != 2 || parts[1] == "" { - log.Fatalf("bad mock names spec: %v", kv) - } - mocksMap[parts[0]] = parts[1] } - return mocksMap -} - -func parseExcludeInterfaces(names string) map[string]struct{} { - splitNames := strings.Split(names, ",") - namesSet := make(map[string]struct{}, len(splitNames)) - for _, name := range splitNames { - if name == "" { - continue - } - namesSet[name] = struct{}{} - } - - if len(namesSet) == 0 { - return nil - } - - return namesSet + api.Generate(api.Config{ + Source: *source, + Destination: *destination, + MockNames: *mockNames, + PackageOut: *packageOut, + SelfPackage: *selfPackage, + WriteCmdComment: *writeCmdComment, + WritePkgComment: *writePkgComment, + WriteSourceComment: *writeSourceComment, + WriteGenerateDirective: *writeGenerateDirective, + CopyrightFile: *copyrightFile, + Typed: *typed, + Imports: *imports, + AuxFiles: *auxFiles, + ExcludeInterfaces: *excludeInterfaces, + DebugParser: *debugParser, + ShowVersion: *showVersion, + }) } func usage() { @@ -242,655 +109,3 @@ Example: mockgen database/sql/driver Conn,Driver ` - -type generator struct { - buf bytes.Buffer - indent string - mockNames map[string]string // may be empty - filename string // may be empty - destination string // may be empty - srcPackage, srcInterfaces string // may be empty - copyrightHeader string - - packageMap map[string]string // map from import path to package name -} - -func (g *generator) p(format string, args ...any) { - _, _ = fmt.Fprintf(&g.buf, g.indent+format+"\n", args...) -} - -func (g *generator) in() { - g.indent += "\t" -} - -func (g *generator) out() { - if len(g.indent) > 0 { - g.indent = g.indent[0 : len(g.indent)-1] - } -} - -// sanitize cleans up a string to make a suitable package name. -func sanitize(s string) string { - t := "" - for _, r := range s { - if t == "" { - if unicode.IsLetter(r) || r == '_' { - t += string(r) - continue - } - } else { - if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' { - t += string(r) - continue - } - } - t += "_" - } - if t == "_" { - t = "x" - } - return t -} - -func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPackagePath string) error { - if outputPkgName != pkg.Name && *selfPackage == "" { - // reset outputPackagePath if it's not passed in through -self_package - outputPackagePath = "" - } - - if g.copyrightHeader != "" { - lines := strings.Split(g.copyrightHeader, "\n") - for _, line := range lines { - g.p("// %s", line) - } - g.p("") - } - - g.p("// Code generated by MockGen. DO NOT EDIT.") - if *writeSourceComment { - if g.filename != "" { - g.p("// Source: %v", g.filename) - } else { - g.p("// Source: %v (interfaces: %v)", g.srcPackage, g.srcInterfaces) - } - } - if *writeCmdComment { - g.p("//") - g.p("// Generated by this command:") - g.p("//") - // only log the name of the executable, not the full path - name := filepath.Base(os.Args[0]) - if runtime.GOOS == "windows" { - name = strings.TrimSuffix(name, ".exe") - } - g.p("//\t%v", strings.Join(append([]string{name}, os.Args[1:]...), " ")) - g.p("//") - } - - // Get all required imports, and generate unique names for them all. - im := pkg.Imports() - im[gomockImportPath] = true - - // Only import reflect if it's used. We only use reflect in mocked methods - // so only import if any of the mocked interfaces have methods. - for _, intf := range pkg.Interfaces { - if len(intf.Methods) > 0 { - im["reflect"] = true - break - } - } - - // Sort keys to make import alias generation predictable - sortedPaths := make([]string, len(im)) - x := 0 - for pth := range im { - sortedPaths[x] = pth - x++ - } - sort.Strings(sortedPaths) - - packagesName := createPackageMap(sortedPaths) - - definedImports := make(map[string]string, len(im)) - if *imports != "" { - for _, kv := range strings.Split(*imports, ",") { - eq := strings.Index(kv, "=") - if k, v := kv[:eq], kv[eq+1:]; k != "." { - definedImports[v] = k - } - } - } - - g.packageMap = make(map[string]string, len(im)) - localNames := make(map[string]bool, len(im)) - for _, pth := range sortedPaths { - base, ok := packagesName[pth] - if !ok { - base = sanitize(path.Base(pth)) - } - - // Local names for an imported package can usually be the basename of the import path. - // A couple of situations don't permit that, such as duplicate local names - // (e.g. importing "html/template" and "text/template"), or where the basename is - // a keyword (e.g. "foo/case") or when defining a name for that by using the -imports flag. - // try base0, base1, ... - pkgName := base - - if _, ok := definedImports[base]; ok { - pkgName = definedImports[base] - } - - i := 0 - for localNames[pkgName] || token.Lookup(pkgName).IsKeyword() || pkgName == "any" { - pkgName = base + strconv.Itoa(i) - i++ - } - - // Avoid importing package if source pkg == output pkg - if pth == pkg.PkgPath && outputPackagePath == pkg.PkgPath { - continue - } - - g.packageMap[pth] = pkgName - localNames[pkgName] = true - } - - // Ensure there is an empty line between “generated by” block and - // package documentation comments to follow the recommendations: - // https://go.dev/wiki/CodeReviewComments#package-comments - // That is, “generated by” should not be a package comment. - g.p("") - - if *writePkgComment { - g.p("// Package %v is a generated GoMock package.", outputPkgName) - } - g.p("package %v", outputPkgName) - g.p("") - g.p("import (") - g.in() - for pkgPath, pkgName := range g.packageMap { - if pkgPath == outputPackagePath { - continue - } - g.p("%v %q", pkgName, pkgPath) - } - for _, pkgPath := range pkg.DotImports { - g.p(". %q", pkgPath) - } - g.out() - g.p(")") - - if *writeGenerateDirective { - g.p("//go:generate %v", strings.Join(os.Args, " ")) - } - - for _, intf := range pkg.Interfaces { - if err := g.GenerateMockInterface(intf, outputPackagePath); err != nil { - return err - } - } - - return nil -} - -// The name of the mock type to use for the given interface identifier. -func (g *generator) mockName(typeName string) string { - if mockName, ok := g.mockNames[typeName]; ok { - return mockName - } - - return "Mock" + typeName -} - -// formattedTypeParams returns a long and short form of type param info used for -// printing. If analyzing a interface with type param [I any, O any] the result -// will be: -// "[I any, O any]", "[I, O]" -func (g *generator) formattedTypeParams(it *model.Interface, pkgOverride string) (string, string) { - if len(it.TypeParams) == 0 { - return "", "" - } - var long, short strings.Builder - long.WriteString("[") - short.WriteString("[") - for i, v := range it.TypeParams { - if i != 0 { - long.WriteString(", ") - short.WriteString(", ") - } - long.WriteString(v.Name) - short.WriteString(v.Name) - long.WriteString(fmt.Sprintf(" %s", v.Type.String(g.packageMap, pkgOverride))) - } - long.WriteString("]") - short.WriteString("]") - return long.String(), short.String() -} - -func (g *generator) GenerateMockInterface(intf *model.Interface, outputPackagePath string) error { - mockType := g.mockName(intf.Name) - longTp, shortTp := g.formattedTypeParams(intf, outputPackagePath) - - g.p("") - g.p("// %v is a mock of %v interface.", mockType, intf.Name) - g.p("type %v%v struct {", mockType, longTp) - g.in() - g.p("ctrl *gomock.Controller") - g.p("recorder *%vMockRecorder%v", mockType, shortTp) - g.out() - g.p("}") - g.p("") - - g.p("// %vMockRecorder is the mock recorder for %v.", mockType, mockType) - g.p("type %vMockRecorder%v struct {", mockType, longTp) - g.in() - g.p("mock *%v%v", mockType, shortTp) - g.out() - g.p("}") - g.p("") - - g.p("// New%v creates a new mock instance.", mockType) - g.p("func New%v%v(ctrl *gomock.Controller) *%v%v {", mockType, longTp, mockType, shortTp) - g.in() - g.p("mock := &%v%v{ctrl: ctrl}", mockType, shortTp) - g.p("mock.recorder = &%vMockRecorder%v{mock}", mockType, shortTp) - g.p("return mock") - g.out() - g.p("}") - g.p("") - - // XXX: possible name collision here if someone has EXPECT in their interface. - g.p("// EXPECT returns an object that allows the caller to indicate expected use.") - g.p("func (m *%v%v) EXPECT() *%vMockRecorder%v {", mockType, shortTp, mockType, shortTp) - g.in() - g.p("return m.recorder") - g.out() - g.p("}") - - // XXX: possible name collision here if someone has ISGOMOCK in their interface. - g.p("// ISGOMOCK indicates that this struct is a gomock mock.") - g.p("func (m *%v%v) ISGOMOCK() struct{} {", mockType, shortTp) - g.in() - g.p("return struct{}{}") - g.out() - g.p("}") - - g.GenerateMockMethods(mockType, intf, outputPackagePath, longTp, shortTp, *typed) - - return nil -} - -type byMethodName []*model.Method - -func (b byMethodName) Len() int { return len(b) } -func (b byMethodName) Swap(i, j int) { b[i], b[j] = b[j], b[i] } -func (b byMethodName) Less(i, j int) bool { return b[i].Name < b[j].Name } - -func (g *generator) GenerateMockMethods(mockType string, intf *model.Interface, pkgOverride, longTp, shortTp string, typed bool) { - sort.Sort(byMethodName(intf.Methods)) - for _, m := range intf.Methods { - g.p("") - _ = g.GenerateMockMethod(mockType, m, pkgOverride, shortTp) - g.p("") - _ = g.GenerateMockRecorderMethod(intf, m, shortTp, typed) - if typed { - g.p("") - _ = g.GenerateMockReturnCallMethod(intf, m, pkgOverride, longTp, shortTp) - } - } -} - -func makeArgString(argNames, argTypes []string) string { - args := make([]string, len(argNames)) - for i, name := range argNames { - // specify the type only once for consecutive args of the same type - if i+1 < len(argTypes) && argTypes[i] == argTypes[i+1] { - args[i] = name - } else { - args[i] = name + " " + argTypes[i] - } - } - return strings.Join(args, ", ") -} - -// GenerateMockMethod generates a mock method implementation. -// If non-empty, pkgOverride is the package in which unqualified types reside. -func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOverride, shortTp string) error { - argNames := g.getArgNames(m, true /* in */) - argTypes := g.getArgTypes(m, pkgOverride, true /* in */) - argString := makeArgString(argNames, argTypes) - - rets := make([]string, len(m.Out)) - for i, p := range m.Out { - rets[i] = p.Type.String(g.packageMap, pkgOverride) - } - retString := strings.Join(rets, ", ") - if len(rets) > 1 { - retString = "(" + retString + ")" - } - if retString != "" { - retString = " " + retString - } - - ia := newIdentifierAllocator(argNames) - idRecv := ia.allocateIdentifier("m") - - g.p("// %v mocks base method.", m.Name) - g.p("func (%v *%v%v) %v(%v)%v {", idRecv, mockType, shortTp, m.Name, argString, retString) - g.in() - g.p("%s.ctrl.T.Helper()", idRecv) - - var callArgs string - if m.Variadic == nil { - if len(argNames) > 0 { - callArgs = ", " + strings.Join(argNames, ", ") - } - } else { - // Non-trivial. The generated code must build a []any, - // but the variadic argument may be any type. - idVarArgs := ia.allocateIdentifier("varargs") - idVArg := ia.allocateIdentifier("a") - g.p("%s := []any{%s}", idVarArgs, strings.Join(argNames[:len(argNames)-1], ", ")) - g.p("for _, %s := range %s {", idVArg, argNames[len(argNames)-1]) - g.in() - g.p("%s = append(%s, %s)", idVarArgs, idVarArgs, idVArg) - g.out() - g.p("}") - callArgs = ", " + idVarArgs + "..." - } - if len(m.Out) == 0 { - g.p(`%v.ctrl.Call(%v, %q%v)`, idRecv, idRecv, m.Name, callArgs) - } else { - idRet := ia.allocateIdentifier("ret") - g.p(`%v := %v.ctrl.Call(%v, %q%v)`, idRet, idRecv, idRecv, m.Name, callArgs) - - // Go does not allow "naked" type assertions on nil values, so we use the two-value form here. - // The value of that is either (x.(T), true) or (Z, false), where Z is the zero value for T. - // Happily, this coincides with the semantics we want here. - retNames := make([]string, len(rets)) - for i, t := range rets { - retNames[i] = ia.allocateIdentifier(fmt.Sprintf("ret%d", i)) - g.p("%s, _ := %s[%d].(%s)", retNames[i], idRet, i, t) - } - g.p("return " + strings.Join(retNames, ", ")) - } - - g.out() - g.p("}") - return nil -} - -func (g *generator) GenerateMockRecorderMethod(intf *model.Interface, m *model.Method, shortTp string, typed bool) error { - mockType := g.mockName(intf.Name) - argNames := g.getArgNames(m, true) - - var argString string - if m.Variadic == nil { - argString = strings.Join(argNames, ", ") - } else { - argString = strings.Join(argNames[:len(argNames)-1], ", ") - } - if argString != "" { - argString += " any" - } - - if m.Variadic != nil { - if argString != "" { - argString += ", " - } - argString += fmt.Sprintf("%s ...any", argNames[len(argNames)-1]) - } - - ia := newIdentifierAllocator(argNames) - idRecv := ia.allocateIdentifier("mr") - - g.p("// %v indicates an expected call of %v.", m.Name, m.Name) - if typed { - g.p("func (%s *%vMockRecorder%v) %v(%v) *%s%sCall%s {", idRecv, mockType, shortTp, m.Name, argString, mockType, m.Name, shortTp) - } else { - g.p("func (%s *%vMockRecorder%v) %v(%v) *gomock.Call {", idRecv, mockType, shortTp, m.Name, argString) - } - - g.in() - g.p("%s.mock.ctrl.T.Helper()", idRecv) - - var callArgs string - if m.Variadic == nil { - if len(argNames) > 0 { - callArgs = ", " + strings.Join(argNames, ", ") - } - } else { - if len(argNames) == 1 { - // Easy: just use ... to push the arguments through. - callArgs = ", " + argNames[0] + "..." - } else { - // Hard: create a temporary slice. - idVarArgs := ia.allocateIdentifier("varargs") - g.p("%s := append([]any{%s}, %s...)", - idVarArgs, - strings.Join(argNames[:len(argNames)-1], ", "), - argNames[len(argNames)-1]) - callArgs = ", " + idVarArgs + "..." - } - } - if typed { - g.p(`call := %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, shortTp, m.Name, callArgs) - g.p(`return &%s%sCall%s{Call: call}`, mockType, m.Name, shortTp) - } else { - g.p(`return %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, shortTp, m.Name, callArgs) - } - - g.out() - g.p("}") - return nil -} - -func (g *generator) GenerateMockReturnCallMethod(intf *model.Interface, m *model.Method, pkgOverride, longTp, shortTp string) error { - mockType := g.mockName(intf.Name) - argNames := g.getArgNames(m, true /* in */) - retNames := g.getArgNames(m, false /* out */) - argTypes := g.getArgTypes(m, pkgOverride, true /* in */) - retTypes := g.getArgTypes(m, pkgOverride, false /* out */) - argString := strings.Join(argTypes, ", ") - - rets := make([]string, len(m.Out)) - for i, p := range m.Out { - rets[i] = p.Type.String(g.packageMap, pkgOverride) - } - - var retString string - switch { - case len(rets) == 1: - retString = " " + rets[0] - case len(rets) > 1: - retString = " (" + strings.Join(rets, ", ") + ")" - } - - ia := newIdentifierAllocator(argNames) - idRecv := ia.allocateIdentifier("c") - - recvStructName := mockType + m.Name - - g.p("// %s%sCall wrap *gomock.Call", mockType, m.Name) - g.p("type %s%sCall%s struct{", mockType, m.Name, longTp) - g.in() - g.p("*gomock.Call") - g.out() - g.p("}") - - g.p("// Return rewrite *gomock.Call.Return") - g.p("func (%s *%sCall%s) Return(%v) *%sCall%s {", idRecv, recvStructName, shortTp, makeArgString(retNames, retTypes), recvStructName, shortTp) - g.in() - var retArgs string - if len(retNames) > 0 { - retArgs = strings.Join(retNames, ", ") - } - g.p(`%s.Call = %v.Call.Return(%v)`, idRecv, idRecv, retArgs) - g.p("return %s", idRecv) - g.out() - g.p("}") - - g.p("// Do rewrite *gomock.Call.Do") - g.p("func (%s *%sCall%s) Do(f func(%v)%v) *%sCall%s {", idRecv, recvStructName, shortTp, argString, retString, recvStructName, shortTp) - g.in() - g.p(`%s.Call = %v.Call.Do(f)`, idRecv, idRecv) - g.p("return %s", idRecv) - g.out() - g.p("}") - - g.p("// DoAndReturn rewrite *gomock.Call.DoAndReturn") - g.p("func (%s *%sCall%s) DoAndReturn(f func(%v)%v) *%sCall%s {", idRecv, recvStructName, shortTp, argString, retString, recvStructName, shortTp) - g.in() - g.p(`%s.Call = %v.Call.DoAndReturn(f)`, idRecv, idRecv) - g.p("return %s", idRecv) - g.out() - g.p("}") - return nil -} - -func (g *generator) getArgNames(m *model.Method, in bool) []string { - var params []*model.Parameter - if in { - params = m.In - } else { - params = m.Out - } - argNames := make([]string, len(params)) - for i, p := range params { - name := p.Name - if name == "" || name == "_" { - name = fmt.Sprintf("arg%d", i) - } - argNames[i] = name - } - if m.Variadic != nil && in { - name := m.Variadic.Name - if name == "" { - name = fmt.Sprintf("arg%d", len(params)) - } - argNames = append(argNames, name) - } - return argNames -} - -func (g *generator) getArgTypes(m *model.Method, pkgOverride string, in bool) []string { - var params []*model.Parameter - if in { - params = m.In - } else { - params = m.Out - } - argTypes := make([]string, len(params)) - for i, p := range params { - argTypes[i] = p.Type.String(g.packageMap, pkgOverride) - } - if m.Variadic != nil { - argTypes = append(argTypes, "..."+m.Variadic.Type.String(g.packageMap, pkgOverride)) - } - return argTypes -} - -type identifierAllocator map[string]struct{} - -func newIdentifierAllocator(taken []string) identifierAllocator { - a := make(identifierAllocator, len(taken)) - for _, s := range taken { - a[s] = struct{}{} - } - return a -} - -func (o identifierAllocator) allocateIdentifier(want string) string { - id := want - for i := 2; ; i++ { - if _, ok := o[id]; !ok { - o[id] = struct{}{} - return id - } - id = want + "_" + strconv.Itoa(i) - } -} - -// Output returns the generator's output, formatted in the standard Go style. -func (g *generator) Output() []byte { - src, err := toolsimports.Process(g.destination, g.buf.Bytes(), nil) - if err != nil { - log.Fatalf("Failed to format generated source code: %s\n%s", err, g.buf.String()) - } - return src -} - -// createPackageMap returns a map of import path to package name -// for specified importPaths. -func createPackageMap(importPaths []string) map[string]string { - var pkg struct { - Name string - ImportPath string - } - pkgMap := make(map[string]string) - b := bytes.NewBuffer(nil) - args := []string{"list", "-json"} - args = append(args, importPaths...) - cmd := exec.Command("go", args...) - cmd.Stdout = b - cmd.Run() - dec := json.NewDecoder(b) - for dec.More() { - err := dec.Decode(&pkg) - if err != nil { - log.Printf("failed to decode 'go list' output: %v", err) - continue - } - pkgMap[pkg.ImportPath] = pkg.Name - } - return pkgMap -} - -func printVersion() { - if version != "" { - fmt.Printf("v%s\nCommit: %s\nDate: %s\n", version, commit, date) - } else { - printModuleVersion() - } -} - -// parseImportPackage get package import path via source file -// an alternative implementation is to use: -// cfg := &packages.Config{Mode: packages.NeedName, Tests: true, Dir: srcDir} -// pkgs, err := packages.Load(cfg, "file="+source) -// However, it will call "go list" and slow down the performance -func parsePackageImport(srcDir string) (string, error) { - moduleMode := os.Getenv("GO111MODULE") - // trying to find the module - if moduleMode != "off" { - currentDir := srcDir - for { - dat, err := os.ReadFile(filepath.Join(currentDir, "go.mod")) - if os.IsNotExist(err) { - if currentDir == filepath.Dir(currentDir) { - // at the root - break - } - currentDir = filepath.Dir(currentDir) - continue - } else if err != nil { - return "", err - } - modulePath := modfile.ModulePath(dat) - return filepath.ToSlash(filepath.Join(modulePath, strings.TrimPrefix(srcDir, currentDir))), nil - } - } - // fall back to GOPATH mode - goPaths := os.Getenv("GOPATH") - if goPaths == "" { - return "", fmt.Errorf("GOPATH is not set") - } - goPathList := strings.Split(goPaths, string(os.PathListSeparator)) - for _, goPath := range goPathList { - sourceRoot := filepath.Join(goPath, "src") + string(os.PathSeparator) - if strings.HasPrefix(srcDir, sourceRoot) { - return filepath.ToSlash(strings.TrimPrefix(srcDir, sourceRoot)), nil - } - } - return "", errOutsideGoPath -}