Merge remote-tracking branch 'origin/master' into allow-error-string-from-custom-matcher
diff --git a/mock/mock.go b/mock/mock.go
index d6694ed..ce446db 100644
--- a/mock/mock.go
+++ b/mock/mock.go
@@ -557,6 +557,10 @@
 	Anything = "mock.Anything"
 )
 
+var (
+	errorType = reflect.TypeOf((*error)(nil)).Elem()
+)
+
 // AnythingOfTypeArgument is a string that contains the type of an argument
 // for use when type checking.  Used in Diff and Assert.
 type AnythingOfTypeArgument string
@@ -578,6 +582,10 @@
 }
 
 func (f argumentMatcher) Matches(argument interface{}) bool {
+	return f.match(argument) == nil
+}
+
+func (f argumentMatcher) match(argument interface{}) error {
 	expectType := f.fn.Type().In(0)
 	expectTypeNilSupported := false
 	switch expectType.Kind() {
@@ -598,25 +606,52 @@
 	}
 	if argType == nil || argType.AssignableTo(expectType) {
 		result := f.fn.Call([]reflect.Value{arg})
-		return result[0].Bool()
+
+		var matchError error
+		switch {
+		case result[0].Type().Kind() == reflect.Bool:
+			if !result[0].Bool() {
+				matchError = fmt.Errorf("not matched by %s", f)
+			}
+		case result[0].Type().Implements(errorType):
+			if !result[0].IsNil() {
+				matchError = result[0].Interface().(error)
+			}
+		default:
+			panic(fmt.Errorf("matcher function of unknown type: %s", result[0].Type().Kind()))
+		}
+
+		return matchError
 	}
-	return false
+	return fmt.Errorf("unexpected type for %s", f)
 }
 
 func (f argumentMatcher) String() string {
-	return fmt.Sprintf("func(%s) bool", f.fn.Type().In(0).Name())
+	return fmt.Sprintf("func(%s) %s", f.fn.Type().In(0).String(), f.fn.Type().Out(0).String())
+}
+
+func (f argumentMatcher) GoString() string {
+	return fmt.Sprintf("MatchedBy(%s)", f)
 }
 
 // MatchedBy can be used to match a mock call based on only certain properties
 // from a complex struct or some calculation. It takes a function that will be
-// evaluated with the called argument and will return true when there's a match
-// and false otherwise.
+// evaluated with the called argument and will return either a boolean (true
+// when there's a match and false otherwise) or an error (nil when there's a
+// match and error holding the failure message otherwise).
 //
-// Example:
-// m.On("Do", MatchedBy(func(req *http.Request) bool { return req.Host == "example.com" }))
+// Examples:
+//  m.On("Do", MatchedBy(func(req *http.Request) bool { return req.Host == "example.com" }))
+//
+//  m.On("Do", MatchedBy(func(req *http.Request) (err error) {
+//  	if req.Host != "example.com" {
+//  		err = errors.New("host was not example.com")
+//  	}
+//  	return
+//  })
 //
 // |fn|, must be a function accepting a single argument (of the expected type)
-// which returns a bool. If |fn| doesn't match the required signature,
+// which returns a bool or error. If |fn| doesn't match the required signature,
 // MatchedBy() panics.
 func MatchedBy(fn interface{}) argumentMatcher {
 	fnType := reflect.TypeOf(fn)
@@ -627,8 +662,9 @@
 	if fnType.NumIn() != 1 {
 		panic(fmt.Sprintf("assert: arguments: %s does not take exactly one argument", fn))
 	}
-	if fnType.NumOut() != 1 || fnType.Out(0).Kind() != reflect.Bool {
-		panic(fmt.Sprintf("assert: arguments: %s does not return a bool", fn))
+
+	if fnType.NumOut() != 1 || (fnType.Out(0).Kind() != reflect.Bool && !fnType.Out(0).Implements(errorType)) {
+		panic(fmt.Sprintf("assert: arguments: %s does not return a bool or a error", fn))
 	}
 
 	return argumentMatcher{fn: reflect.ValueOf(fn)}
@@ -688,11 +724,11 @@
 		}
 
 		if matcher, ok := expected.(argumentMatcher); ok {
-			if matcher.Matches(actual) {
+			if matchError := matcher.match(actual); matchError == nil {
 				output = fmt.Sprintf("%s\t%d: PASS:  %s matched by %s\n", output, i, actualFmt, matcher)
 			} else {
 				differences++
-				output = fmt.Sprintf("%s\t%d: FAIL:  %s not matched by %s\n", output, i, actualFmt, matcher)
+				output = fmt.Sprintf("%s\t%d: FAIL:  %s %s\n", output, i, actualFmt, matchError)
 			}
 		} else if reflect.TypeOf(expected) == reflect.TypeOf((*AnythingOfTypeArgument)(nil)).Elem() {
 
diff --git a/mock/mock_test.go b/mock/mock_test.go
index 2608f5a..9fb14a3 100644
--- a/mock/mock_test.go
+++ b/mock/mock_test.go
@@ -1259,7 +1259,7 @@
 
 	diff, count = args.Diff([]interface{}{"string", false, true})
 	assert.Equal(t, 1, count)
-	assert.Contains(t, diff, `(bool=false) not matched by func(int) bool`)
+	assert.Contains(t, diff, `(bool=false) unexpected type for func(int) bool`)
 
 	diff, count = args.Diff([]interface{}{"string", 123, false})
 	assert.Contains(t, diff, `(int=123) matched by func(int) bool`)
@@ -1269,6 +1269,31 @@
 	assert.Contains(t, diff, `No differences.`)
 }
 
+func Test_Arguments_Diff_WithArgMatcherReturningError(t *testing.T) {
+	matchFn := func(a int) (err error) {
+		if a != 123 {
+			err = errors.New("did not match")
+		}
+		return
+	}
+	var args = Arguments([]interface{}{"string", MatchedBy(matchFn), true})
+
+	diff, count := args.Diff([]interface{}{"string", 124, true})
+	assert.Equal(t, 1, count)
+	assert.Contains(t, diff, `(int=124) did not match`)
+
+	diff, count = args.Diff([]interface{}{"string", false, true})
+	assert.Equal(t, 1, count)
+	assert.Contains(t, diff, `(bool=false) unexpected type for func(int) error`)
+
+	diff, count = args.Diff([]interface{}{"string", 123, false})
+	assert.Contains(t, diff, `(int=123) matched by func(int) error`)
+
+	diff, count = args.Diff([]interface{}{"string", 123, true})
+	assert.Equal(t, 0, count)
+	assert.Contains(t, diff, `No differences.`)
+}
+
 func Test_Arguments_Assert(t *testing.T) {
 
 	var args = Arguments([]interface{}{"string", 123, true})
@@ -1445,7 +1470,7 @@
 	defer func() {
 		if r := recover(); r != nil {
 			matchingExp := regexp.MustCompile(
-				`\s+mock: Unexpected Method Call\s+-*\s+GetTime\(int\)\s+0: 1\s+The closest call I have is:\s+GetTime\(mock.argumentMatcher\)\s+0: mock.argumentMatcher\{.*?\}\s+Diff:.*\(int=1\) not matched by func\(int\) bool`)
+				`\s+mock: Unexpected Method Call\s+-*\s+GetTime\(int\)\s+0: 1\s+The closest call I have is:\s+GetTime\(mock.argumentMatcher\)\s+0: MatchedBy\(func\(int\) bool\)\s+Diff:.*\(int=1\) not matched by func\(int\) bool`)
 			assert.Regexp(t, matchingExp, r)
 		}
 	}()