Add ReadByte method, satisfies the io.ByteReader interface
diff --git a/decode.go b/decode.go
index f1e04b1..23c6e26 100644
--- a/decode.go
+++ b/decode.go
@@ -118,32 +118,23 @@
return true
}
-// Read satisfies the io.Reader interface.
-func (r *Reader) Read(p []byte) (int, error) {
- if r.err != nil {
- return 0, r.err
- }
- for {
- if r.i < r.j {
- n := copy(p, r.decoded[r.i:r.j])
- r.i += n
- return n, nil
- }
+func (r *Reader) fill() error {
+ for r.i >= r.j {
if !r.readFull(r.buf[:4], true) {
- return 0, r.err
+ return r.err
}
chunkType := r.buf[0]
if !r.readHeader {
if chunkType != chunkTypeStreamIdentifier {
r.err = ErrCorrupt
- return 0, r.err
+ return r.err
}
r.readHeader = true
}
chunkLen := int(r.buf[1]) | int(r.buf[2])<<8 | int(r.buf[3])<<16
if chunkLen > len(r.buf) {
r.err = ErrUnsupported
- return 0, r.err
+ return r.err
}
// The chunk types are specified at
@@ -153,11 +144,11 @@
// Section 4.2. Compressed data (chunk type 0x00).
if chunkLen < checksumSize {
r.err = ErrCorrupt
- return 0, r.err
+ return r.err
}
buf := r.buf[:chunkLen]
if !r.readFull(buf, false) {
- return 0, r.err
+ return r.err
}
checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24
buf = buf[checksumSize:]
@@ -165,19 +156,19 @@
n, err := DecodedLen(buf)
if err != nil {
r.err = err
- return 0, r.err
+ return r.err
}
if n > len(r.decoded) {
r.err = ErrCorrupt
- return 0, r.err
+ return r.err
}
if _, err := Decode(r.decoded, buf); err != nil {
r.err = err
- return 0, r.err
+ return r.err
}
if crc(r.decoded[:n]) != checksum {
r.err = ErrCorrupt
- return 0, r.err
+ return r.err
}
r.i, r.j = 0, n
continue
@@ -186,25 +177,25 @@
// Section 4.3. Uncompressed data (chunk type 0x01).
if chunkLen < checksumSize {
r.err = ErrCorrupt
- return 0, r.err
+ return r.err
}
buf := r.buf[:checksumSize]
if !r.readFull(buf, false) {
- return 0, r.err
+ return r.err
}
checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24
// Read directly into r.decoded instead of via r.buf.
n := chunkLen - checksumSize
if n > len(r.decoded) {
r.err = ErrCorrupt
- return 0, r.err
+ return r.err
}
if !r.readFull(r.decoded[:n], false) {
- return 0, r.err
+ return r.err
}
if crc(r.decoded[:n]) != checksum {
r.err = ErrCorrupt
- return 0, r.err
+ return r.err
}
r.i, r.j = 0, n
continue
@@ -213,15 +204,15 @@
// Section 4.1. Stream identifier (chunk type 0xff).
if chunkLen != len(magicBody) {
r.err = ErrCorrupt
- return 0, r.err
+ return r.err
}
if !r.readFull(r.buf[:len(magicBody)], false) {
- return 0, r.err
+ return r.err
}
for i := 0; i < len(magicBody); i++ {
if r.buf[i] != magicBody[i] {
r.err = ErrCorrupt
- return 0, r.err
+ return r.err
}
}
continue
@@ -230,12 +221,44 @@
if chunkType <= 0x7f {
// Section 4.5. Reserved unskippable chunks (chunk types 0x02-0x7f).
r.err = ErrUnsupported
- return 0, r.err
+ return r.err
}
// Section 4.4 Padding (chunk type 0xfe).
// Section 4.6. Reserved skippable chunks (chunk types 0x80-0xfd).
if !r.readFull(r.buf[:chunkLen], false) {
- return 0, r.err
+ return r.err
}
}
+
+ return nil
+}
+
+// Read satisfies the io.Reader interface.
+func (r *Reader) Read(p []byte) (int, error) {
+ if r.err != nil {
+ return 0, r.err
+ }
+
+ if err := r.fill(); err != nil {
+ return 0, err
+ }
+
+ n := copy(p, r.decoded[r.i:r.j])
+ r.i += n
+ return n, nil
+}
+
+// ReadByte satisfies the io.ByteReader interface.
+func (r *Reader) ReadByte() (byte, error) {
+ if r.err != nil {
+ return 0, r.err
+ }
+
+ if err := r.fill(); err != nil {
+ return 0, err
+ }
+
+ c := r.decoded[r.i]
+ r.i++
+ return c, nil
}
diff --git a/snappy_test.go b/snappy_test.go
index 2712710..90fab80 100644
--- a/snappy_test.go
+++ b/snappy_test.go
@@ -1032,6 +1032,71 @@
}
}
+func TestReaderReadByte(t *testing.T) {
+ // Test all 32 possible sub-sequences of these 5 input slices prefixed by
+ // their size encoded as a uvarint.
+ //
+ // Their lengths sum to 400,000, which is over 6 times the Writer ibuf
+ // capacity: 6 * maxBlockSize is 393,216.
+ inputs := [][]byte{
+ bytes.Repeat([]byte{'a'}, 40000),
+ bytes.Repeat([]byte{'b'}, 150000),
+ bytes.Repeat([]byte{'c'}, 60000),
+ bytes.Repeat([]byte{'d'}, 120000),
+ bytes.Repeat([]byte{'e'}, 30000),
+ }
+loop:
+ for i := 0; i < 1<<uint(len(inputs)); i++ {
+ var want []int
+ buf := new(bytes.Buffer)
+ w := NewBufferedWriter(buf)
+ p := make([]byte, binary.MaxVarintLen64)
+ for j, input := range inputs {
+ if i&(1<<uint(j)) == 0 {
+ continue
+ }
+ n := binary.PutUvarint(p, uint64(len(input)))
+ if _, err := w.Write(p[:n]); err != nil {
+ t.Errorf("i=%#02x: j=%d: Write Uvarint: %v", i, j, err)
+ continue loop
+ }
+ if _, err := w.Write(input); err != nil {
+ t.Errorf("i=%#02x: j=%d: Write: %v", i, j, err)
+ continue loop
+ }
+ want = append(want, j)
+ }
+ if err := w.Close(); err != nil {
+ t.Errorf("i=%#02x: Close: %v", i, err)
+ continue
+ }
+ r := NewReader(buf)
+ for _, j := range want {
+ size, err := binary.ReadUvarint(r)
+ if err != nil {
+ t.Errorf("i=%#02x: ReadUvarint: %v", i, err)
+ continue loop
+ }
+ if wantedSize := uint64(len(inputs[j])); size != wantedSize {
+ t.Errorf("i=%#02x: expected size %d, got %d", i, wantedSize, size)
+ continue loop
+ }
+ got := make([]byte, size)
+ if _, err := io.ReadFull(r, got); err != nil {
+ t.Errorf("i=%#02x: ReadFull: %v", i, err)
+ continue loop
+ }
+ if err := cmp(got, inputs[j]); err != nil {
+ t.Errorf("i=%#02x: %v", i, err)
+ continue
+ }
+ }
+ if _, err := r.ReadByte(); err != io.EOF {
+ t.Errorf("i=%#02x: expected size EOF, got %v", i, err)
+ }
+ }
+}
+
func TestWriterReset(t *testing.T) {
gold := bytes.Repeat([]byte("Not all those who wander are lost;\n"), 10000)
const n = 20