blob: 26428e34acb4446bd255adeaeb6a37857528c33f [file] [log] [blame] [edit]
package ratelimiter
import (
"os"
"reflect"
"testing"
"time"
)
func TestMain(m *testing.M) {
os.Exit(m.Run())
}
func TestSetBucketFor(t *testing.T) {
buckets := getTestBuckets("bucketOne", "bucketTwo")
mem := NewMemory()
if err := setTestBuckets(mem, buckets); err != nil {
t.Errorf("SetBucketFor failed: %v", err)
}
}
func TestGetBucketFor(t *testing.T) {
buckets := getTestBuckets("bucketOne", "bucketTwo", "bucketThree")
mem := NewMemory()
if err := setTestBuckets(mem, buckets); err != nil {
t.Errorf("SetBucketFor failed: %v", err)
}
for name, testBucket := range buckets {
storedBucket, err := mem.GetBucketFor(name)
if err != nil {
t.Errorf("GetBucketFor failed: %v", err)
}
if storedBucket == nil {
t.Error("Received a nil bucket")
}
if !reflect.DeepEqual(storedBucket, testBucket) {
t.Error("Unexpected bucket")
}
}
}
func TestGarbageCollect(t *testing.T) {
bucketName := "GC-test-bucket"
mem := NewMemory()
bucket := NewLeakyBucket(1, time.Second)
if err := mem.SetBucketFor(bucketName, bucket); err != nil {
t.Errorf("SetBucketFor failed: %v", err)
}
bucket.Pour(1)
mem.GarbageCollect()
bucket, err := mem.GetBucketFor(bucketName)
if err == nil || err.Error() != "miss" {
t.Errorf("Expected an error from GetBucketFor")
}
if bucket != nil {
t.Errorf("GarbageCollect did not clear bucket for: %s", bucketName)
}
}
func setTestBuckets(mem *Memory, buckets map[string]*LeakyBucket) error {
for name, bucket := range buckets {
if err := mem.SetBucketFor(name, bucket); err != nil {
return err
}
}
return nil
}
func getTestBuckets(names ...string) map[string]*LeakyBucket {
buckets := make(map[string]*LeakyBucket, len(names))
for _, name := range names {
buckets[name] = NewLeakyBucket(1, time.Second)
}
return buckets
}