Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a replace-type parameter to allow replacing package and/or type names #540

Merged
merged 15 commits into from
Mar 20, 2023
1 change: 1 addition & 0 deletions cmd/mockery.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ func NewRootCmd() *cobra.Command {
pFlags.Bool("unroll-variadic", true, "For functions with variadic arguments, do not unroll the arguments into the underlying testify call. Instead, pass variadic slice as-is.")
pFlags.Bool("exported", false, "Generates public mocks for private interfaces.")
pFlags.Bool("with-expecter", false, "Generate expecter utility around mock's On, Run and Return methods with explicit types. This option is NOT compatible with -unroll-variadic=false")
pFlags.StringArray("replace-type", nil, "Replace types")

viperCfg.BindPFlags(pFlags)

Expand Down
3 changes: 2 additions & 1 deletion docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,5 @@ Parameter Descriptions
| `exported` | Use `exported: True` to generate public mocks for private interfaces. |
| `with-expecter` | Use `with-expecter: True` to generate `EXPECT()` methods for your mocks. This is the preferred way to setup your mocks. |
| `testonly` | Prepend every mock file with `_test.go`. This is useful in cases where you are generating mocks `inpackage` but don't want the mocks to be visible to code outside of tests. |
| `inpackage-suffix` | When `inpackage-suffix` is set to `True`, mock files are suffixed with `_mock` instead of being prefixed with `mock_` for InPackage mocks |
| `inpackage-suffix` | When `inpackage-suffix` is set to `True`, mock files are suffixed with `_mock` instead of being prefixed with `mock_` for InPackage mocks |
| `replace-type source=destination` | Replaces aliases, packages and/or types during generation.|
81 changes: 81 additions & 0 deletions docs/features.md
Original file line number Diff line number Diff line change
Expand Up @@ -271,3 +271,84 @@ Return(
},
)
```

Replace Types
-------------

The `replace-type` parameter allows adding a list of type replacements to be made in package and/or type names.
This can help overcome some parsing problems like type aliases that the Go parser doesn't provide enough information.

This parameter can be specified multiple times.

```shell
mockery --replace-type github.com/vektra/mockery/v2/baz/internal/foo.InternalBaz=baz:github.com/vektra/mockery/v2/baz.Baz
```

This will replace any imported named `"github.com/vektra/mockery/v2/baz/internal/foo"`
with `baz "github.com/vektra/mockery/v2/baz"`. The alias is defined with `:` before
the package name. Also, the `InternalBaz` type that comes from this package will be renamed to `baz.Baz`.

This next example fixes a common problem of type aliases that point to an internal package.

`cloud.google.com/go/pubsub.Message` is a type alias defined like this:

```go
import (
ipubsub "cloud.google.com/go/internal/pubsub"
)

type Message = ipubsub.Message
```

The Go parser that mockery uses doesn't provide a way to detect this alias and sends the application the package and
type name of the type in the internal package, which will not work.

We can use "replace-type" with only the package part to replace any import of `cloud.google.com/go/internal/pubsub` to
`cloud.google.com/go/pubsub`. We don't need to change the alias or type name in this case, because they are `pubsub`
and `Message` in both cases.

```shell
mockery --replace-type cloud.google.com/go/internal/pubsub=cloud.google.com/go/pubsub
```

Original source:

```go
import (
"cloud.google.com/go/pubsub"
)

type Handler struct {
HandleMessage(m pubsub.Message) error
}
```

Invalid mock generated without this parameter (points to an `internal` folder):

```go
import (
mock "github.com/stretchr/testify/mock"

pubsub "cloud.google.com/go/internal/pubsub"
)

func (_m *Handler) HandleMessage(m pubsub.Message) error {
// ...
return nil
}
```

Correct mock generated with this parameter.

```go
import (
mock "github.com/stretchr/testify/mock"

pubsub "cloud.google.com/go/pubsub"
)

