| // Copyright 2011 Aaron Jacobs. All Rights Reserved. |
| // Author: [email protected] (Aaron Jacobs) |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| package oglematchers |
| |
| import ( |
| "errors" |
| "fmt" |
| "math" |
| "reflect" |
| ) |
| |
| // Equals(x) returns a matcher that matches values v such that v and x are |
| // equivalent. This includes the case when the comparison v == x using Go's |
| // built-in comparison operator is legal (except for structs, which this |
| // matcher does not support), but for convenience the following rules also |
| // apply: |
| // |
| // * Type checking is done based on underlying types rather than actual |
| // types, so that e.g. two aliases for string can be compared: |
| // |
| // type stringAlias1 string |
| // type stringAlias2 string |
| // |
| // a := "taco" |
| // b := stringAlias1("taco") |
| // c := stringAlias2("taco") |
| // |
| // ExpectTrue(a == b) // Legal, passes |
| // ExpectTrue(b == c) // Illegal, doesn't compile |
| // |
| // ExpectThat(a, Equals(b)) // Passes |
| // ExpectThat(b, Equals(c)) // Passes |
| // |
| // * Values of numeric type are treated as if they were abstract numbers, and |
| // compared accordingly. Therefore Equals(17) will match int(17), |
| // int16(17), uint(17), float32(17), complex64(17), and so on. |
| // |
| // If you want a stricter matcher that contains no such cleverness, see |
| // IdenticalTo instead. |
| // |
| // Arrays are supported by this matcher, but do not participate in the |
| // exceptions above. Two arrays compared with this matcher must have identical |
| // types, and their element type must itself be comparable according to Go's == |
| // operator. |
| func Equals(x interface{}) Matcher { |
| v := reflect.ValueOf(x) |
| |
| // This matcher doesn't support structs. |
| if v.Kind() == reflect.Struct { |
| panic(fmt.Sprintf("oglematchers.Equals: unsupported kind %v", v.Kind())) |
| } |
| |
| // The == operator is not defined for non-nil slices. |
| if v.Kind() == reflect.Slice && v.Pointer() != uintptr(0) { |
| panic(fmt.Sprintf("oglematchers.Equals: non-nil slice")) |
| } |
| |
| return &equalsMatcher{v} |
| } |
| |
| type equalsMatcher struct { |
| expectedValue reflect.Value |
| } |
| |
| //////////////////////////////////////////////////////////////////////// |
| // Numeric types |
| //////////////////////////////////////////////////////////////////////// |
| |
| func isSignedInteger(v reflect.Value) bool { |
| k := v.Kind() |
| return k >= reflect.Int && k <= reflect.Int64 |
| } |
| |
| func isUnsignedInteger(v reflect.Value) bool { |
| k := v.Kind() |
| return k >= reflect.Uint && k <= reflect.Uintptr |
| } |
| |
| func isInteger(v reflect.Value) bool { |
| return isSignedInteger(v) || isUnsignedInteger(v) |
| } |
| |
| func isFloat(v reflect.Value) bool { |
| k := v.Kind() |
| return k == reflect.Float32 || k == reflect.Float64 |
| } |
| |
| func isComplex(v reflect.Value) bool { |
| k := v.Kind() |
| return k == reflect.Complex64 || k == reflect.Complex128 |
| } |
| |
| func checkAgainstInt64(e int64, c reflect.Value) (err error) { |
| err = errors.New("") |
| |
| switch { |
| case isSignedInteger(c): |
| if c.Int() == e { |
| err = nil |
| } |
| |
| case isUnsignedInteger(c): |
| u := c.Uint() |
| if u <= math.MaxInt64 && int64(u) == e { |
| err = nil |
| } |
| |
| // Turn around the various floating point types so that the checkAgainst* |
| // functions for them can deal with precision issues. |
| case isFloat(c), isComplex(c): |
| return Equals(c.Interface()).Matches(e) |
| |
| default: |
| err = NewFatalError("which is not numeric") |
| } |
| |
| return |
| } |
| |
| func checkAgainstUint64(e uint64, c reflect.Value) (err error) { |
| err = errors.New("") |
| |
| switch { |
| case isSignedInteger(c): |
| i := c.Int() |
| if i >= 0 && uint64(i) == e { |
| err = nil |
| } |
| |
| case isUnsignedInteger(c): |
| if c.Uint() == e { |
| err = nil |
| } |
| |
| // Turn around the various floating point types so that the checkAgainst* |
| // functions for them can deal with precision issues. |
| case isFloat(c), isComplex(c): |
| return Equals(c.Interface()).Matches(e) |
| |
| default: |
| err = NewFatalError("which is not numeric") |
| } |
| |
| return |
| } |
| |
| func checkAgainstFloat32(e float32, c reflect.Value) (err error) { |
| err = errors.New("") |
| |
| switch { |
| case isSignedInteger(c): |
| if float32(c.Int()) == e { |
| err = nil |
| } |
| |
| case isUnsignedInteger(c): |
| if float32(c.Uint()) == e { |
| err = nil |
| } |
| |
| case isFloat(c): |
| // Compare using float32 to avoid a false sense of precision; otherwise |
| // e.g. Equals(float32(0.1)) won't match float32(0.1). |
| if float32(c.Float()) == e { |
| err = nil |
| } |
| |
| case isComplex(c): |
| comp := c.Complex() |
| rl := real(comp) |
| im := imag(comp) |
| |
| // Compare using float32 to avoid a false sense of precision; otherwise |
| // e.g. Equals(float32(0.1)) won't match (0.1 + 0i). |
| if im == 0 && float32(rl) == e { |
| err = nil |
| } |
| |
| default: |
| err = NewFatalError("which is not numeric") |
| } |
| |
| return |
| } |
| |
| func checkAgainstFloat64(e float64, c reflect.Value) (err error) { |
| err = errors.New("") |
| |
| ck := c.Kind() |
| |
| switch { |
| case isSignedInteger(c): |
| if float64(c.Int()) == e { |
| err = nil |
| } |
| |
| case isUnsignedInteger(c): |
| if float64(c.Uint()) == e { |
| err = nil |
| } |
| |
| // If the actual value is lower precision, turn the comparison around so we |
| // apply the low-precision rules. Otherwise, e.g. Equals(0.1) may not match |
| // float32(0.1). |
| case ck == reflect.Float32 || ck == reflect.Complex64: |
| return Equals(c.Interface()).Matches(e) |
| |
| // Otherwise, compare with double precision. |
| case isFloat(c): |
| if c.Float() == e { |
| err = nil |
| } |
| |
| case isComplex(c): |
| comp := c.Complex() |
| rl := real(comp) |
| im := imag(comp) |
| |
| if im == 0 && rl == e { |
| err = nil |
| } |
| |
| default: |
| err = NewFatalError("which is not numeric") |
| } |
| |
| return |
| } |
| |
| func checkAgainstComplex64(e complex64, c reflect.Value) (err error) { |
| err = errors.New("") |
| realPart := real(e) |
| imaginaryPart := imag(e) |
| |
| switch { |
| case isInteger(c) || isFloat(c): |
| // If we have no imaginary part, then we should just compare against the |
| // real part. Otherwise, we can't be equal. |
| if imaginaryPart != 0 { |
| return |
| } |
| |
| return checkAgainstFloat32(realPart, c) |
| |
| case isComplex(c): |
| // Compare using complex64 to avoid a false sense of precision; otherwise |
| // e.g. Equals(0.1 + 0i) won't match float32(0.1). |
| if complex64(c.Complex()) == e { |
| err = nil |
| } |
| |
| default: |
| err = NewFatalError("which is not numeric") |
| } |
| |
| return |
| } |
| |
| func checkAgainstComplex128(e complex128, c reflect.Value) (err error) { |
| err = errors.New("") |
| realPart := real(e) |
| imaginaryPart := imag(e) |
| |
| switch { |
| case isInteger(c) || isFloat(c): |
| // If we have no imaginary part, then we should just compare against the |
| // real part. Otherwise, we can't be equal. |
| if imaginaryPart != 0 { |
| return |
| } |
| |
| return checkAgainstFloat64(realPart, c) |
| |
| case isComplex(c): |
| if c.Complex() == e { |
| err = nil |
| } |
| |
| default: |
| err = NewFatalError("which is not numeric") |
| } |
| |
| return |
| } |
| |
| //////////////////////////////////////////////////////////////////////// |
| // Other types |
| //////////////////////////////////////////////////////////////////////// |
| |
| func checkAgainstBool(e bool, c reflect.Value) (err error) { |
| if c.Kind() != reflect.Bool { |
| err = NewFatalError("which is not a bool") |
| return |
| } |
| |
| err = errors.New("") |
| if c.Bool() == e { |
| err = nil |
| } |
| return |
| } |
| |
| func checkAgainstChan(e reflect.Value, c reflect.Value) (err error) { |
| // Create a description of e's type, e.g. "chan int". |
| typeStr := fmt.Sprintf("%s %s", e.Type().ChanDir(), e.Type().Elem()) |
| |
| // Make sure c is a chan of the correct type. |
| if c.Kind() != reflect.Chan || |
| c.Type().ChanDir() != e.Type().ChanDir() || |
| c.Type().Elem() != e.Type().Elem() { |
| err = NewFatalError(fmt.Sprintf("which is not a %s", typeStr)) |
| return |
| } |
| |
| err = errors.New("") |
| if c.Pointer() == e.Pointer() { |
| err = nil |
| } |
| return |
| } |
| |
| func checkAgainstFunc(e reflect.Value, c reflect.Value) (err error) { |
| // Make sure c is a function. |
| if c.Kind() != reflect.Func { |
| err = NewFatalError("which is not a function") |
| return |
| } |
| |
| err = errors.New("") |
| if c.Pointer() == e.Pointer() { |
| err = nil |
| } |
| return |
| } |
| |
| func checkAgainstMap(e reflect.Value, c reflect.Value) (err error) { |
| // Make sure c is a map. |
| if c.Kind() != reflect.Map { |
| err = NewFatalError("which is not a map") |
| return |
| } |
| |
| err = errors.New("") |
| if c.Pointer() == e.Pointer() { |
| err = nil |
| } |
| return |
| } |
| |
| func checkAgainstPtr(e reflect.Value, c reflect.Value) (err error) { |
| // Create a description of e's type, e.g. "*int". |
| typeStr := fmt.Sprintf("*%v", e.Type().Elem()) |
| |
| // Make sure c is a pointer of the correct type. |
| if c.Kind() != reflect.Ptr || |
| c.Type().Elem() != e.Type().Elem() { |
| err = NewFatalError(fmt.Sprintf("which is not a %s", typeStr)) |
| return |
| } |
| |
| err = errors.New("") |
| if c.Pointer() == e.Pointer() { |
| err = nil |
| } |
| return |
| } |
| |
| func checkAgainstSlice(e reflect.Value, c reflect.Value) (err error) { |
| // Create a description of e's type, e.g. "[]int". |
| typeStr := fmt.Sprintf("[]%v", e.Type().Elem()) |
| |
| // Make sure c is a slice of the correct type. |
| if c.Kind() != reflect.Slice || |
| c.Type().Elem() != e.Type().Elem() { |
| err = NewFatalError(fmt.Sprintf("which is not a %s", typeStr)) |
| return |
| } |
| |
| err = errors.New("") |
| if c.Pointer() == e.Pointer() { |
| err = nil |
| } |
| return |
| } |
| |
| func checkAgainstString(e reflect.Value, c reflect.Value) (err error) { |
| // Make sure c is a string. |
| if c.Kind() != reflect.String { |
| err = NewFatalError("which is not a string") |
| return |
| } |
| |
| err = errors.New("") |
| if c.String() == e.String() { |
| err = nil |
| } |
| return |
| } |
| |
| func checkAgainstArray(e reflect.Value, c reflect.Value) (err error) { |
| // Create a description of e's type, e.g. "[2]int". |
| typeStr := fmt.Sprintf("%v", e.Type()) |
| |
| // Make sure c is the correct type. |
| if c.Type() != e.Type() { |
| err = NewFatalError(fmt.Sprintf("which is not %s", typeStr)) |
| return |
| } |
| |
| // Check for equality. |
| if e.Interface() != c.Interface() { |
| err = errors.New("") |
| return |
| } |
| |
| return |
| } |
| |
| func checkAgainstUnsafePointer(e reflect.Value, c reflect.Value) (err error) { |
| // Make sure c is a pointer. |
| if c.Kind() != reflect.UnsafePointer { |
| err = NewFatalError("which is not a unsafe.Pointer") |
| return |
| } |
| |
| err = errors.New("") |
| if c.Pointer() == e.Pointer() { |
| err = nil |
| } |
| return |
| } |
| |
| func checkForNil(c reflect.Value) (err error) { |
| err = errors.New("") |
| |
| // Make sure it is legal to call IsNil. |
| switch c.Kind() { |
| case reflect.Invalid: |
| case reflect.Chan: |
| case reflect.Func: |
| case reflect.Interface: |
| case reflect.Map: |
| case reflect.Ptr: |
| case reflect.Slice: |
| |
| default: |
| err = NewFatalError("which cannot be compared to nil") |
| return |
| } |
| |
| // Ask whether the value is nil. Handle a nil literal (kind Invalid) |
| // specially, since it's not legal to call IsNil there. |
| if c.Kind() == reflect.Invalid || c.IsNil() { |
| err = nil |
| } |
| return |
| } |
| |
| //////////////////////////////////////////////////////////////////////// |
| // Public implementation |
| //////////////////////////////////////////////////////////////////////// |
| |
| func (m *equalsMatcher) Matches(candidate interface{}) error { |
| e := m.expectedValue |
| c := reflect.ValueOf(candidate) |
| ek := e.Kind() |
| |
| switch { |
| case ek == reflect.Bool: |
| return checkAgainstBool(e.Bool(), c) |
| |
| case isSignedInteger(e): |
| return checkAgainstInt64(e.Int(), c) |
| |
| case isUnsignedInteger(e): |
| return checkAgainstUint64(e.Uint(), c) |
| |
| case ek == reflect.Float32: |
| return checkAgainstFloat32(float32(e.Float()), c) |
| |
| case ek == reflect.Float64: |
| return checkAgainstFloat64(e.Float(), c) |
| |
| case ek == reflect.Complex64: |
| return checkAgainstComplex64(complex64(e.Complex()), c) |
| |
| case ek == reflect.Complex128: |
| return checkAgainstComplex128(complex128(e.Complex()), c) |
| |
| case ek == reflect.Chan: |
| return checkAgainstChan(e, c) |
| |
| case ek == reflect.Func: |
| return checkAgainstFunc(e, c) |
| |
| case ek == reflect.Map: |
| return checkAgainstMap(e, c) |
| |
| case ek == reflect.Ptr: |
| return checkAgainstPtr(e, c) |
| |
| case ek == reflect.Slice: |
| return checkAgainstSlice(e, c) |
| |
| case ek == reflect.String: |
| return checkAgainstString(e, c) |
| |
| case ek == reflect.Array: |
| return checkAgainstArray(e, c) |
| |
| case ek == reflect.UnsafePointer: |
| return checkAgainstUnsafePointer(e, c) |
| |
| case ek == reflect.Invalid: |
| return checkForNil(c) |
| } |
| |
| panic(fmt.Sprintf("equalsMatcher.Matches: unexpected kind: %v", ek)) |
| } |
| |
| func (m *equalsMatcher) Description() string { |
| // Special case: handle nil. |
| if !m.expectedValue.IsValid() { |
| return "is nil" |
| } |
| |
| return fmt.Sprintf("%v", m.expectedValue.Interface()) |
| } |