blob: a0fa746d4645da8f13b5e3b29f02edb354091484 [file] [edit]
/*
Copyright 2011 The gomemcache AUTHORS
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package memcache provides a client for the memcached cache server.
package memcache
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"flag"
"fmt"
"io"
"io/ioutil"
"net"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"testing"
"time"
)
var debug = flag.Bool("debug", false, "be more verbose")
const localhostTCPAddr = "localhost:11211"
func TestLocalhost(t *testing.T) {
t.Parallel()
c, err := net.Dial("tcp", localhostTCPAddr)
if err != nil {
t.Skipf("skipping test; no server running at %s", localhostTCPAddr)
}
io.WriteString(c, "flush_all\r\n")
c.Close()
testWithClient(t, New(localhostTCPAddr))
}
// Run the memcached binary as a child process and connect to its unix socket.
func TestUnixSocket(t *testing.T) {
t.Parallel()
sock := fmt.Sprintf("/tmp/test-gomemcache-%d.sock", os.Getpid())
cmd := exec.Command("memcached", "-s", sock)
if err := cmd.Start(); err != nil {
t.Skipf("skipping test; couldn't find memcached")
return
}
defer cmd.Wait()
defer cmd.Process.Kill()
// Wait a bit for the socket to appear.
for i := 0; i < 10; i++ {
if _, err := os.Stat(sock); err == nil {
break
}
time.Sleep(time.Duration(25*i) * time.Millisecond)
}
testWithClient(t, New(sock))
}
func TestFakeServer(t *testing.T) {
t.Parallel()
ln, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("failed to listen: %v", err)
}
t.Logf("running test server on %s", ln.Addr())
defer ln.Close()
srv := &testServer{}
go srv.Serve(ln)
testWithClient(t, New(ln.Addr().String()))
}
func TestTLS(t *testing.T) {
t.Parallel()
td := t.TempDir()
// Test whether our memcached binary has TLS support. We --enable-ssl first,
// before --version, as memcached evaluates the flags in the order provided
// and we want it to fail if it's built without TLS support (as it is in
// Debian, but not Ubuntu or Homebrew).
out, err := exec.Command("memcached", "--enable-ssl", "--version").CombinedOutput()
if err != nil {
t.Skipf("skipping test; couldn't find memcached or no TLS support in binary: %v, %s", err, out)
}
t.Logf("version: %s", bytes.TrimSpace(out))
if err := os.WriteFile(filepath.Join(td, "/cert.pem"), LocalhostCert, 0644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(td, "/key.pem"), LocalhostKey, 0644); err != nil {
t.Fatal(err)
}
// Find some unused port. This is racy but we hope for the best and hope the kernel
// doesn't reassign our ephemeral port to somebody in the tiny race window.
ln, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("failed to listen: %v", err)
}
port := ln.Addr().(*net.TCPAddr).Port
ln.Close()
cmd := exec.Command("memcached",
"--port="+strconv.Itoa(port),
"--listen=127.0.0.1",
"--enable-ssl",
"-o", "ssl_chain_cert=cert.pem",
"-o", "ssl_key=key.pem")
cmd.Dir = td
if *debug {
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
}
if err := cmd.Start(); err != nil {
t.Fatalf("failed to start memcached: %v", err)
}
defer cmd.Wait()
defer cmd.Process.Kill()
// Wait a bit for the server to be running.
for i := 0; i < 10; i++ {
nc, err := net.Dial("tcp", "localhost:"+strconv.Itoa(port))
if err == nil {
t.Logf("localhost:%d is up.", port)
nc.Close()
break
}
t.Logf("waiting for localhost:%d to be up...", port)
time.Sleep(time.Duration(25*i) * time.Millisecond)
}
c := New(net.JoinHostPort("127.0.0.1", strconv.Itoa(port)))
c.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
var td tls.Dialer
td.Config = &tls.Config{
InsecureSkipVerify: true,
}
return td.DialContext(ctx, network, addr)
}
testWithClient(t, c)
}
func mustSetF(t *testing.T, c *Client) func(*Item) {
return func(it *Item) {
if err := c.Set(it); err != nil {
t.Fatalf("failed to Set %#v: %v", *it, err)
}
}
}
func testWithClient(t *testing.T, c *Client) {
checkErr := func(err error, format string, args ...interface{}) {
t.Helper()
if err != nil {
t.Fatalf(format, args...)
}
}
mustSet := mustSetF(t, c)
// Set
foo := &Item{Key: "foo", Value: []byte("fooval-fromset"), Flags: 123}
err := c.Set(foo)
checkErr(err, "first set(foo): %v", err)
err = c.Set(foo)
checkErr(err, "second set(foo): %v", err)
// CompareAndSwap
it, err := c.Get("foo")
checkErr(err, "get(foo): %v", err)
if string(it.Value) != "fooval-fromset" {
t.Errorf("get(foo) Value = %q, want fooval-romset", it.Value)
}
it0, err := c.Get("foo") // another get, to fail our CAS later
checkErr(err, "get(foo): %v", err)
it.Value = []byte("fooval")
err = c.CompareAndSwap(it)
checkErr(err, "cas(foo): %v", err)
it0.Value = []byte("should-fail")
if err := c.CompareAndSwap(it0); err != ErrCASConflict {
t.Fatalf("cas(foo) error = %v; want ErrCASConflict", err)
}
// Get
it, err = c.Get("foo")
checkErr(err, "get(foo): %v", err)
if it.Key != "foo" {
t.Errorf("get(foo) Key = %q, want foo", it.Key)
}
if string(it.Value) != "fooval" {
t.Errorf("get(foo) Value = %q, want fooval", it.Value)
}
if it.Flags != 123 {
t.Errorf("get(foo) Flags = %v, want 123", it.Flags)
}
// Get and set a unicode key
quxKey := "Hello_世界"
qux := &Item{Key: quxKey, Value: []byte("hello world")}
err = c.Set(qux)
checkErr(err, "first set(Hello_世界): %v", err)
it, err = c.Get(quxKey)
checkErr(err, "get(Hello_世界): %v", err)
if it.Key != quxKey {
t.Errorf("get(Hello_世界) Key = %q, want Hello_世界", it.Key)
}
if string(it.Value) != "hello world" {
t.Errorf("get(Hello_世界) Value = %q, want hello world", string(it.Value))
}
// Set malformed keys
malFormed := &Item{Key: "foo bar", Value: []byte("foobarval")}
err = c.Set(malFormed)
if err != ErrMalformedKey {
t.Errorf("set(foo bar) should return ErrMalformedKey instead of %v", err)
}
malFormed = &Item{Key: "foo" + string(rune(0x7f)), Value: []byte("foobarval")}
err = c.Set(malFormed)
if err != ErrMalformedKey {
t.Errorf("set(foo<0x7f>) should return ErrMalformedKey instead of %v", err)
}
// Add
bar := &Item{Key: "bar", Value: []byte("barval")}
err = c.Add(bar)
checkErr(err, "first add(foo): %v", err)
if err := c.Add(bar); err != ErrNotStored {
t.Fatalf("second add(foo) want ErrNotStored, got %v", err)
}
// Append
append := &Item{Key: "append", Value: []byte("appendval")}
if err := c.Append(append); err != ErrNotStored {
t.Fatalf("first append(append) want ErrNotStored, got %v", err)
}
c.Set(append)
err = c.Append(&Item{Key: "append", Value: []byte("1")})
checkErr(err, "second append(append): %v", err)
appended, err := c.Get("append")
checkErr(err, "third append(append): %v", err)
if string(appended.Value) != string(append.Value)+"1" {
t.Fatalf("Append: want=append1, got=%s", string(appended.Value))
}
// Prepend
prepend := &Item{Key: "prepend", Value: []byte("prependval")}
if err := c.Prepend(prepend); err != ErrNotStored {
t.Fatalf("first prepend(prepend) want ErrNotStored, got %v", err)
}
c.Set(prepend)
err = c.Prepend(&Item{Key: "prepend", Value: []byte("1")})
checkErr(err, "second prepend(prepend): %v", err)
prepended, err := c.Get("prepend")
checkErr(err, "third prepend(prepend): %v", err)
if string(prepended.Value) != "1"+string(prepend.Value) {
t.Fatalf("Prepend: want=1prepend, got=%s", string(prepended.Value))
}
// Replace
baz := &Item{Key: "baz", Value: []byte("bazvalue")}
if err := c.Replace(baz); err != ErrNotStored {
t.Fatalf("expected replace(baz) to return ErrNotStored, got %v", err)
}
err = c.Replace(bar)
checkErr(err, "replaced(foo): %v", err)
// GetMulti
m, err := c.GetMulti([]string{"foo", "bar"})
checkErr(err, "GetMulti: %v", err)
if g, e := len(m), 2; g != e {
t.Errorf("GetMulti: got len(map) = %d, want = %d", g, e)
}
if _, ok := m["foo"]; !ok {
t.Fatalf("GetMulti: didn't get key 'foo'")
}
if _, ok := m["bar"]; !ok {
t.Fatalf("GetMulti: didn't get key 'bar'")
}
if g, e := string(m["foo"].Value), "fooval"; g != e {
t.Errorf("GetMulti: foo: got %q, want %q", g, e)
}
if g, e := string(m["bar"].Value), "barval"; g != e {
t.Errorf("GetMulti: bar: got %q, want %q", g, e)
}
// Delete
err = c.Delete("foo")
checkErr(err, "Delete: %v", err)
it, err = c.Get("foo")
if err != ErrCacheMiss {
t.Errorf("post-Delete want ErrCacheMiss, got %v", err)
}
// Incr/Decr
mustSet(&Item{Key: "num", Value: []byte("42")})
n, err := c.Increment("num", 8)
checkErr(err, "Increment num + 8: %v", err)
if n != 50 {
t.Fatalf("Increment num + 8: want=50, got=%d", n)
}
n, err = c.Decrement("num", 49)
checkErr(err, "Decrement: %v", err)
if n != 1 {
t.Fatalf("Decrement 49: want=1, got=%d", n)
}
err = c.Delete("num")
checkErr(err, "delete num: %v", err)
n, err = c.Increment("num", 1)
if err != ErrCacheMiss {
t.Fatalf("increment post-delete: want ErrCacheMiss, got %v", err)
}
mustSet(&Item{Key: "num", Value: []byte("not-numeric")})
n, err = c.Increment("num", 1)
if err == nil || !strings.Contains(err.Error(), "client error") {
t.Fatalf("increment non-number: want client error, got %v", err)
}
testTouchWithClient(t, c)
// Test Delete All
err = c.DeleteAll()
checkErr(err, "DeleteAll: %v", err)
it, err = c.Get("bar")
if err != ErrCacheMiss {
t.Errorf("post-DeleteAll want ErrCacheMiss, got %v", err)
}
// Test Ping
err = c.Ping()
checkErr(err, "error ping: %s", err)
}
func testTouchWithClient(t *testing.T, c *Client) {
if testing.Short() {
t.Log("Skipping testing memcache Touch with testing in Short mode")
return
}
mustSet := mustSetF(t, c)
const secondsToExpiry = int32(2)
// We will set foo and bar to expire in 2 seconds, then we'll keep touching
// foo every second
// After 3 seconds, we expect foo to be available, and bar to be expired
foo := &Item{Key: "foo", Value: []byte("fooval"), Expiration: secondsToExpiry}
bar := &Item{Key: "bar", Value: []byte("barval"), Expiration: secondsToExpiry}
setTime := time.Now()
mustSet(foo)
mustSet(bar)
for s := 0; s < 3; s++ {
time.Sleep(time.Duration(1 * time.Second))
err := c.Touch(foo.Key, secondsToExpiry)
if nil != err {
t.Errorf("error touching foo: %v", err.Error())
}
}
_, err := c.Get("foo")
if err != nil {
if err == ErrCacheMiss {
t.Fatalf("touching failed to keep item foo alive")
} else {
t.Fatalf("unexpected error retrieving foo after touching: %v", err.Error())
}
}
_, err = c.Get("bar")
if err == nil {
t.Fatalf("item bar did not expire within %v seconds", time.Now().Sub(setTime).Seconds())
} else {
if err != ErrCacheMiss {
t.Fatalf("unexpected error retrieving bar: %v", err.Error())
}
}
}
func BenchmarkOnItem(b *testing.B) {
fakeServer, err := net.Listen("tcp", "localhost:0")
if err != nil {
b.Fatal("Could not open fake server: ", err)
}
defer fakeServer.Close()
go func() {
for {
if c, err := fakeServer.Accept(); err == nil {
go func() { io.Copy(ioutil.Discard, c) }()
} else {
return
}
}
}()
addr := fakeServer.Addr()
c := New(addr.String())
if _, err := c.getConn(addr); err != nil {
b.Fatal("failed to initialize connection to fake server")
}
item := Item{Key: "foo"}
dummyFn := func(_ *Client, _ *bufio.ReadWriter, _ *Item) error { return nil }
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.onItem(&item, dummyFn)
}
}
func BenchmarkScanGetResponseLine(b *testing.B) {
line := []byte("VALUE foobar1234 0 4096 1234\r\n")
var it Item
for i := 0; i < b.N; i++ {
_, err := scanGetResponseLine(line, &it)
if err != nil {
b.Fatal(err)
}
}
}
func TestScanGetResponseLine(t *testing.T) {
tests := []struct {
name string
line string
wantKey string
wantFlags uint32
wantCasid uint64
wantSize int
wantErr bool
}{
{name: "blank", line: "",
wantErr: true},
{name: "malformed1", line: "VALU foobar1234 1 4096\r\n",
wantErr: true},
{name: "malformed2", line: "VALUEfoobar1234 1 4096\r\n",
wantErr: true},
{name: "malformed3", line: "VALUE foobar1234 14096\r\n",
wantErr: true},
{name: "malformed4", line: "VALUE foobar123414096\r\n",
wantErr: true},
{name: "no-eol", line: "VALUE foobar1234 1 4096",
wantErr: true},
{name: "basic", line: "VALUE foobar1234 1 4096\r\n",
wantKey: "foobar1234", wantFlags: 1, wantSize: 4096},
{name: "casid", line: "VALUE foobar1234 1 4096 1234\r\n",
wantKey: "foobar1234", wantFlags: 1, wantSize: 4096, wantCasid: 1234},
{name: "flags-max-uint32", line: "VALUE key 4294967295 1\r\n",
wantKey: "key", wantFlags: 4294967295, wantSize: 1},
{name: "flags-overflow", line: "VALUE key 4294967296 1\r\n",
wantErr: true},
{name: "size-max-uint32", line: "VALUE key 1 2147483647\r\n",
wantKey: "key", wantFlags: 1, wantSize: 2147483647},
{name: "size-overflow", line: "VALUE key 1 4294967296\r\n",
wantErr: true},
{name: "casid-overflow", line: "VALUE key 1 4096 18446744073709551616\r\n",
wantErr: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var got Item
gotSize, err := scanGetResponseLine([]byte(tt.line), &got)
if tt.wantErr {
if err == nil {
t.Errorf("scanGetResponseLine() should have returned error")
}
return
}
if err != nil {
t.Errorf("scanGetResponseLine() returned error %s", err)
return
}
if got.Key != tt.wantKey {
t.Errorf("key = %v, want %v", got.Key, tt.wantKey)
}
if got.Flags != tt.wantFlags {
t.Errorf("flags = %v, want %v", got.Flags, tt.wantFlags)
}
if got.CasID != tt.wantCasid {
t.Errorf("flags = %v, want %v", got.CasID, tt.wantCasid)
}
if gotSize != tt.wantSize {
t.Errorf("size = %v, want %v", gotSize, tt.wantSize)
}
})
}
}