func (_m *Handler) HandleMessage(m pubsub.Message) error {
// ...
return nil
}
```
11 changes: 6 additions & 5 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,12 @@ type Config struct {
BoilerplateFile string `mapstructure:"boilerplate-file"`
// StructName overrides the name given to the mock struct and should only be nonempty
// when generating for an exact match (non regex expression in -name).
StructName string `mapstructure:"structname"`
TestOnly bool `mapstructure:"testonly"`
UnrollVariadic bool `mapstructure:"unroll-variadic"`
Version bool `mapstructure:"version"`
WithExpecter bool `mapstructure:"with-expecter"`
StructName string `mapstructure:"structname"`
TestOnly bool `mapstructure:"testonly"`
UnrollVariadic bool `mapstructure:"unroll-variadic"`
Version bool `mapstructure:"version"`
WithExpecter bool `mapstructure:"with-expecter"`
ReplaceType []string `mapstructure:"replace-type"`

// Viper throws away case-sensitivity when it marshals into this struct. This
// destroys necessary information we need, specifically around interface names.
Expand Down
12 changes: 12 additions & 0 deletions pkg/fixtures/example_project/baz/foo.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package baz

import (
ifoo "github.com/vektra/mockery/v2/pkg/fixtures/example_project/baz/internal/foo"
)

type Baz = ifoo.InternalBaz

type Foo interface {
DoFoo() string
GetBaz() (*Baz, error)
}
6 changes: 6 additions & 0 deletions pkg/fixtures/example_project/baz/internal/foo/foo.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package foo

type InternalBaz struct {
One string
Two int
}
78 changes: 77 additions & 1 deletion pkg/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ type GeneratorConfig struct {
StructName string
UnrollVariadic bool
WithExpecter bool
ReplaceType []string
}

// Generator is responsible for generating the string containing
Expand All @@ -76,6 +77,7 @@ type Generator struct {
localizationCache map[string]string
packagePathToName map[string]string
nameToPackagePath map[string]string
replaceTypeCache []*replaceTypeItem
}

// NewGenerator builds a Generator.
Expand All @@ -99,6 +101,7 @@ func NewGenerator(ctx context.Context, c GeneratorConfig, iface *Interface, pkg
nameToPackagePath: make(map[string]string),
}

g.parseReplaceTypes(ctx)
g.addPackageImportWithName(ctx, "github.com/stretchr/testify/mock", "mock")

return g
Expand Down Expand Up @@ -157,14 +160,42 @@ func (g *Generator) getPackageScopedType(ctx context.Context, o *types.TypeName)
(!g.config.KeepTree && g.config.InPackage && o.Pkg() == g.iface.Pkg) {
return o.Name()
}
return g.addPackageImport(ctx, o.Pkg()) + "." + o.Name()
pkg := g.addPackageImport(ctx, o.Pkg())
name := o.Name()
g.checkReplaceType(ctx, func(from *replaceType, to *replaceType) bool {
if o.Pkg().Path() == from.pkg && name == from.typ {
name = to.typ
return false
}
return true
})
return pkg + "." + name
}

func (g *Generator) addPackageImport(ctx context.Context, pkg *types.Package) string {
return g.addPackageImportWithName(ctx, pkg.Path(), pkg.Name())
}

func (g *Generator) checkReplaceType(ctx context.Context, f func(from *replaceType, to *replaceType) bool) {
for _, item := range g.replaceTypeCache {
if !f(item.from, item.to) {
break
}
}
}

func (g *Generator) addPackageImportWithName(ctx context.Context, path, name string) string {
g.checkReplaceType(ctx, func(from *replaceType, to *replaceType) bool {
if path == from.pkg {
path = to.pkg
if to.alias != "" {
name = to.alias
}
return false
}
return true
})

if existingName, pathExists := g.packagePathToName[path]; pathExists {
return existingName
}
Expand All @@ -175,6 +206,22 @@ func (g *Generator) addPackageImportWithName(ctx context.Context, path, name str
return nonConflictingName
}

func (g *Generator) parseReplaceTypes(ctx context.Context) {
for _, replace := range g.config.ReplaceType {
r := strings.SplitN(replace, "=", 2)
if len(r) != 2 {
log := zerolog.Ctx(ctx)
log.Error().Msgf("invalid replace type value: %s", replace)
continue
}

g.replaceTypeCache = append(g.replaceTypeCache, &replaceTypeItem{
from: parseReplaceType(r[0]),
to: parseReplaceType(r[1]),
})
}
}

func (g *Generator) getNonConflictingName(path, name string) string {
if !g.importNameExists(name) && (!g.config.InPackage || g.iface.Pkg.Name() != name) {
// do not allow imports with the same name as the package when inPackage
Expand Down Expand Up @@ -968,3 +1015,32 @@ func resolveCollision(names []string, variable string) string {

return ret
}

type replaceType struct {
alias string
pkg string
typ string
}

type replaceTypeItem struct {
from *replaceType
to *replaceType
}

func parseReplaceType(t string) *replaceType {
ret := &replaceType{}
r := strings.SplitN(t, ":", 2)
if len(r) > 1 {
ret.alias = r[0]
t = r[1]
}
lastDot := strings.LastIndex(t, ".")
lastSlash := strings.LastIndex(t, "/")
if lastDot == -1 || (lastSlash > -1 && lastDot < lastSlash) {
ret.pkg = t
} else {
ret.pkg = t[:lastDot]
ret.typ = t[lastDot+1:]
}
return ret
}
89 changes: 89 additions & 0 deletions pkg/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"os"
"path/filepath"
"regexp"
"strings"
"testing"

Expand Down Expand Up @@ -85,6 +86,37 @@ func (s *GeneratorSuite) checkGenerationWithConfig(
return generator
}

type regexpExpected struct {
shouldMatch bool
re *regexp.Regexp
}

func (s *GeneratorSuite) checkGenerationRegexWithConfig(
filepath, interfaceName string, cfg GeneratorConfig, expected []regexpExpected,
) *Generator {
generator := s.getGeneratorWithConfig(filepath, interfaceName, cfg)
err := generator.Generate(s.ctx)
s.NoError(err, "The generator ran without errors.")
if err != nil {
return generator
}
// Mirror the formatting done by normally done by golang.org/x/tools/imports in Generator.Write.
//
// While we could possibly reuse Generator.Write here in addition to Generator.Generate,
// it would require changing Write's signature to accept custom options, specifically to
// allow the fragments in preexisting cases. It's assumed that this approximation,
// just formatting the source, is sufficient for the needs of the current test styles.
var actual []byte
actual, fmtErr := format.Source(generator.buf.Bytes())
s.NoError(fmtErr, "The formatter ran without errors.")

for _, re := range expected {
s.Equalf(re.shouldMatch, re.re.Match(actual), "match '%s' should be %t", re.re.String(), re.shouldMatch)
}

return generator
}

func (s *GeneratorSuite) getGenerator(
filepath, interfaceName string, inPackage bool, structName string,
) *Generator {
Expand Down Expand Up @@ -2424,6 +2456,38 @@ import mock "github.com/stretchr/testify/mock"
s.checkPrologueGeneration(generator, expected)
}

func (s *GeneratorSuite) TestReplaceTypePackagePrologue() {
expected := `package mocks

