Make memory to really set and get leackyBucket - Also adds some tests
diff --git a/ratelimiter/memory.go b/ratelimiter/memory.go index 8f6a578..8ccac98 100644 --- a/ratelimiter/memory.go +++ b/ratelimiter/memory.go
@@ -5,16 +5,19 @@ "time" ) -const GC_SIZE int = 100 +const ( + GC_SIZE int = 100 + GC_RATE = 60 * time.Second +) type Memory struct { - store map[string]LeakyBucket + store map[string]*LeakyBucket lastGCCollected time.Time } func NewMemory() *Memory { m := new(Memory) - m.store = make(map[string]LeakyBucket) + m.store = make(map[string]*LeakyBucket) m.lastGCCollected = time.Now() return m } @@ -26,10 +29,10 @@ return nil, errors.New("miss") } - return &bucket, nil + return bucket, nil } -func (m *Memory) SetBucketFor(key string, bucket LeakyBucket) error { +func (m *Memory) SetBucketFor(key string, bucket *LeakyBucket) error { if len(m.store) > GC_SIZE { m.GarbageCollect()
diff --git a/ratelimiter/memory_test.go b/ratelimiter/memory_test.go new file mode 100644 index 0000000..26428e3 --- /dev/null +++ b/ratelimiter/memory_test.go
@@ -0,0 +1,79 @@ +package ratelimiter + +import ( + "os" + "reflect" + "testing" + "time" +) + +func TestMain(m *testing.M) { + os.Exit(m.Run()) +} + +func TestSetBucketFor(t *testing.T) { + buckets := getTestBuckets("bucketOne", "bucketTwo") + mem := NewMemory() + if err := setTestBuckets(mem, buckets); err != nil { + t.Errorf("SetBucketFor failed: %v", err) + } +} + +func TestGetBucketFor(t *testing.T) { + buckets := getTestBuckets("bucketOne", "bucketTwo", "bucketThree") + mem := NewMemory() + if err := setTestBuckets(mem, buckets); err != nil { + t.Errorf("SetBucketFor failed: %v", err) + } + + for name, testBucket := range buckets { + storedBucket, err := mem.GetBucketFor(name) + if err != nil { + t.Errorf("GetBucketFor failed: %v", err) + } + if storedBucket == nil { + t.Error("Received a nil bucket") + } + if !reflect.DeepEqual(storedBucket, testBucket) { + t.Error("Unexpected bucket") + } + } +} + +func TestGarbageCollect(t *testing.T) { + bucketName := "GC-test-bucket" + mem := NewMemory() + bucket := NewLeakyBucket(1, time.Second) + + if err := mem.SetBucketFor(bucketName, bucket); err != nil { + t.Errorf("SetBucketFor failed: %v", err) + } + + bucket.Pour(1) + mem.GarbageCollect() + + bucket, err := mem.GetBucketFor(bucketName) + if err == nil || err.Error() != "miss" { + t.Errorf("Expected an error from GetBucketFor") + } + if bucket != nil { + t.Errorf("GarbageCollect did not clear bucket for: %s", bucketName) + } +} + +func setTestBuckets(mem *Memory, buckets map[string]*LeakyBucket) error { + for name, bucket := range buckets { + if err := mem.SetBucketFor(name, bucket); err != nil { + return err + } + } + return nil +} + +func getTestBuckets(names ...string) map[string]*LeakyBucket { + buckets := make(map[string]*LeakyBucket, len(names)) + for _, name := range names { + buckets[name] = NewLeakyBucket(1, time.Second) + } + return buckets +}