Merge remote-tracking branch 'origin/master' into allow-error-string-from-custom-matcher
diff --git a/mock/mock.go b/mock/mock.go
index d6694ed..ce446db 100644
--- a/mock/mock.go
+++ b/mock/mock.go
@@ -557,6 +557,10 @@
Anything = "mock.Anything"
)
+var (
+ errorType = reflect.TypeOf((*error)(nil)).Elem()
+)
+
// AnythingOfTypeArgument is a string that contains the type of an argument
// for use when type checking. Used in Diff and Assert.
type AnythingOfTypeArgument string
@@ -578,6 +582,10 @@
}
func (f argumentMatcher) Matches(argument interface{}) bool {
+ return f.match(argument) == nil
+}
+
+func (f argumentMatcher) match(argument interface{}) error {
expectType := f.fn.Type().In(0)
expectTypeNilSupported := false
switch expectType.Kind() {
@@ -598,25 +606,52 @@
}
if argType == nil || argType.AssignableTo(expectType) {
result := f.fn.Call([]reflect.Value{arg})
- return result[0].Bool()
+
+ var matchError error
+ switch {
+ case result[0].Type().Kind() == reflect.Bool:
+ if !result[0].Bool() {
+ matchError = fmt.Errorf("not matched by %s", f)
+ }
+ case result[0].Type().Implements(errorType):
+ if !result[0].IsNil() {
+ matchError = result[0].Interface().(error)
+ }
+ default:
+ panic(fmt.Errorf("matcher function of unknown type: %s", result[0].Type().Kind()))
+ }
+
+ return matchError
}
- return false
+ return fmt.Errorf("unexpected type for %s", f)
}
func (f argumentMatcher) String() string {
- return fmt.Sprintf("func(%s) bool", f.fn.Type().In(0).Name())
+ return fmt.Sprintf("func(%s) %s", f.fn.Type().In(0).String(), f.fn.Type().Out(0).String())
+}
+
+func (f argumentMatcher) GoString() string {
+ return fmt.Sprintf("MatchedBy(%s)", f)
}
// MatchedBy can be used to match a mock call based on only certain properties
// from a complex struct or some calculation. It takes a function that will be
-// evaluated with the called argument and will return true when there's a match
-// and false otherwise.
+// evaluated with the called argument and will return either a boolean (true
+// when there's a match and false otherwise) or an error (nil when there's a
+// match and error holding the failure message otherwise).
//
-// Example:
-// m.On("Do", MatchedBy(func(req *http.Request) bool { return req.Host == "example.com" }))
+// Examples:
+// m.On("Do", MatchedBy(func(req *http.Request) bool { return req.Host == "example.com" }))
+//
+// m.On("Do", MatchedBy(func(req *http.Request) (err error) {
+// if req.Host != "example.com" {
+// err = errors.New("host was not example.com")
+// }
+// return
+// })
//
// |fn|, must be a function accepting a single argument (of the expected type)
-// which returns a bool. If |fn| doesn't match the required signature,
+// which returns a bool or error. If |fn| doesn't match the required signature,
// MatchedBy() panics.
func MatchedBy(fn interface{}) argumentMatcher {
fnType := reflect.TypeOf(fn)
@@ -627,8 +662,9 @@
if fnType.NumIn() != 1 {
panic(fmt.Sprintf("assert: arguments: %s does not take exactly one argument", fn))
}
- if fnType.NumOut() != 1 || fnType.Out(0).Kind() != reflect.Bool {
- panic(fmt.Sprintf("assert: arguments: %s does not return a bool", fn))
+
+ if fnType.NumOut() != 1 || (fnType.Out(0).Kind() != reflect.Bool && !fnType.Out(0).Implements(errorType)) {
+ panic(fmt.Sprintf("assert: arguments: %s does not return a bool or a error", fn))
}
return argumentMatcher{fn: reflect.ValueOf(fn)}
@@ -688,11 +724,11 @@
}
if matcher, ok := expected.(argumentMatcher); ok {
- if matcher.Matches(actual) {
+ if matchError := matcher.match(actual); matchError == nil {
output = fmt.Sprintf("%s\t%d: PASS: %s matched by %s\n", output, i, actualFmt, matcher)
} else {
differences++
- output = fmt.Sprintf("%s\t%d: FAIL: %s not matched by %s\n", output, i, actualFmt, matcher)
+ output = fmt.Sprintf("%s\t%d: FAIL: %s %s\n", output, i, actualFmt, matchError)
}
} else if reflect.TypeOf(expected) == reflect.TypeOf((*AnythingOfTypeArgument)(nil)).Elem() {
diff --git a/mock/mock_test.go b/mock/mock_test.go
index 2608f5a..9fb14a3 100644
--- a/mock/mock_test.go
+++ b/mock/mock_test.go
@@ -1259,7 +1259,7 @@
diff, count = args.Diff([]interface{}{"string", false, true})
assert.Equal(t, 1, count)
- assert.Contains(t, diff, `(bool=false) not matched by func(int) bool`)
+ assert.Contains(t, diff, `(bool=false) unexpected type for func(int) bool`)
diff, count = args.Diff([]interface{}{"string", 123, false})
assert.Contains(t, diff, `(int=123) matched by func(int) bool`)
@@ -1269,6 +1269,31 @@
assert.Contains(t, diff, `No differences.`)
}
+func Test_Arguments_Diff_WithArgMatcherReturningError(t *testing.T) {
+ matchFn := func(a int) (err error) {
+ if a != 123 {
+ err = errors.New("did not match")
+ }
+ return
+ }
+ var args = Arguments([]interface{}{"string", MatchedBy(matchFn), true})
+
+ diff, count := args.Diff([]interface{}{"string", 124, true})
+ assert.Equal(t, 1, count)
+ assert.Contains(t, diff, `(int=124) did not match`)
+
+ diff, count = args.Diff([]interface{}{"string", false, true})
+ assert.Equal(t, 1, count)
+ assert.Contains(t, diff, `(bool=false) unexpected type for func(int) error`)
+
+ diff, count = args.Diff([]interface{}{"string", 123, false})
+ assert.Contains(t, diff, `(int=123) matched by func(int) error`)
+
+ diff, count = args.Diff([]interface{}{"string", 123, true})
+ assert.Equal(t, 0, count)
+ assert.Contains(t, diff, `No differences.`)
+}
+
func Test_Arguments_Assert(t *testing.T) {
var args = Arguments([]interface{}{"string", 123, true})
@@ -1445,7 +1470,7 @@
defer func() {
if r := recover(); r != nil {
matchingExp := regexp.MustCompile(
- `\s+mock: Unexpected Method Call\s+-*\s+GetTime\(int\)\s+0: 1\s+The closest call I have is:\s+GetTime\(mock.argumentMatcher\)\s+0: mock.argumentMatcher\{.*?\}\s+Diff:.*\(int=1\) not matched by func\(int\) bool`)
+ `\s+mock: Unexpected Method Call\s+-*\s+GetTime\(int\)\s+0: 1\s+The closest call I have is:\s+GetTime\(mock.argumentMatcher\)\s+0: MatchedBy\(func\(int\) bool\)\s+Diff:.*\(int=1\) not matched by func\(int\) bool`)
assert.Regexp(t, matchingExp, r)
}
}()