import baz "github.com/vektra/mockery/v2/pkg/fixtures/example_project/baz"
import mock "github.com/stretchr/testify/mock"

`
generator := NewGenerator(
s.ctx,
GeneratorConfig{InPackage: false, ReplaceType: []string{
"github.com/vektra/mockery/v2/pkg/fixtures/example_project/baz/internal/foo.InternalBaz=baz:github.com/vektra/mockery/v2/pkg/fixtures/example_project/baz.Baz",
}},
s.getInterfaceFromFile("example_project/baz/foo.go", "Foo"),
pkg,
)

s.checkPrologueGeneration(generator, expected)
}

func (s *GeneratorSuite) TestReplaceTypePackage() {
cfg := GeneratorConfig{InPackage: false, ReplaceType: []string{
"github.com/vektra/mockery/v2/pkg/fixtures/example_project/baz/internal/foo.InternalBaz=baz:github.com/vektra/mockery/v2/pkg/fixtures/example_project/baz.Baz",
}}

s.checkGenerationRegexWithConfig("example_project/baz/foo.go", "Foo", cfg, []regexpExpected{
// func (_m *Foo) GetBaz() (*baz.Baz, error)
{true, regexp.MustCompile(`func \([^\)]+\) GetBaz\(\) \(\*baz\.Baz`)},
// func (_m *Foo) GetBaz() (*foo.InternalBaz, error)
{false, regexp.MustCompile(`func \([^\)]+\) GetBaz\(\) \(\*foo\.InternalBaz`)},
})
}

func (s *GeneratorSuite) TestGenericGenerator() {
expected := `// RequesterGenerics is an autogenerated mock type for the RequesterGenerics type
type RequesterGenerics[TAny interface{}, TComparable comparable, TSigned constraints.Signed, TIntf test.GetInt, TExternalIntf io.Writer, TGenIntf test.GetGeneric[TSigned], TInlineType interface{ ~int | ~uint }, TInlineTypeGeneric interface {
Expand Down Expand Up @@ -2812,3 +2876,28 @@ func TestGeneratorSuite(t *testing.T) {
generatorSuite := new(GeneratorSuite)
suite.Run(t, generatorSuite)
}

func TestParseReplaceType(t *testing.T) {
tests := []struct {
value string
expected replaceType
}{
{
value: "github.com/vektra/mockery/v2/pkg/fixtures/example_project/baz/internal/foo.InternalBaz",
expected: replaceType{alias: "", pkg: "github.com/vektra/mockery/v2/pkg/fixtures/example_project/baz/internal/foo", typ: "InternalBaz"},
},
{
value: "baz:github.com/vektra/mockery/v2/pkg/fixtures/example_project/baz.Baz",
expected: replaceType{alias: "baz", pkg: "github.com/vektra/mockery/v2/pkg/fixtures/example_project/baz", typ: "Baz"},
},
{
value: "github.com/vektra/mockery/v2/pkg/fixtures/example_project/baz",
expected: replaceType{alias: "", pkg: "github.com/vektra/mockery/v2/pkg/fixtures/example_project/baz", typ: ""},
},
}

for _, test := range tests {
actual := parseReplaceType(test.value)
assert.Equal(t, test.expected, *actual)
}
}
1 change: 1 addition & 0 deletions pkg/outputter.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ func (m *Outputter) Generate(ctx context.Context, iface *Interface) error {
StructName: interfaceConfig.StructName,
UnrollVariadic: interfaceConfig.UnrollVariadic,
WithExpecter: interfaceConfig.WithExpecter,
ReplaceType: interfaceConfig.ReplaceType,
}
generator := NewGenerator(ctx, g, iface, "")

Expand Down
Loading