| import array |
| import gc |
| import io |
| import pathlib |
| import random |
| import re |
| import os |
| import unittest |
| import tempfile |
| import threading |
| |
| from test.support.import_helper import import_module |
| from test.support import threading_helper |
| from test.support import _1M |
| |
| _zstd = import_module("_zstd") |
| zstd = import_module("compression.zstd") |
| |
| from compression.zstd import ( |
| open, |
| compress, |
| decompress, |
| ZstdCompressor, |
| ZstdDecompressor, |
| ZstdDict, |
| ZstdError, |
| zstd_version, |
| zstd_version_info, |
| COMPRESSION_LEVEL_DEFAULT, |
| get_frame_info, |
| get_frame_size, |
| finalize_dict, |
| train_dict, |
| CompressionParameter, |
| DecompressionParameter, |
| Strategy, |
| ZstdFile, |
| ) |
| |
| _1K = 1024 |
| _130_1K = 130 * _1K |
| DICT_SIZE1 = 3*_1K |
| |
| DAT_130K_D = None |
| DAT_130K_C = None |
| |
| DECOMPRESSED_DAT = None |
| COMPRESSED_DAT = None |
| |
| DECOMPRESSED_100_PLUS_32KB = None |
| COMPRESSED_100_PLUS_32KB = None |
| |
| SKIPPABLE_FRAME = None |
| |
| THIS_FILE_BYTES = None |
| THIS_FILE_STR = None |
| COMPRESSED_THIS_FILE = None |
| |
| COMPRESSED_BOGUS = None |
| |
| SAMPLES = None |
| |
| TRAINED_DICT = None |
| |
| # Cannot be deferred to setup as it is used to check whether or not to skip |
| # tests |
| try: |
| SUPPORT_MULTITHREADING = CompressionParameter.nb_workers.bounds() != (0, 0) |
| except Exception: |
| SUPPORT_MULTITHREADING = False |
| |
| C_INT_MIN = -(2**31) |
| C_INT_MAX = (2**31) - 1 |
| |
| |
| def setUpModule(): |
| # uncompressed size 130KB, more than a zstd block. |
| # with a frame epilogue, 4 bytes checksum. |
| global DAT_130K_D |
| DAT_130K_D = bytes([random.randint(0, 127) for _ in range(130*_1K)]) |
| |
| global DAT_130K_C |
| DAT_130K_C = compress(DAT_130K_D, options={CompressionParameter.checksum_flag:1}) |
| |
| global DECOMPRESSED_DAT |
| DECOMPRESSED_DAT = b'abcdefg123456' * 1000 |
| |
| global COMPRESSED_DAT |
| COMPRESSED_DAT = compress(DECOMPRESSED_DAT) |
| |
| global DECOMPRESSED_100_PLUS_32KB |
| DECOMPRESSED_100_PLUS_32KB = b'a' * (100 + 32*_1K) |
| |
| global COMPRESSED_100_PLUS_32KB |
| COMPRESSED_100_PLUS_32KB = compress(DECOMPRESSED_100_PLUS_32KB) |
| |
| global SKIPPABLE_FRAME |
| SKIPPABLE_FRAME = (0x184D2A50).to_bytes(4, byteorder='little') + \ |
| (32*_1K).to_bytes(4, byteorder='little') + \ |
| b'a' * (32*_1K) |
| |
| global THIS_FILE_BYTES, THIS_FILE_STR |
| with io.open(os.path.abspath(__file__), 'rb') as f: |
| THIS_FILE_BYTES = f.read() |
| THIS_FILE_BYTES = re.sub(rb'\r?\n', rb'\n', THIS_FILE_BYTES) |
| THIS_FILE_STR = THIS_FILE_BYTES.decode('utf-8') |
| |
| global COMPRESSED_THIS_FILE |
| COMPRESSED_THIS_FILE = compress(THIS_FILE_BYTES) |
| |
| global COMPRESSED_BOGUS |
| COMPRESSED_BOGUS = DECOMPRESSED_DAT |
| |
| # dict data |
| words = [b'red', b'green', b'yellow', b'black', b'withe', b'blue', |
| b'lilac', b'purple', b'navy', b'glod', b'silver', b'olive', |
| b'dog', b'cat', b'tiger', b'lion', b'fish', b'bird'] |
| lst = [] |
| for i in range(300): |
| sample = [b'%s = %d' % (random.choice(words), random.randrange(100)) |
| for j in range(20)] |
| sample = b'\n'.join(sample) |
| |
| lst.append(sample) |
| global SAMPLES |
| SAMPLES = lst |
| assert len(SAMPLES) > 10 |
| |
| global TRAINED_DICT |
| TRAINED_DICT = train_dict(SAMPLES, 3*_1K) |
| assert len(TRAINED_DICT.dict_content) <= 3*_1K |
| |
| |
| class FunctionsTestCase(unittest.TestCase): |
| |
| def test_version(self): |
| s = ".".join((str(i) for i in zstd_version_info)) |
| self.assertEqual(s, zstd_version) |
| |
| def test_compressionLevel_values(self): |
| min, max = CompressionParameter.compression_level.bounds() |
| self.assertIs(type(COMPRESSION_LEVEL_DEFAULT), int) |
| self.assertIs(type(min), int) |
| self.assertIs(type(max), int) |
| self.assertLess(min, max) |
| |
| def test_roundtrip_default(self): |
| raw_dat = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6] |
| dat1 = compress(raw_dat) |
| dat2 = decompress(dat1) |
| self.assertEqual(dat2, raw_dat) |
| |
| def test_roundtrip_level(self): |
| raw_dat = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6] |
| level_min, level_max = CompressionParameter.compression_level.bounds() |
| |
| for level in range(max(-20, level_min), level_max + 1): |
| dat1 = compress(raw_dat, level) |
| dat2 = decompress(dat1) |
| self.assertEqual(dat2, raw_dat) |
| |
| def test_get_frame_info(self): |
| # no dict |
| info = get_frame_info(COMPRESSED_100_PLUS_32KB[:20]) |
| self.assertEqual(info.decompressed_size, 32 * _1K + 100) |
| self.assertEqual(info.dictionary_id, 0) |
| |
| # use dict |
| dat = compress(b"a" * 345, zstd_dict=TRAINED_DICT) |
| info = get_frame_info(dat) |
| self.assertEqual(info.decompressed_size, 345) |
| self.assertEqual(info.dictionary_id, TRAINED_DICT.dict_id) |
| |
| with self.assertRaisesRegex(ZstdError, "not less than the frame header"): |
| get_frame_info(b"aaaaaaaaaaaaaa") |
| |
| def test_get_frame_size(self): |
| size = get_frame_size(COMPRESSED_100_PLUS_32KB) |
| self.assertEqual(size, len(COMPRESSED_100_PLUS_32KB)) |
| |
| with self.assertRaisesRegex(ZstdError, "not less than this complete frame"): |
| get_frame_size(b"aaaaaaaaaaaaaa") |
| |
| def test_decompress_2x130_1K(self): |
| decompressed_size = get_frame_info(DAT_130K_C).decompressed_size |
| self.assertEqual(decompressed_size, _130_1K) |
| |
| dat = decompress(DAT_130K_C + DAT_130K_C) |
| self.assertEqual(len(dat), 2 * _130_1K) |
| |
| |
| class CompressorTestCase(unittest.TestCase): |
| |
| def test_simple_compress_bad_args(self): |
| # ZstdCompressor |
| self.assertRaises(TypeError, ZstdCompressor, []) |
| self.assertRaises(TypeError, ZstdCompressor, level=3.14) |
| self.assertRaises(TypeError, ZstdCompressor, level="abc") |
| self.assertRaises(TypeError, ZstdCompressor, options=b"abc") |
| |
| self.assertRaises(TypeError, ZstdCompressor, zstd_dict=123) |
| self.assertRaises(TypeError, ZstdCompressor, zstd_dict=b"abcd1234") |
| self.assertRaises(TypeError, ZstdCompressor, zstd_dict={1: 2, 3: 4}) |
| |
| # valid range for compression level is [-(1<<17), 22] |
| msg = r'illegal compression level {}; the valid range is \[-?\d+, -?\d+\]' |
| with self.assertRaisesRegex(ValueError, msg.format(C_INT_MAX)): |
| ZstdCompressor(C_INT_MAX) |
| with self.assertRaisesRegex(ValueError, msg.format(C_INT_MIN)): |
| ZstdCompressor(C_INT_MIN) |
| msg = r'illegal compression level; the valid range is \[-?\d+, -?\d+\]' |
| with self.assertRaisesRegex(ValueError, msg): |
| ZstdCompressor(level=-(2**1000)) |
| with self.assertRaisesRegex(ValueError, msg): |
| ZstdCompressor(level=2**1000) |
| |
| with self.assertRaises(ValueError): |
| ZstdCompressor(options={CompressionParameter.window_log: 100}) |
| with self.assertRaises(ValueError): |
| ZstdCompressor(options={3333: 100}) |
| |
| # Method bad arguments |
| zc = ZstdCompressor() |
| self.assertRaises(TypeError, zc.compress) |
| self.assertRaises((TypeError, ValueError), zc.compress, b"foo", b"bar") |
| self.assertRaises(TypeError, zc.compress, "str") |
| self.assertRaises((TypeError, ValueError), zc.flush, b"foo") |
| self.assertRaises(TypeError, zc.flush, b"blah", 1) |
| |
| self.assertRaises(ValueError, zc.compress, b'', -1) |
| self.assertRaises(ValueError, zc.compress, b'', 3) |
| self.assertRaises(ValueError, zc.flush, zc.CONTINUE) # 0 |
| self.assertRaises(ValueError, zc.flush, 3) |
| |
| zc.compress(b'') |
| zc.compress(b'', zc.CONTINUE) |
| zc.compress(b'', zc.FLUSH_BLOCK) |
| zc.compress(b'', zc.FLUSH_FRAME) |
| empty = zc.flush() |
| zc.flush(zc.FLUSH_BLOCK) |
| zc.flush(zc.FLUSH_FRAME) |
| |
| def test_compress_parameters(self): |
| d = {CompressionParameter.compression_level : 10, |
| |
| CompressionParameter.window_log : 12, |
| CompressionParameter.hash_log : 10, |
| CompressionParameter.chain_log : 12, |
| CompressionParameter.search_log : 12, |
| CompressionParameter.min_match : 4, |
| CompressionParameter.target_length : 12, |
| CompressionParameter.strategy : Strategy.lazy, |
| |
| CompressionParameter.enable_long_distance_matching : 1, |
| CompressionParameter.ldm_hash_log : 12, |
| CompressionParameter.ldm_min_match : 11, |
| CompressionParameter.ldm_bucket_size_log : 5, |
| CompressionParameter.ldm_hash_rate_log : 12, |
| |
| CompressionParameter.content_size_flag : 1, |
| CompressionParameter.checksum_flag : 1, |
| CompressionParameter.dict_id_flag : 0, |
| |
| CompressionParameter.nb_workers : 2 if SUPPORT_MULTITHREADING else 0, |
| CompressionParameter.job_size : 5*_1M if SUPPORT_MULTITHREADING else 0, |
| CompressionParameter.overlap_log : 9 if SUPPORT_MULTITHREADING else 0, |
| } |
| ZstdCompressor(options=d) |
| |
| d1 = d.copy() |
| # larger than signed int |
| d1[CompressionParameter.ldm_bucket_size_log] = C_INT_MAX |
| with self.assertRaises(ValueError): |
| ZstdCompressor(options=d1) |
| # smaller than signed int |
| d1[CompressionParameter.ldm_bucket_size_log] = C_INT_MIN |
| with self.assertRaises(ValueError): |
| ZstdCompressor(options=d1) |
| |
| # out of bounds compression level |
| level_min, level_max = CompressionParameter.compression_level.bounds() |
| with self.assertRaises(ValueError): |
| compress(b'', level_max+1) |
| with self.assertRaises(ValueError): |
| compress(b'', level_min-1) |
| with self.assertRaises(ValueError): |
| compress(b'', 2**1000) |
| with self.assertRaises(ValueError): |
| compress(b'', -(2**1000)) |
| with self.assertRaises(ValueError): |
| compress(b'', options={ |
| CompressionParameter.compression_level: level_max+1}) |
| with self.assertRaises(ValueError): |
| compress(b'', options={ |
| CompressionParameter.compression_level: level_min-1}) |
| |
| # zstd lib doesn't support MT compression |
| if not SUPPORT_MULTITHREADING: |
| with self.assertRaises(ValueError): |
| ZstdCompressor(options={CompressionParameter.nb_workers:4}) |
| with self.assertRaises(ValueError): |
| ZstdCompressor(options={CompressionParameter.job_size:4}) |
| with self.assertRaises(ValueError): |
| ZstdCompressor(options={CompressionParameter.overlap_log:4}) |
| |
| # out of bounds error msg |
| option = {CompressionParameter.window_log:100} |
| with self.assertRaisesRegex( |
| ValueError, |
| "compression parameter 'window_log' received an illegal value 100; " |
| r'the valid range is \[-?\d+, -?\d+\]', |
| ): |
| compress(b'', options=option) |
| |
| def test_unknown_compression_parameter(self): |
| KEY = 100001234 |
| option = {CompressionParameter.compression_level: 10, |
| KEY: 200000000} |
| pattern = rf"invalid compression parameter 'unknown parameter \(key {KEY}\)'" |
| with self.assertRaisesRegex(ValueError, pattern): |
| ZstdCompressor(options=option) |
| |
| @unittest.skipIf(not SUPPORT_MULTITHREADING, |
| "zstd build doesn't support multi-threaded compression") |
| def test_zstd_multithread_compress(self): |
| size = 40*_1M |
| b = THIS_FILE_BYTES * (size // len(THIS_FILE_BYTES)) |
| |
| options = {CompressionParameter.compression_level : 4, |
| CompressionParameter.nb_workers : 2} |
| |
| # compress() |
| dat1 = compress(b, options=options) |
| dat2 = decompress(dat1) |
| self.assertEqual(dat2, b) |
| |
| # ZstdCompressor |
| c = ZstdCompressor(options=options) |
| dat1 = c.compress(b, c.CONTINUE) |
| dat2 = c.compress(b, c.FLUSH_BLOCK) |
| dat3 = c.compress(b, c.FLUSH_FRAME) |
| dat4 = decompress(dat1+dat2+dat3) |
| self.assertEqual(dat4, b * 3) |
| |
| # ZstdFile |
| with ZstdFile(io.BytesIO(), 'w', options=options) as f: |
| f.write(b) |
| |
| def test_compress_flushblock(self): |
| point = len(THIS_FILE_BYTES) // 2 |
| |
| c = ZstdCompressor() |
| self.assertEqual(c.last_mode, c.FLUSH_FRAME) |
| dat1 = c.compress(THIS_FILE_BYTES[:point]) |
| self.assertEqual(c.last_mode, c.CONTINUE) |
| dat1 += c.compress(THIS_FILE_BYTES[point:], c.FLUSH_BLOCK) |
| self.assertEqual(c.last_mode, c.FLUSH_BLOCK) |
| dat2 = c.flush() |
| pattern = "Compressed data ended before the end-of-stream marker" |
| with self.assertRaisesRegex(ZstdError, pattern): |
| decompress(dat1) |
| |
| dat3 = decompress(dat1 + dat2) |
| |
| self.assertEqual(dat3, THIS_FILE_BYTES) |
| |
| def test_compress_flushframe(self): |
| # test compress & decompress |
| point = len(THIS_FILE_BYTES) // 2 |
| |
| c = ZstdCompressor() |
| |
| dat1 = c.compress(THIS_FILE_BYTES[:point]) |
| self.assertEqual(c.last_mode, c.CONTINUE) |
| |
| dat1 += c.compress(THIS_FILE_BYTES[point:], c.FLUSH_FRAME) |
| self.assertEqual(c.last_mode, c.FLUSH_FRAME) |
| |
| nt = get_frame_info(dat1) |
| self.assertEqual(nt.decompressed_size, None) # no content size |
| |
| dat2 = decompress(dat1) |
| |
| self.assertEqual(dat2, THIS_FILE_BYTES) |
| |
| # single .FLUSH_FRAME mode has content size |
| c = ZstdCompressor() |
| dat = c.compress(THIS_FILE_BYTES, mode=c.FLUSH_FRAME) |
| self.assertEqual(c.last_mode, c.FLUSH_FRAME) |
| |
| nt = get_frame_info(dat) |
| self.assertEqual(nt.decompressed_size, len(THIS_FILE_BYTES)) |
| |
| def test_compress_empty(self): |
| # output empty content frame |
| self.assertNotEqual(compress(b''), b'') |
| |
| c = ZstdCompressor() |
| self.assertNotEqual(c.compress(b'', c.FLUSH_FRAME), b'') |
| |
| def test_set_pledged_input_size(self): |
| DAT = DECOMPRESSED_100_PLUS_32KB |
| CHUNK_SIZE = len(DAT) // 3 |
| |
| # wrong value |
| c = ZstdCompressor() |
| with self.assertRaisesRegex(ValueError, |
| r'should be a positive int less than \d+'): |
| c.set_pledged_input_size(-300) |
| # overflow |
| with self.assertRaisesRegex(ValueError, |
| r'should be a positive int less than \d+'): |
| c.set_pledged_input_size(2**64) |
| # ZSTD_CONTENTSIZE_ERROR is invalid |
| with self.assertRaisesRegex(ValueError, |
| r'should be a positive int less than \d+'): |
| c.set_pledged_input_size(2**64-2) |
| # ZSTD_CONTENTSIZE_UNKNOWN should use None |
| with self.assertRaisesRegex(ValueError, |
| r'should be a positive int less than \d+'): |
| c.set_pledged_input_size(2**64-1) |
| |
| # check valid values are settable |
| c.set_pledged_input_size(2**63) |
| c.set_pledged_input_size(2**64-3) |
| |
| # check that zero means empty frame |
| c = ZstdCompressor(level=1) |
| c.set_pledged_input_size(0) |
| c.compress(b'') |
| dat = c.flush() |
| ret = get_frame_info(dat) |
| self.assertEqual(ret.decompressed_size, 0) |
| |
| |
| # wrong mode |
| c = ZstdCompressor(level=1) |
| c.compress(b'123456') |
| self.assertEqual(c.last_mode, c.CONTINUE) |
| with self.assertRaisesRegex(ValueError, |
| r'last_mode == FLUSH_FRAME'): |
| c.set_pledged_input_size(300) |
| |
| # None value |
| c = ZstdCompressor(level=1) |
| c.set_pledged_input_size(None) |
| dat = c.compress(DAT) + c.flush() |
| |
| ret = get_frame_info(dat) |
| self.assertEqual(ret.decompressed_size, None) |
| |
| # correct value |
| c = ZstdCompressor(level=1) |
| c.set_pledged_input_size(len(DAT)) |
| |
| chunks = [] |
| posi = 0 |
| while posi < len(DAT): |
| dat = c.compress(DAT[posi:posi+CHUNK_SIZE]) |
| posi += CHUNK_SIZE |
| chunks.append(dat) |
| |
| dat = c.flush() |
| chunks.append(dat) |
| chunks = b''.join(chunks) |
| |
| ret = get_frame_info(chunks) |
| self.assertEqual(ret.decompressed_size, len(DAT)) |
| self.assertEqual(decompress(chunks), DAT) |
| |
| c.set_pledged_input_size(len(DAT)) # the second frame |
| dat = c.compress(DAT) + c.flush() |
| |
| ret = get_frame_info(dat) |
| self.assertEqual(ret.decompressed_size, len(DAT)) |
| self.assertEqual(decompress(dat), DAT) |
| |
| # not enough data |
| c = ZstdCompressor(level=1) |
| c.set_pledged_input_size(len(DAT)+1) |
| |
| for start in range(0, len(DAT), CHUNK_SIZE): |
| end = min(start+CHUNK_SIZE, len(DAT)) |
| _dat = c.compress(DAT[start:end]) |
| |
| with self.assertRaises(ZstdError): |
| c.flush() |
| |
| # too much data |
| c = ZstdCompressor(level=1) |
| c.set_pledged_input_size(len(DAT)) |
| |
| for start in range(0, len(DAT), CHUNK_SIZE): |
| end = min(start+CHUNK_SIZE, len(DAT)) |
| _dat = c.compress(DAT[start:end]) |
| |
| with self.assertRaises(ZstdError): |
| c.compress(b'extra', ZstdCompressor.FLUSH_FRAME) |
| |
| # content size not set if content_size_flag == 0 |
| c = ZstdCompressor(options={CompressionParameter.content_size_flag: 0}) |
| c.set_pledged_input_size(10) |
| dat1 = c.compress(b"hello") |
| dat2 = c.compress(b"world") |
| dat3 = c.flush() |
| frame_data = get_frame_info(dat1 + dat2 + dat3) |
| self.assertIsNone(frame_data.decompressed_size) |
| |
| |
| class DecompressorTestCase(unittest.TestCase): |
| |
| def test_simple_decompress_bad_args(self): |
| # ZstdDecompressor |
| self.assertRaises(TypeError, ZstdDecompressor, ()) |
| self.assertRaises(TypeError, ZstdDecompressor, zstd_dict=123) |
| self.assertRaises(TypeError, ZstdDecompressor, zstd_dict=b'abc') |
| self.assertRaises(TypeError, ZstdDecompressor, zstd_dict={1:2, 3:4}) |
| |
| self.assertRaises(TypeError, ZstdDecompressor, options=123) |
| self.assertRaises(TypeError, ZstdDecompressor, options='abc') |
| self.assertRaises(TypeError, ZstdDecompressor, options=b'abc') |
| |
| with self.assertRaises(ValueError): |
| ZstdDecompressor(options={C_INT_MAX: 100}) |
| with self.assertRaises(ValueError): |
| ZstdDecompressor(options={C_INT_MIN: 100}) |
| with self.assertRaises(ValueError): |
| ZstdDecompressor(options={0: C_INT_MAX}) |
| with self.assertRaises(OverflowError): |
| ZstdDecompressor(options={2**1000: 100}) |
| with self.assertRaises(OverflowError): |
| ZstdDecompressor(options={-(2**1000): 100}) |
| with self.assertRaises(OverflowError): |
| ZstdDecompressor(options={0: -(2**1000)}) |
| |
| with self.assertRaises(ValueError): |
| ZstdDecompressor(options={DecompressionParameter.window_log_max: 100}) |
| with self.assertRaises(ValueError): |
| ZstdDecompressor(options={3333: 100}) |
| |
| empty = compress(b'') |
| lzd = ZstdDecompressor() |
| self.assertRaises(TypeError, lzd.decompress) |
| self.assertRaises(TypeError, lzd.decompress, b"foo", b"bar") |
| self.assertRaises(TypeError, lzd.decompress, "str") |
| lzd.decompress(empty) |
| |
| def test_decompress_parameters(self): |
| d = {DecompressionParameter.window_log_max : 15} |
| ZstdDecompressor(options=d) |
| |
| d1 = d.copy() |
| # larger than signed int |
| d1[DecompressionParameter.window_log_max] = 2**1000 |
| with self.assertRaises(OverflowError): |
| ZstdDecompressor(None, d1) |
| # smaller than signed int |
| d1[DecompressionParameter.window_log_max] = -(2**1000) |
| with self.assertRaises(OverflowError): |
| ZstdDecompressor(None, d1) |
| |
| d1[DecompressionParameter.window_log_max] = C_INT_MAX |
| with self.assertRaises(ValueError): |
| ZstdDecompressor(None, d1) |
| d1[DecompressionParameter.window_log_max] = C_INT_MIN |
| with self.assertRaises(ValueError): |
| ZstdDecompressor(None, d1) |
| |
| # out of bounds error msg |
| options = {DecompressionParameter.window_log_max:100} |
| with self.assertRaisesRegex( |
| ValueError, |
| "decompression parameter 'window_log_max' received an illegal value 100; " |
| r'the valid range is \[-?\d+, -?\d+\]', |
| ): |
| decompress(b'', options=options) |
| |
| # out of bounds deecompression parameter |
| options[DecompressionParameter.window_log_max] = C_INT_MAX |
| with self.assertRaises(ValueError): |
| decompress(b'', options=options) |
| options[DecompressionParameter.window_log_max] = C_INT_MIN |
| with self.assertRaises(ValueError): |
| decompress(b'', options=options) |
| options[DecompressionParameter.window_log_max] = 2**1000 |
| with self.assertRaises(OverflowError): |
| decompress(b'', options=options) |
| options[DecompressionParameter.window_log_max] = -(2**1000) |
| with self.assertRaises(OverflowError): |
| decompress(b'', options=options) |
| |
| def test_unknown_decompression_parameter(self): |
| KEY = 100001234 |
| options = {DecompressionParameter.window_log_max: DecompressionParameter.window_log_max.bounds()[1], |
| KEY: 200000000} |
| pattern = rf"invalid decompression parameter 'unknown parameter \(key {KEY}\)'" |
| with self.assertRaisesRegex(ValueError, pattern): |
| ZstdDecompressor(options=options) |
| |
| def test_decompress_epilogue_flags(self): |
| # DAT_130K_C has a 4 bytes checksum at frame epilogue |
| |
| # full unlimited |
| d = ZstdDecompressor() |
| dat = d.decompress(DAT_130K_C) |
| self.assertEqual(len(dat), _130_1K) |
| self.assertFalse(d.needs_input) |
| |
| with self.assertRaises(EOFError): |
| dat = d.decompress(b'') |
| |
| # full limited |
| d = ZstdDecompressor() |
| dat = d.decompress(DAT_130K_C, _130_1K) |
| self.assertEqual(len(dat), _130_1K) |
| self.assertFalse(d.needs_input) |
| |
| with self.assertRaises(EOFError): |
| dat = d.decompress(b'', 0) |
| |
| # [:-4] unlimited |
| d = ZstdDecompressor() |
| dat = d.decompress(DAT_130K_C[:-4]) |
| self.assertEqual(len(dat), _130_1K) |
| self.assertTrue(d.needs_input) |
| |
| dat = d.decompress(b'') |
| self.assertEqual(len(dat), 0) |
| self.assertTrue(d.needs_input) |
| |
| # [:-4] limited |
| d = ZstdDecompressor() |
| dat = d.decompress(DAT_130K_C[:-4], _130_1K) |
| self.assertEqual(len(dat), _130_1K) |
| self.assertFalse(d.needs_input) |
| |
| dat = d.decompress(b'', 0) |
| self.assertEqual(len(dat), 0) |
| self.assertFalse(d.needs_input) |
| |
| # [:-3] unlimited |
| d = ZstdDecompressor() |
| dat = d.decompress(DAT_130K_C[:-3]) |
| self.assertEqual(len(dat), _130_1K) |
| self.assertTrue(d.needs_input) |
| |
| dat = d.decompress(b'') |
| self.assertEqual(len(dat), 0) |
| self.assertTrue(d.needs_input) |
| |
| # [:-3] limited |
| d = ZstdDecompressor() |
| dat = d.decompress(DAT_130K_C[:-3], _130_1K) |
| self.assertEqual(len(dat), _130_1K) |
| self.assertFalse(d.needs_input) |
| |
| dat = d.decompress(b'', 0) |
| self.assertEqual(len(dat), 0) |
| self.assertFalse(d.needs_input) |
| |
| # [:-1] unlimited |
| d = ZstdDecompressor() |
| dat = d.decompress(DAT_130K_C[:-1]) |
| self.assertEqual(len(dat), _130_1K) |
| self.assertTrue(d.needs_input) |
| |
| dat = d.decompress(b'') |
| self.assertEqual(len(dat), 0) |
| self.assertTrue(d.needs_input) |
| |
| # [:-1] limited |
| d = ZstdDecompressor() |
| dat = d.decompress(DAT_130K_C[:-1], _130_1K) |
| self.assertEqual(len(dat), _130_1K) |
| self.assertFalse(d.needs_input) |
| |
| dat = d.decompress(b'', 0) |
| self.assertEqual(len(dat), 0) |
| self.assertFalse(d.needs_input) |
| |
| def test_decompressor_arg(self): |
| zd = ZstdDict(b'12345678', is_raw=True) |
| |
| with self.assertRaises(TypeError): |
| d = ZstdDecompressor(zstd_dict={}) |
| |
| with self.assertRaises(TypeError): |
| d = ZstdDecompressor(options=zd) |
| |
| ZstdDecompressor() |
| ZstdDecompressor(zd, {}) |
| ZstdDecompressor(zstd_dict=zd, options={DecompressionParameter.window_log_max:25}) |
| |
| def test_decompressor_1(self): |
| # empty |
| d = ZstdDecompressor() |
| dat = d.decompress(b'') |
| |
| self.assertEqual(dat, b'') |
| self.assertFalse(d.eof) |
| |
| # 130_1K full |
| d = ZstdDecompressor() |
| dat = d.decompress(DAT_130K_C) |
| |
| self.assertEqual(len(dat), _130_1K) |
| self.assertTrue(d.eof) |
| self.assertFalse(d.needs_input) |
| |
| # 130_1K full, limit output |
| d = ZstdDecompressor() |
| dat = d.decompress(DAT_130K_C, _130_1K) |
| |
| self.assertEqual(len(dat), _130_1K) |
| self.assertTrue(d.eof) |
| self.assertFalse(d.needs_input) |
| |
| # 130_1K, without 4 bytes checksum |
| d = ZstdDecompressor() |
| dat = d.decompress(DAT_130K_C[:-4]) |
| |
| self.assertEqual(len(dat), _130_1K) |
| self.assertFalse(d.eof) |
| self.assertTrue(d.needs_input) |
| |
| # above, limit output |
| d = ZstdDecompressor() |
| dat = d.decompress(DAT_130K_C[:-4], _130_1K) |
| |
| self.assertEqual(len(dat), _130_1K) |
| self.assertFalse(d.eof) |
| self.assertFalse(d.needs_input) |
| |
| # full, unused_data |
| TRAIL = b'89234893abcd' |
| d = ZstdDecompressor() |
| dat = d.decompress(DAT_130K_C + TRAIL, _130_1K) |
| |
| self.assertEqual(len(dat), _130_1K) |
| self.assertTrue(d.eof) |
| self.assertFalse(d.needs_input) |
| self.assertEqual(d.unused_data, TRAIL) |
| |
| def test_decompressor_chunks_read_300(self): |
| TRAIL = b'89234893abcd' |
| DAT = DAT_130K_C + TRAIL |
| d = ZstdDecompressor() |
| |
| bi = io.BytesIO(DAT) |
| lst = [] |
| while True: |
| if d.needs_input: |
| dat = bi.read(300) |
| if not dat: |
| break |
| else: |
| raise Exception('should not get here') |
| |
| ret = d.decompress(dat) |
| lst.append(ret) |
| if d.eof: |
| break |
| |
| ret = b''.join(lst) |
| |
| self.assertEqual(len(ret), _130_1K) |
| self.assertTrue(d.eof) |
| self.assertFalse(d.needs_input) |
| self.assertEqual(d.unused_data + bi.read(), TRAIL) |
| |
| def test_decompressor_chunks_read_3(self): |
| TRAIL = b'89234893' |
| DAT = DAT_130K_C + TRAIL |
| d = ZstdDecompressor() |
| |
| bi = io.BytesIO(DAT) |
| lst = [] |
| while True: |
| if d.needs_input: |
| dat = bi.read(3) |
| if not dat: |
| break |
| else: |
| dat = b'' |
| |
| ret = d.decompress(dat, 1) |
| lst.append(ret) |
| if d.eof: |
| break |
| |
| ret = b''.join(lst) |
| |
| self.assertEqual(len(ret), _130_1K) |
| self.assertTrue(d.eof) |
| self.assertFalse(d.needs_input) |
| self.assertEqual(d.unused_data + bi.read(), TRAIL) |
| |
| |
| def test_decompress_empty(self): |
| with self.assertRaises(ZstdError): |
| decompress(b'') |
| |
| d = ZstdDecompressor() |
| self.assertEqual(d.decompress(b''), b'') |
| self.assertFalse(d.eof) |
| |
| def test_decompress_empty_content_frame(self): |
| DAT = compress(b'') |
| # decompress |
| self.assertGreaterEqual(len(DAT), 4) |
| self.assertEqual(decompress(DAT), b'') |
| |
| with self.assertRaises(ZstdError): |
| decompress(DAT[:-1]) |
| |
| # ZstdDecompressor |
| d = ZstdDecompressor() |
| dat = d.decompress(DAT) |
| self.assertEqual(dat, b'') |
| self.assertTrue(d.eof) |
| self.assertFalse(d.needs_input) |
| self.assertEqual(d.unused_data, b'') |
| self.assertEqual(d.unused_data, b'') # twice |
| |
| d = ZstdDecompressor() |
| dat = d.decompress(DAT[:-1]) |
| self.assertEqual(dat, b'') |
| self.assertFalse(d.eof) |
| self.assertTrue(d.needs_input) |
| self.assertEqual(d.unused_data, b'') |
| self.assertEqual(d.unused_data, b'') # twice |
| |
| class DecompressorFlagsTestCase(unittest.TestCase): |
| |
| @classmethod |
| def setUpClass(cls): |
| options = {CompressionParameter.checksum_flag:1} |
| c = ZstdCompressor(options=options) |
| |
| cls.DECOMPRESSED_42 = b'a'*42 |
| cls.FRAME_42 = c.compress(cls.DECOMPRESSED_42, c.FLUSH_FRAME) |
| |
| cls.DECOMPRESSED_60 = b'a'*60 |
| cls.FRAME_60 = c.compress(cls.DECOMPRESSED_60, c.FLUSH_FRAME) |
| |
| cls.FRAME_42_60 = cls.FRAME_42 + cls.FRAME_60 |
| cls.DECOMPRESSED_42_60 = cls.DECOMPRESSED_42 + cls.DECOMPRESSED_60 |
| |
| cls._130_1K = 130*_1K |
| |
| c = ZstdCompressor() |
| cls.UNKNOWN_FRAME_42 = c.compress(cls.DECOMPRESSED_42) + c.flush() |
| cls.UNKNOWN_FRAME_60 = c.compress(cls.DECOMPRESSED_60) + c.flush() |
| cls.UNKNOWN_FRAME_42_60 = cls.UNKNOWN_FRAME_42 + cls.UNKNOWN_FRAME_60 |
| |
| cls.TRAIL = b'12345678abcdefg!@#$%^&*()_+|' |
| |
| def test_function_decompress(self): |
| |
| self.assertEqual(len(decompress(COMPRESSED_100_PLUS_32KB)), 100+32*_1K) |
| |
| # 1 frame |
| self.assertEqual(decompress(self.FRAME_42), self.DECOMPRESSED_42) |
| |
| self.assertEqual(decompress(self.UNKNOWN_FRAME_42), self.DECOMPRESSED_42) |
| |
| pattern = r"Compressed data ended before the end-of-stream marker" |
| with self.assertRaisesRegex(ZstdError, pattern): |
| decompress(self.FRAME_42[:1]) |
| |
| with self.assertRaisesRegex(ZstdError, pattern): |
| decompress(self.FRAME_42[:-4]) |
| |
| with self.assertRaisesRegex(ZstdError, pattern): |
| decompress(self.FRAME_42[:-1]) |
| |
| # 2 frames |
| self.assertEqual(decompress(self.FRAME_42_60), self.DECOMPRESSED_42_60) |
| |
| self.assertEqual(decompress(self.UNKNOWN_FRAME_42_60), self.DECOMPRESSED_42_60) |
| |
| self.assertEqual(decompress(self.FRAME_42 + self.UNKNOWN_FRAME_60), |
| self.DECOMPRESSED_42_60) |
| |
| self.assertEqual(decompress(self.UNKNOWN_FRAME_42 + self.FRAME_60), |
| self.DECOMPRESSED_42_60) |
| |
| with self.assertRaisesRegex(ZstdError, pattern): |
| decompress(self.FRAME_42_60[:-4]) |
| |
| with self.assertRaisesRegex(ZstdError, pattern): |
| decompress(self.UNKNOWN_FRAME_42_60[:-1]) |
| |
| # 130_1K |
| self.assertEqual(decompress(DAT_130K_C), DAT_130K_D) |
| |
| with self.assertRaisesRegex(ZstdError, pattern): |
| decompress(DAT_130K_C[:-4]) |
| |
| with self.assertRaisesRegex(ZstdError, pattern): |
| decompress(DAT_130K_C[:-1]) |
| |
| # Unknown frame descriptor |
| with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): |
| decompress(b'aaaaaaaaa') |
| |
| with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): |
| decompress(self.FRAME_42 + b'aaaaaaaaa') |
| |
| with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): |
| decompress(self.UNKNOWN_FRAME_42_60 + b'aaaaaaaaa') |
| |
| # doesn't match checksum |
| checksum = DAT_130K_C[-4:] |
| if checksum[0] == 255: |
| wrong_checksum = bytes([254]) + checksum[1:] |
| else: |
| wrong_checksum = bytes([checksum[0]+1]) + checksum[1:] |
| |
| dat = DAT_130K_C[:-4] + wrong_checksum |
| |
| with self.assertRaisesRegex(ZstdError, "doesn't match checksum"): |
| decompress(dat) |
| |
| def test_function_skippable(self): |
| self.assertEqual(decompress(SKIPPABLE_FRAME), b'') |
| self.assertEqual(decompress(SKIPPABLE_FRAME + SKIPPABLE_FRAME), b'') |
| |
| # 1 frame + 2 skippable |
| self.assertEqual(len(decompress(SKIPPABLE_FRAME + SKIPPABLE_FRAME + DAT_130K_C)), |
| self._130_1K) |
| |
| self.assertEqual(len(decompress(DAT_130K_C + SKIPPABLE_FRAME + SKIPPABLE_FRAME)), |
| self._130_1K) |
| |
| self.assertEqual(len(decompress(SKIPPABLE_FRAME + DAT_130K_C + SKIPPABLE_FRAME)), |
| self._130_1K) |
| |
| # unknown size |
| self.assertEqual(decompress(SKIPPABLE_FRAME + self.UNKNOWN_FRAME_60), |
| self.DECOMPRESSED_60) |
| |
| self.assertEqual(decompress(self.UNKNOWN_FRAME_60 + SKIPPABLE_FRAME), |
| self.DECOMPRESSED_60) |
| |
| # 2 frames + 1 skippable |
| self.assertEqual(decompress(self.FRAME_42 + SKIPPABLE_FRAME + self.FRAME_60), |
| self.DECOMPRESSED_42_60) |
| |
| self.assertEqual(decompress(SKIPPABLE_FRAME + self.FRAME_42_60), |
| self.DECOMPRESSED_42_60) |
| |
| self.assertEqual(decompress(self.UNKNOWN_FRAME_42_60 + SKIPPABLE_FRAME), |
| self.DECOMPRESSED_42_60) |
| |
| # incomplete |
| with self.assertRaises(ZstdError): |
| decompress(SKIPPABLE_FRAME[:1]) |
| |
| with self.assertRaises(ZstdError): |
| decompress(SKIPPABLE_FRAME[:-1]) |
| |
| with self.assertRaises(ZstdError): |
| decompress(self.FRAME_42 + SKIPPABLE_FRAME[:-1]) |
| |
| # Unknown frame descriptor |
| with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): |
| decompress(b'aaaaaaaaa' + SKIPPABLE_FRAME) |
| |
| with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): |
| decompress(SKIPPABLE_FRAME + b'aaaaaaaaa') |
| |
| with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): |
| decompress(SKIPPABLE_FRAME + SKIPPABLE_FRAME + b'aaaaaaaaa') |
| |
| def test_decompressor_1(self): |
| # empty 1 |
| d = ZstdDecompressor() |
| |
| dat = d.decompress(b'') |
| self.assertEqual(dat, b'') |
| self.assertFalse(d.eof) |
| self.assertTrue(d.needs_input) |
| self.assertEqual(d.unused_data, b'') |
| self.assertEqual(d.unused_data, b'') # twice |
| |
| dat = d.decompress(b'', 0) |
| self.assertEqual(dat, b'') |
| self.assertFalse(d.eof) |
| self.assertFalse(d.needs_input) |
| self.assertEqual(d.unused_data, b'') |
| self.assertEqual(d.unused_data, b'') # twice |
| |
| dat = d.decompress(COMPRESSED_100_PLUS_32KB + b'a') |
| self.assertEqual(dat, DECOMPRESSED_100_PLUS_32KB) |
| self.assertTrue(d.eof) |
| self.assertFalse(d.needs_input) |
| self.assertEqual(d.unused_data, b'a') |
| self.assertEqual(d.unused_data, b'a') # twice |
| |
| # empty 2 |
| d = ZstdDecompressor() |
| |
| dat = d.decompress(b'', 0) |
| self.assertEqual(dat, b'') |
| self.assertFalse(d.eof) |
| self.assertFalse(d.needs_input) |
| self.assertEqual(d.unused_data, b'') |
| self.assertEqual(d.unused_data, b'') # twice |
| |
| dat = d.decompress(b'') |
| self.assertEqual(dat, b'') |
| self.assertFalse(d.eof) |
| self.assertTrue(d.needs_input) |
| self.assertEqual(d.unused_data, b'') |
| self.assertEqual(d.unused_data, b'') # twice |
| |
| dat = d.decompress(COMPRESSED_100_PLUS_32KB + b'a') |
| self.assertEqual(dat, DECOMPRESSED_100_PLUS_32KB) |
| self.assertTrue(d.eof) |
| self.assertFalse(d.needs_input) |
| self.assertEqual(d.unused_data, b'a') |
| self.assertEqual(d.unused_data, b'a') # twice |
| |
| # 1 frame |
| d = ZstdDecompressor() |
| dat = d.decompress(self.FRAME_42) |
| |
| self.assertEqual(dat, self.DECOMPRESSED_42) |
| self.assertTrue(d.eof) |
| self.assertFalse(d.needs_input) |
| self.assertEqual(d.unused_data, b'') |
| self.assertEqual(d.unused_data, b'') # twice |
| |
| with self.assertRaises(EOFError): |
| d.decompress(b'') |
| |
| # 1 frame, trail |
| d = ZstdDecompressor() |
| dat = d.decompress(self.FRAME_42 + self.TRAIL) |
| |
| self.assertEqual(dat, self.DECOMPRESSED_42) |
| self.assertTrue(d.eof) |
| self.assertFalse(d.needs_input) |
| self.assertEqual(d.unused_data, self.TRAIL) |
| self.assertEqual(d.unused_data, self.TRAIL) # twice |
| |
| # 1 frame, 32_1K |
| temp = compress(b'a'*(32*_1K)) |
| d = ZstdDecompressor() |
| dat = d.decompress(temp, 32*_1K) |
| |
| self.assertEqual(dat, b'a'*(32*_1K)) |
| self.assertTrue(d.eof) |
| self.assertFalse(d.needs_input) |
| self.assertEqual(d.unused_data, b'') |
| self.assertEqual(d.unused_data, b'') # twice |
| |
| with self.assertRaises(EOFError): |
| d.decompress(b'') |
| |
| # 1 frame, 32_1K+100, trail |
| d = ZstdDecompressor() |
| dat = d.decompress(COMPRESSED_100_PLUS_32KB+self.TRAIL, 100) # 100 bytes |
| |
| self.assertEqual(len(dat), 100) |
| self.assertFalse(d.eof) |
| self.assertFalse(d.needs_input) |
| self.assertEqual(d.unused_data, b'') |
| |
| dat = d.decompress(b'') # 32_1K |
| |
| self.assertEqual(len(dat), 32*_1K) |
| self.assertTrue(d.eof) |
| self.assertFalse(d.needs_input) |
| self.assertEqual(d.unused_data, self.TRAIL) |
| self.assertEqual(d.unused_data, self.TRAIL) # twice |
| |
| with self.assertRaises(EOFError): |
| d.decompress(b'') |
| |
| # incomplete 1 |
| d = ZstdDecompressor() |
| dat = d.decompress(self.FRAME_60[:1]) |
| |
| self.assertFalse(d.eof) |
| self.assertTrue(d.needs_input) |
| self.assertEqual(d.unused_data, b'') |
| self.assertEqual(d.unused_data, b'') # twice |
| |
| # incomplete 2 |
| d = ZstdDecompressor() |
| |
| dat = d.decompress(self.FRAME_60[:-4]) |
| self.assertEqual(dat, self.DECOMPRESSED_60) |
| self.assertFalse(d.eof) |
| self.assertTrue(d.needs_input) |
| self.assertEqual(d.unused_data, b'') |
| self.assertEqual(d.unused_data, b'') # twice |
| |
| # incomplete 3 |
| d = ZstdDecompressor() |
| |
| dat = d.decompress(self.FRAME_60[:-1]) |
| self.assertEqual(dat, self.DECOMPRESSED_60) |
| self.assertFalse(d.eof) |
| self.assertTrue(d.needs_input) |
| self.assertEqual(d.unused_data, b'') |
| |
| # incomplete 4 |
| d = ZstdDecompressor() |
| |
| dat = d.decompress(self.FRAME_60[:-4], 60) |
| self.assertEqual(dat, self.DECOMPRESSED_60) |
| self.assertFalse(d.eof) |
| self.assertFalse(d.needs_input) |
| self.assertEqual(d.unused_data, b'') |
| self.assertEqual(d.unused_data, b'') # twice |
| |
| dat = d.decompress(b'') |
| self.assertEqual(dat, b'') |
| self.assertFalse(d.eof) |
| self.assertTrue(d.needs_input) |
| self.assertEqual(d.unused_data, b'') |
| self.assertEqual(d.unused_data, b'') # twice |
| |
| # Unknown frame descriptor |
| d = ZstdDecompressor() |
| with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"): |
| d.decompress(b'aaaaaaaaa') |
| |
| def test_decompressor_skippable(self): |
| # 1 skippable |
| d = ZstdDecompressor() |
| dat = d.decompress(SKIPPABLE_FRAME) |
| |
| self.assertEqual(dat, b'') |
| self.assertTrue(d.eof) |
| self.assertFalse(d.needs_input) |
| self.assertEqual(d.unused_data, b'') |
| self.assertEqual(d.unused_data, b'') # twice |
| |
| # 1 skippable, max_length=0 |
| d = ZstdDecompressor() |
| dat = d.decompress(SKIPPABLE_FRAME, 0) |
| |
| self.assertEqual(dat, b'') |
| self.assertTrue(d.eof) |
| self.assertFalse(d.needs_input) |
| self.assertEqual(d.unused_data, b'') |
| self.assertEqual(d.unused_data, b'') # twice |
| |
| # 1 skippable, trail |
| d = ZstdDecompressor() |
| dat = d.decompress(SKIPPABLE_FRAME + self.TRAIL) |
| |
| self.assertEqual(dat, b'') |
| self.assertTrue(d.eof) |
| self.assertFalse(d.needs_input) |
| self.assertEqual(d.unused_data, self.TRAIL) |
| self.assertEqual(d.unused_data, self.TRAIL) # twice |
| |
| # incomplete |
| d = ZstdDecompressor() |
| dat = d.decompress(SKIPPABLE_FRAME[:-1]) |
| |
| self.assertEqual(dat, b'') |
| self.assertFalse(d.eof) |
| self.assertTrue(d.needs_input) |
| self.assertEqual(d.unused_data, b'') |
| self.assertEqual(d.unused_data, b'') # twice |
| |
| # incomplete |
| d = ZstdDecompressor() |
| dat = d.decompress(SKIPPABLE_FRAME[:-1], 0) |
| |
| self.assertEqual(dat, b'') |
| self.assertFalse(d.eof) |
| self.assertFalse(d.needs_input) |
| self.assertEqual(d.unused_data, b'') |
| self.assertEqual(d.unused_data, b'') # twice |
| |
| dat = d.decompress(b'') |
| |
| self.assertEqual(dat, b'') |
| self.assertFalse(d.eof) |
| self.assertTrue(d.needs_input) |
| self.assertEqual(d.unused_data, b'') |
| self.assertEqual(d.unused_data, b'') # twice |
| |
| |
| |
| class ZstdDictTestCase(unittest.TestCase): |
| |
| def test_is_raw(self): |
| # must be passed as a keyword argument |
| with self.assertRaises(TypeError): |
| ZstdDict(bytes(8), True) |
| |
| # content < 8 |
| b = b'1234567' |
| with self.assertRaises(ValueError): |
| ZstdDict(b) |
| |
| # content == 8 |
| b = b'12345678' |
| zd = ZstdDict(b, is_raw=True) |
| self.assertEqual(zd.dict_id, 0) |
| |
| temp = compress(b'aaa12345678', level=3, zstd_dict=zd) |
| self.assertEqual(b'aaa12345678', decompress(temp, zd)) |
| |
| # is_raw == False |
| b = b'12345678abcd' |
| with self.assertRaises(ValueError): |
| ZstdDict(b) |
| |
| # read only attributes |
| with self.assertRaises(AttributeError): |
| zd.dict_content = b |
| |
| with self.assertRaises(AttributeError): |
| zd.dict_id = 10000 |
| |
| # ZstdDict arguments |
| zd = ZstdDict(TRAINED_DICT.dict_content, is_raw=False) |
| self.assertNotEqual(zd.dict_id, 0) |
| |
| zd = ZstdDict(TRAINED_DICT.dict_content, is_raw=True) |
| self.assertNotEqual(zd.dict_id, 0) # note this assertion |
| |
| with self.assertRaises(TypeError): |
| ZstdDict("12345678abcdef", is_raw=True) |
| with self.assertRaises(TypeError): |
| ZstdDict(TRAINED_DICT) |
| |
| # invalid parameter |
| with self.assertRaises(TypeError): |
| ZstdDict(desk333=345) |
| |
| def test_invalid_dict(self): |
| DICT_MAGIC = 0xEC30A437.to_bytes(4, byteorder='little') |
| dict_content = DICT_MAGIC + b'abcdefghighlmnopqrstuvwxyz' |
| |
| # corrupted |
| zd = ZstdDict(dict_content, is_raw=False) |
| with self.assertRaisesRegex(ZstdError, r'ZSTD_CDict.*?content\.$'): |
| ZstdCompressor(zstd_dict=zd.as_digested_dict) |
| with self.assertRaisesRegex(ZstdError, r'ZSTD_DDict.*?content\.$'): |
| ZstdDecompressor(zd) |
| |
| # wrong type |
| with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): |
| ZstdCompressor(zstd_dict=[zd, 1]) |
| with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): |
| ZstdCompressor(zstd_dict=(zd, 1.0)) |
| with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): |
| ZstdCompressor(zstd_dict=(zd,)) |
| with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): |
| ZstdCompressor(zstd_dict=(zd, 1, 2)) |
| with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): |
| ZstdCompressor(zstd_dict=(zd, -1)) |
| with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): |
| ZstdCompressor(zstd_dict=(zd, 3)) |
| with self.assertRaises(OverflowError): |
| ZstdCompressor(zstd_dict=(zd, 2**1000)) |
| with self.assertRaises(OverflowError): |
| ZstdCompressor(zstd_dict=(zd, -2**1000)) |
| |
| with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): |
| ZstdDecompressor(zstd_dict=[zd, 1]) |
| with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): |
| ZstdDecompressor(zstd_dict=(zd, 1.0)) |
| with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): |
| ZstdDecompressor((zd,)) |
| with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): |
| ZstdDecompressor((zd, 1, 2)) |
| with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): |
| ZstdDecompressor((zd, -1)) |
| with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): |
| ZstdDecompressor((zd, 3)) |
| with self.assertRaises(OverflowError): |
| ZstdDecompressor((zd, 2**1000)) |
| with self.assertRaises(OverflowError): |
| ZstdDecompressor((zd, -2**1000)) |
| |
| def test_train_dict(self): |
| TRAINED_DICT = train_dict(SAMPLES, DICT_SIZE1) |
| ZstdDict(TRAINED_DICT.dict_content, is_raw=False) |
| |
| self.assertNotEqual(TRAINED_DICT.dict_id, 0) |
| self.assertGreater(len(TRAINED_DICT.dict_content), 0) |
| self.assertLessEqual(len(TRAINED_DICT.dict_content), DICT_SIZE1) |
| self.assertTrue(re.match(r'^<ZstdDict dict_id=\d+ dict_size=\d+>$', str(TRAINED_DICT))) |
| |
| # compress/decompress |
| c = ZstdCompressor(zstd_dict=TRAINED_DICT) |
| for sample in SAMPLES: |
| dat1 = compress(sample, zstd_dict=TRAINED_DICT) |
| dat2 = decompress(dat1, TRAINED_DICT) |
| self.assertEqual(sample, dat2) |
| |
| dat1 = c.compress(sample) |
| dat1 += c.flush() |
| dat2 = decompress(dat1, TRAINED_DICT) |
| self.assertEqual(sample, dat2) |
| |
| def test_finalize_dict(self): |
| DICT_SIZE2 = 200*_1K |
| C_LEVEL = 6 |
| |
| try: |
| dic2 = finalize_dict(TRAINED_DICT, SAMPLES, DICT_SIZE2, C_LEVEL) |
| except NotImplementedError: |
| # < v1.4.5 at compile-time, >= v.1.4.5 at run-time |
| return |
| |
| self.assertNotEqual(dic2.dict_id, 0) |
| self.assertGreater(len(dic2.dict_content), 0) |
| self.assertLessEqual(len(dic2.dict_content), DICT_SIZE2) |
| |
| # compress/decompress |
| c = ZstdCompressor(C_LEVEL, zstd_dict=dic2) |
| for sample in SAMPLES: |
| dat1 = compress(sample, C_LEVEL, zstd_dict=dic2) |
| dat2 = decompress(dat1, dic2) |
| self.assertEqual(sample, dat2) |
| |
| dat1 = c.compress(sample) |
| dat1 += c.flush() |
| dat2 = decompress(dat1, dic2) |
| self.assertEqual(sample, dat2) |
| |
| # dict mismatch |
| self.assertNotEqual(TRAINED_DICT.dict_id, dic2.dict_id) |
| |
| dat1 = compress(SAMPLES[0], zstd_dict=TRAINED_DICT) |
| with self.assertRaises(ZstdError): |
| decompress(dat1, dic2) |
| |
| def test_train_dict_arguments(self): |
| with self.assertRaises(ValueError): |
| train_dict([], 100*_1K) |
| |
| with self.assertRaises(ValueError): |
| train_dict(SAMPLES, -100) |
| |
| with self.assertRaises(ValueError): |
| train_dict(SAMPLES, 0) |
| |
| def test_finalize_dict_arguments(self): |
| with self.assertRaises(TypeError): |
| finalize_dict({1:2}, (b'aaa', b'bbb'), 100*_1K, 2) |
| |
| with self.assertRaises(ValueError): |
| finalize_dict(TRAINED_DICT, [], 100*_1K, 2) |
| |
| with self.assertRaises(ValueError): |
| finalize_dict(TRAINED_DICT, SAMPLES, -100, 2) |
| |
| with self.assertRaises(ValueError): |
| finalize_dict(TRAINED_DICT, SAMPLES, 0, 2) |
| |
| def test_train_dict_c(self): |
| # argument wrong type |
| with self.assertRaises(TypeError): |
| _zstd.train_dict({}, (), 100) |
| with self.assertRaises(TypeError): |
| _zstd.train_dict(bytearray(), (), 100) |
| with self.assertRaises(TypeError): |
| _zstd.train_dict(b'', 99, 100) |
| with self.assertRaises(TypeError): |
| _zstd.train_dict(b'', [], 100) |
| with self.assertRaises(TypeError): |
| _zstd.train_dict(b'', (), 100.1) |
| with self.assertRaises(TypeError): |
| _zstd.train_dict(b'', (99.1,), 100) |
| with self.assertRaises(ValueError): |
| _zstd.train_dict(b'abc', (4, -1), 100) |
| with self.assertRaises(ValueError): |
| _zstd.train_dict(b'abc', (2,), 100) |
| with self.assertRaises(ValueError): |
| _zstd.train_dict(b'', (99,), 100) |
| |
| # size > size_t |
| with self.assertRaises(ValueError): |
| _zstd.train_dict(b'', (2**1000,), 100) |
| with self.assertRaises(ValueError): |
| _zstd.train_dict(b'', (-2**1000,), 100) |
| |
| # dict_size <= 0 |
| with self.assertRaises(ValueError): |
| _zstd.train_dict(b'', (), 0) |
| with self.assertRaises(ValueError): |
| _zstd.train_dict(b'', (), -1) |
| |
| with self.assertRaises(ZstdError): |
| _zstd.train_dict(b'', (), 1) |
| |
| def test_finalize_dict_c(self): |
| with self.assertRaises(TypeError): |
| _zstd.finalize_dict(1, 2, 3, 4, 5) |
| |
| # argument wrong type |
| with self.assertRaises(TypeError): |
| _zstd.finalize_dict({}, b'', (), 100, 5) |
| with self.assertRaises(TypeError): |
| _zstd.finalize_dict(bytearray(TRAINED_DICT.dict_content), b'', (), 100, 5) |
| with self.assertRaises(TypeError): |
| _zstd.finalize_dict(TRAINED_DICT.dict_content, {}, (), 100, 5) |
| with self.assertRaises(TypeError): |
| _zstd.finalize_dict(TRAINED_DICT.dict_content, bytearray(), (), 100, 5) |
| with self.assertRaises(TypeError): |
| _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', 99, 100, 5) |
| with self.assertRaises(TypeError): |
| _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', [], 100, 5) |
| with self.assertRaises(TypeError): |
| _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100.1, 5) |
| with self.assertRaises(TypeError): |
| _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 5.1) |
| |
| with self.assertRaises(ValueError): |
| _zstd.finalize_dict(TRAINED_DICT.dict_content, b'abc', (4, -1), 100, 5) |
| with self.assertRaises(ValueError): |
| _zstd.finalize_dict(TRAINED_DICT.dict_content, b'abc', (2,), 100, 5) |
| with self.assertRaises(ValueError): |
| _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (99,), 100, 5) |
| |
| # size > size_t |
| with self.assertRaises(ValueError): |
| _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (2**1000,), 100, 5) |
| with self.assertRaises(ValueError): |
| _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (-2**1000,), 100, 5) |
| |
| # dict_size <= 0 |
| with self.assertRaises(ValueError): |
| _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 0, 5) |
| with self.assertRaises(ValueError): |
| _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), -1, 5) |
| with self.assertRaises(OverflowError): |
| _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 2**1000, 5) |
| with self.assertRaises(OverflowError): |
| _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), -2**1000, 5) |
| |
| with self.assertRaises(OverflowError): |
| _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 2**1000) |
| with self.assertRaises(OverflowError): |
| _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, -2**1000) |
| |
| with self.assertRaises(ZstdError): |
| _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 5) |
| |
| def test_train_buffer_protocol_samples(self): |
| def _nbytes(dat): |
| if isinstance(dat, (bytes, bytearray)): |
| return len(dat) |
| return memoryview(dat).nbytes |
| |
| # prepare samples |
| chunk_lst = [] |
| wrong_size_lst = [] |
| correct_size_lst = [] |
| for _ in range(300): |
| arr = array.array('Q', [random.randint(0, 20) for i in range(20)]) |
| chunk_lst.append(arr) |
| correct_size_lst.append(_nbytes(arr)) |
| wrong_size_lst.append(len(arr)) |
| concatenation = b''.join(chunk_lst) |
| |
| # wrong size list |
| with self.assertRaisesRegex(ValueError, |
| "The samples size tuple doesn't match the concatenation's size"): |
| _zstd.train_dict(concatenation, tuple(wrong_size_lst), 100*_1K) |
| |
| # correct size list |
| _zstd.train_dict(concatenation, tuple(correct_size_lst), 3*_1K) |
| |
| # wrong size list |
| with self.assertRaisesRegex(ValueError, |
| "The samples size tuple doesn't match the concatenation's size"): |
| _zstd.finalize_dict(TRAINED_DICT.dict_content, |
| concatenation, tuple(wrong_size_lst), 300*_1K, 5) |
| |
| # correct size list |
| _zstd.finalize_dict(TRAINED_DICT.dict_content, |
| concatenation, tuple(correct_size_lst), 300*_1K, 5) |
| |
| def test_as_prefix(self): |
| # V1 |
| V1 = THIS_FILE_BYTES |
| zd = ZstdDict(V1, is_raw=True) |
| |
| # V2 |
| mid = len(V1) // 2 |
| V2 = V1[:mid] + \ |
| (b'a' if V1[mid] != int.from_bytes(b'a') else b'b') + \ |
| V1[mid+1:] |
| |
| # compress |
| dat = compress(V2, zstd_dict=zd.as_prefix) |
| self.assertEqual(get_frame_info(dat).dictionary_id, 0) |
| |
| # decompress |
| self.assertEqual(decompress(dat, zd.as_prefix), V2) |
| |
| # use wrong prefix |
| zd2 = ZstdDict(SAMPLES[0], is_raw=True) |
| try: |
| decompressed = decompress(dat, zd2.as_prefix) |
| except ZstdError: # expected |
| pass |
| else: |
| self.assertNotEqual(decompressed, V2) |
| |
| # read only attribute |
| with self.assertRaises(AttributeError): |
| zd.as_prefix = b'1234' |
| |
| def test_as_digested_dict(self): |
| zd = TRAINED_DICT |
| |
| # test .as_digested_dict |
| dat = compress(SAMPLES[0], zstd_dict=zd.as_digested_dict) |
| self.assertEqual(decompress(dat, zd.as_digested_dict), SAMPLES[0]) |
| with self.assertRaises(AttributeError): |
| zd.as_digested_dict = b'1234' |
| |
| # test .as_undigested_dict |
| dat = compress(SAMPLES[0], zstd_dict=zd.as_undigested_dict) |
| self.assertEqual(decompress(dat, zd.as_undigested_dict), SAMPLES[0]) |
| with self.assertRaises(AttributeError): |
| zd.as_undigested_dict = b'1234' |
| |
| def test_advanced_compression_parameters(self): |
| options = {CompressionParameter.compression_level: 6, |
| CompressionParameter.window_log: 20, |
| CompressionParameter.enable_long_distance_matching: 1} |
| |
| # automatically select |
| dat = compress(SAMPLES[0], options=options, zstd_dict=TRAINED_DICT) |
| self.assertEqual(decompress(dat, TRAINED_DICT), SAMPLES[0]) |
| |
| # explicitly select |
| dat = compress(SAMPLES[0], options=options, zstd_dict=TRAINED_DICT.as_digested_dict) |
| self.assertEqual(decompress(dat, TRAINED_DICT), SAMPLES[0]) |
| |
| def test_len(self): |
| self.assertEqual(len(TRAINED_DICT), len(TRAINED_DICT.dict_content)) |
| self.assertIn(str(len(TRAINED_DICT)), str(TRAINED_DICT)) |
| |
| class FileTestCase(unittest.TestCase): |
| def setUp(self): |
| self.DECOMPRESSED_42 = b'a'*42 |
| self.FRAME_42 = compress(self.DECOMPRESSED_42) |
| |
| def test_init(self): |
| with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: |
| pass |
| with ZstdFile(io.BytesIO(), "w") as f: |
| pass |
| with ZstdFile(io.BytesIO(), "x") as f: |
| pass |
| with ZstdFile(io.BytesIO(), "a") as f: |
| pass |
| |
| with ZstdFile(io.BytesIO(), "w", level=12) as f: |
| pass |
| with ZstdFile(io.BytesIO(), "w", options={CompressionParameter.checksum_flag:1}) as f: |
| pass |
| with ZstdFile(io.BytesIO(), "w", options={}) as f: |
| pass |
| with ZstdFile(io.BytesIO(), "w", level=20, zstd_dict=TRAINED_DICT) as f: |
| pass |
| |
| with ZstdFile(io.BytesIO(), "r", options={DecompressionParameter.window_log_max:25}) as f: |
| pass |
| with ZstdFile(io.BytesIO(), "r", options={}, zstd_dict=TRAINED_DICT) as f: |
| pass |
| |
| def test_init_with_PathLike_filename(self): |
| with tempfile.NamedTemporaryFile(delete=False) as tmp_f: |
| filename = pathlib.Path(tmp_f.name) |
| |
| with ZstdFile(filename, "a") as f: |
| f.write(DECOMPRESSED_100_PLUS_32KB) |
| with ZstdFile(filename) as f: |
| self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) |
| |
| with ZstdFile(filename, "a") as f: |
| f.write(DECOMPRESSED_100_PLUS_32KB) |
| with ZstdFile(filename) as f: |
| self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB * 2) |
| |
| os.remove(filename) |
| |
| def test_init_with_filename(self): |
| with tempfile.NamedTemporaryFile(delete=False) as tmp_f: |
| filename = pathlib.Path(tmp_f.name) |
| |
| with ZstdFile(filename) as f: |
| pass |
| with ZstdFile(filename, "w") as f: |
| pass |
| with ZstdFile(filename, "a") as f: |
| pass |
| |
| os.remove(filename) |
| |
| def test_init_mode(self): |
| bi = io.BytesIO() |
| |
| with ZstdFile(bi, "r"): |
| pass |
| with ZstdFile(bi, "rb"): |
| pass |
| with ZstdFile(bi, "w"): |
| pass |
| with ZstdFile(bi, "wb"): |
| pass |
| with ZstdFile(bi, "a"): |
| pass |
| with ZstdFile(bi, "ab"): |
| pass |
| |
| def test_init_with_x_mode(self): |
| with tempfile.NamedTemporaryFile() as tmp_f: |
| filename = pathlib.Path(tmp_f.name) |
| |
| for mode in ("x", "xb"): |
| with ZstdFile(filename, mode): |
| pass |
| with self.assertRaises(FileExistsError): |
| with ZstdFile(filename, mode): |
| pass |
| os.remove(filename) |
| |
| def test_init_bad_mode(self): |
| with self.assertRaises(ValueError): |
| ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), (3, "x")) |
| with self.assertRaises(ValueError): |
| ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "") |
| with self.assertRaises(ValueError): |
| ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "xt") |
| with self.assertRaises(ValueError): |
| ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "x+") |
| with self.assertRaises(ValueError): |
| ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rx") |
| with self.assertRaises(ValueError): |
| ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "wx") |
| with self.assertRaises(ValueError): |
| ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rt") |
| with self.assertRaises(ValueError): |
| ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r+") |
| with self.assertRaises(ValueError): |
| ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "wt") |
| with self.assertRaises(ValueError): |
| ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "w+") |
| with self.assertRaises(ValueError): |
| ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rw") |
| |
| with self.assertRaisesRegex(TypeError, |
| r"not be a CompressionParameter"): |
| ZstdFile(io.BytesIO(), 'rb', |
| options={CompressionParameter.compression_level:5}) |
| with self.assertRaisesRegex(TypeError, |
| r"not be a DecompressionParameter"): |
| ZstdFile(io.BytesIO(), 'wb', |
| options={DecompressionParameter.window_log_max:21}) |
| |
| with self.assertRaises(TypeError): |
| ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r", level=12) |
| |
| def test_init_bad_check(self): |
| with self.assertRaises(TypeError): |
| ZstdFile(io.BytesIO(), "w", level='asd') |
| # CHECK_UNKNOWN and anything above CHECK_ID_MAX should be invalid. |
| with self.assertRaises(ValueError): |
| ZstdFile(io.BytesIO(), "w", options={999:9999}) |
| with self.assertRaises(ValueError): |
| ZstdFile(io.BytesIO(), "w", options={CompressionParameter.window_log:99}) |
| |
| with self.assertRaises(TypeError): |
| ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r", options=33) |
| |
| with self.assertRaises(OverflowError): |
| ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), |
| options={DecompressionParameter.window_log_max:2**31}) |
| |
| with self.assertRaises(ValueError): |
| ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), |
| options={444:333}) |
| |
| with self.assertRaises(TypeError): |
| ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), zstd_dict={1:2}) |
| |
| with self.assertRaises(TypeError): |
| ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), zstd_dict=b'dict123456') |
| |
| def test_init_close_fp(self): |
| # get a temp file name |
| with tempfile.NamedTemporaryFile(delete=False) as tmp_f: |
| tmp_f.write(DAT_130K_C) |
| filename = tmp_f.name |
| |
| with self.assertRaises(TypeError): |
| ZstdFile(filename, options={'a':'b'}) |
| |
| # for PyPy |
| gc.collect() |
| |
| os.remove(filename) |
| |
| def test_close(self): |
| with io.BytesIO(COMPRESSED_100_PLUS_32KB) as src: |
| f = ZstdFile(src) |
| f.close() |
| # ZstdFile.close() should not close the underlying file object. |
| self.assertFalse(src.closed) |
| # Try closing an already-closed ZstdFile. |
| f.close() |
| self.assertFalse(src.closed) |
| |
| # Test with a real file on disk, opened directly by ZstdFile. |
| with tempfile.NamedTemporaryFile(delete=False) as tmp_f: |
| filename = pathlib.Path(tmp_f.name) |
| |
| f = ZstdFile(filename) |
| fp = f._fp |
| f.close() |
| # Here, ZstdFile.close() *should* close the underlying file object. |
| self.assertTrue(fp.closed) |
| # Try closing an already-closed ZstdFile. |
| f.close() |
| |
| os.remove(filename) |
| |
| def test_closed(self): |
| f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) |
| try: |
| self.assertFalse(f.closed) |
| f.read() |
| self.assertFalse(f.closed) |
| finally: |
| f.close() |
| self.assertTrue(f.closed) |
| |
| f = ZstdFile(io.BytesIO(), "w") |
| try: |
| self.assertFalse(f.closed) |
| finally: |
| f.close() |
| self.assertTrue(f.closed) |
| |
| def test_fileno(self): |
| # 1 |
| f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) |
| try: |
| self.assertRaises(io.UnsupportedOperation, f.fileno) |
| finally: |
| f.close() |
| self.assertRaises(ValueError, f.fileno) |
| |
| # 2 |
| with tempfile.NamedTemporaryFile(delete=False) as tmp_f: |
| filename = pathlib.Path(tmp_f.name) |
| |
| f = ZstdFile(filename) |
| try: |
| self.assertEqual(f.fileno(), f._fp.fileno()) |
| self.assertIsInstance(f.fileno(), int) |
| finally: |
| f.close() |
| self.assertRaises(ValueError, f.fileno) |
| |
| os.remove(filename) |
| |
| # 3, no .fileno() method |
| class C: |
| def read(self, size=-1): |
| return b'123' |
| with ZstdFile(C(), 'rb') as f: |
| with self.assertRaisesRegex(AttributeError, r'fileno'): |
| f.fileno() |
| |
| def test_name(self): |
| # 1 |
| f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) |
| try: |
| with self.assertRaises(AttributeError): |
| f.name |
| finally: |
| f.close() |
| with self.assertRaises(ValueError): |
| f.name |
| |
| # 2 |
| with tempfile.NamedTemporaryFile(delete=False) as tmp_f: |
| filename = pathlib.Path(tmp_f.name) |
| |
| f = ZstdFile(filename) |
| try: |
| self.assertEqual(f.name, f._fp.name) |
| self.assertIsInstance(f.name, str) |
| finally: |
| f.close() |
| with self.assertRaises(ValueError): |
| f.name |
| |
| os.remove(filename) |
| |
| # 3, no .filename property |
| class C: |
| def read(self, size=-1): |
| return b'123' |
| with ZstdFile(C(), 'rb') as f: |
| with self.assertRaisesRegex(AttributeError, r'name'): |
| f.name |
| |
| def test_seekable(self): |
| f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) |
| try: |
| self.assertTrue(f.seekable()) |
| f.read() |
| self.assertTrue(f.seekable()) |
| finally: |
| f.close() |
| self.assertRaises(ValueError, f.seekable) |
| |
| f = ZstdFile(io.BytesIO(), "w") |
| try: |
| self.assertFalse(f.seekable()) |
| finally: |
| f.close() |
| self.assertRaises(ValueError, f.seekable) |
| |
| src = io.BytesIO(COMPRESSED_100_PLUS_32KB) |
| src.seekable = lambda: False |
| f = ZstdFile(src) |
| try: |
| self.assertFalse(f.seekable()) |
| finally: |
| f.close() |
| self.assertRaises(ValueError, f.seekable) |
| |
| def test_readable(self): |
| f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) |
| try: |
| self.assertTrue(f.readable()) |
| f.read() |
| self.assertTrue(f.readable()) |
| finally: |
| f.close() |
| self.assertRaises(ValueError, f.readable) |
| |
| f = ZstdFile(io.BytesIO(), "w") |
| try: |
| self.assertFalse(f.readable()) |
| finally: |
| f.close() |
| self.assertRaises(ValueError, f.readable) |
| |
| def test_writable(self): |
| f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) |
| try: |
| self.assertFalse(f.writable()) |
| f.read() |
| self.assertFalse(f.writable()) |
| finally: |
| f.close() |
| self.assertRaises(ValueError, f.writable) |
| |
| f = ZstdFile(io.BytesIO(), "w") |
| try: |
| self.assertTrue(f.writable()) |
| finally: |
| f.close() |
| self.assertRaises(ValueError, f.writable) |
| |
| def test_read_0(self): |
| with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: |
| self.assertEqual(f.read(0), b"") |
| self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) |
| with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), |
| options={DecompressionParameter.window_log_max:20}) as f: |
| self.assertEqual(f.read(0), b"") |
| |
| # empty file |
| with ZstdFile(io.BytesIO(b'')) as f: |
| self.assertEqual(f.read(0), b"") |
| with self.assertRaises(EOFError): |
| f.read(10) |
| |
| with ZstdFile(io.BytesIO(b'')) as f: |
| with self.assertRaises(EOFError): |
| f.read(10) |
| |
| def test_read_10(self): |
| with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: |
| chunks = [] |
| while True: |
| result = f.read(10) |
| if not result: |
| break |
| self.assertLessEqual(len(result), 10) |
| chunks.append(result) |
| self.assertEqual(b"".join(chunks), DECOMPRESSED_100_PLUS_32KB) |
| |
| def test_read_multistream(self): |
| with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 5)) as f: |
| self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB * 5) |
| |
| with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB + SKIPPABLE_FRAME)) as f: |
| self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) |
| |
| with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB + COMPRESSED_DAT)) as f: |
| self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB + DECOMPRESSED_DAT) |
| |
| def test_read_incomplete(self): |
| with ZstdFile(io.BytesIO(DAT_130K_C[:-200])) as f: |
| self.assertRaises(EOFError, f.read) |
| |
| # Trailing data isn't a valid compressed stream |
| with ZstdFile(io.BytesIO(self.FRAME_42 + b'12345')) as f: |
| self.assertRaises(ZstdError, f.read) |
| |
| with ZstdFile(io.BytesIO(SKIPPABLE_FRAME + b'12345')) as f: |
| self.assertRaises(ZstdError, f.read) |
| |
| def test_read_truncated(self): |
| # Drop stream epilogue: 4 bytes checksum |
| truncated = DAT_130K_C[:-4] |
| with ZstdFile(io.BytesIO(truncated)) as f: |
| self.assertRaises(EOFError, f.read) |
| |
| with ZstdFile(io.BytesIO(truncated)) as f: |
| # this is an important test, make sure it doesn't raise EOFError. |
| self.assertEqual(f.read(130*_1K), DAT_130K_D) |
| with self.assertRaises(EOFError): |
| f.read(1) |
| |
| # Incomplete header |
| for i in range(1, 20): |
| with ZstdFile(io.BytesIO(truncated[:i])) as f: |
| self.assertRaises(EOFError, f.read, 1) |
| |
| def test_read_bad_args(self): |
| f = ZstdFile(io.BytesIO(COMPRESSED_DAT)) |
| f.close() |
| self.assertRaises(ValueError, f.read) |
| with ZstdFile(io.BytesIO(), "w") as f: |
| self.assertRaises(ValueError, f.read) |
| with ZstdFile(io.BytesIO(COMPRESSED_DAT)) as f: |
| self.assertRaises(TypeError, f.read, float()) |
| |
| def test_read_bad_data(self): |
| with ZstdFile(io.BytesIO(COMPRESSED_BOGUS)) as f: |
| self.assertRaises(ZstdError, f.read) |
| |
| def test_read_exception(self): |
| class C: |
| def read(self, size=-1): |
| raise OSError |
| with ZstdFile(C()) as f: |
| with self.assertRaises(OSError): |
| f.read(10) |
| |
| def test_read1(self): |
| with ZstdFile(io.BytesIO(DAT_130K_C)) as f: |
| blocks = [] |
| while True: |
| result = f.read1() |
| if not result: |
| break |
| blocks.append(result) |
| self.assertEqual(b"".join(blocks), DAT_130K_D) |
| self.assertEqual(f.read1(), b"") |
| |
| def test_read1_0(self): |
| with ZstdFile(io.BytesIO(COMPRESSED_DAT)) as f: |
| self.assertEqual(f.read1(0), b"") |
| |
| def test_read1_10(self): |
| with ZstdFile(io.BytesIO(COMPRESSED_DAT)) as f: |
| blocks = [] |
| while True: |
| result = f.read1(10) |
| if not result: |
| break |
| blocks.append(result) |
| self.assertEqual(b"".join(blocks), DECOMPRESSED_DAT) |
| self.assertEqual(f.read1(), b"") |
| |
| def test_read1_multistream(self): |
| with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 5)) as f: |
| blocks = [] |
| while True: |
| result = f.read1() |
| if not result: |
| break |
| blocks.append(result) |
| self.assertEqual(b"".join(blocks), DECOMPRESSED_100_PLUS_32KB * 5) |
| self.assertEqual(f.read1(), b"") |
| |
| def test_read1_bad_args(self): |
| f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) |
| f.close() |
| self.assertRaises(ValueError, f.read1) |
| with ZstdFile(io.BytesIO(), "w") as f: |
| self.assertRaises(ValueError, f.read1) |
| with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: |
| self.assertRaises(TypeError, f.read1, None) |
| |
| def test_readinto(self): |
| arr = array.array("I", range(100)) |
| self.assertEqual(len(arr), 100) |
| self.assertEqual(len(arr) * arr.itemsize, 400) |
| ba = bytearray(300) |
| with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: |
| # 0 length output buffer |
| self.assertEqual(f.readinto(ba[0:0]), 0) |
| |
| # use correct length for buffer protocol object |
| self.assertEqual(f.readinto(arr), 400) |
| self.assertEqual(arr.tobytes(), DECOMPRESSED_100_PLUS_32KB[:400]) |
| |
| # normal readinto |
| self.assertEqual(f.readinto(ba), 300) |
| self.assertEqual(ba, DECOMPRESSED_100_PLUS_32KB[400:700]) |
| |
| def test_peek(self): |
| with ZstdFile(io.BytesIO(DAT_130K_C)) as f: |
| result = f.peek() |
| self.assertGreater(len(result), 0) |
| self.assertTrue(DAT_130K_D.startswith(result)) |
| self.assertEqual(f.read(), DAT_130K_D) |
| with ZstdFile(io.BytesIO(DAT_130K_C)) as f: |
| result = f.peek(10) |
| self.assertGreater(len(result), 0) |
| self.assertTrue(DAT_130K_D.startswith(result)) |
| self.assertEqual(f.read(), DAT_130K_D) |
| |
| def test_peek_bad_args(self): |
| with ZstdFile(io.BytesIO(), "w") as f: |
| self.assertRaises(ValueError, f.peek) |
| |
| def test_iterator(self): |
| with io.BytesIO(THIS_FILE_BYTES) as f: |
| lines = f.readlines() |
| compressed = compress(THIS_FILE_BYTES) |
| |
| # iter |
| with ZstdFile(io.BytesIO(compressed)) as f: |
| self.assertListEqual(list(iter(f)), lines) |
| |
| # readline |
| with ZstdFile(io.BytesIO(compressed)) as f: |
| for line in lines: |
| self.assertEqual(f.readline(), line) |
| self.assertEqual(f.readline(), b'') |
| self.assertEqual(f.readline(), b'') |
| |
| # readlines |
| with ZstdFile(io.BytesIO(compressed)) as f: |
| self.assertListEqual(f.readlines(), lines) |
| |
| def test_decompress_limited(self): |
| _ZSTD_DStreamInSize = 128*_1K + 3 |
| |
| bomb = compress(b'\0' * int(2e6), level=10) |
| self.assertLess(len(bomb), _ZSTD_DStreamInSize) |
| |
| decomp = ZstdFile(io.BytesIO(bomb)) |
| self.assertEqual(decomp.read(1), b'\0') |
| |
| # BufferedReader uses 128 KiB buffer in __init__.py |
| max_decomp = 128*_1K |
| self.assertLessEqual(decomp._buffer.raw.tell(), max_decomp, |
| "Excessive amount of data was decompressed") |
| |
| def test_write(self): |
| raw_data = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6] |
| with io.BytesIO() as dst: |
| with ZstdFile(dst, "w") as f: |
| f.write(raw_data) |
| |
| comp = ZstdCompressor() |
| expected = comp.compress(raw_data) + comp.flush() |
| self.assertEqual(dst.getvalue(), expected) |
| |
| with io.BytesIO() as dst: |
| with ZstdFile(dst, "w", level=12) as f: |
| f.write(raw_data) |
| |
| comp = ZstdCompressor(12) |
| expected = comp.compress(raw_data) + comp.flush() |
| self.assertEqual(dst.getvalue(), expected) |
| |
| with io.BytesIO() as dst: |
| with ZstdFile(dst, "w", options={CompressionParameter.checksum_flag:1}) as f: |
| f.write(raw_data) |
| |
| comp = ZstdCompressor(options={CompressionParameter.checksum_flag:1}) |
| expected = comp.compress(raw_data) + comp.flush() |
| self.assertEqual(dst.getvalue(), expected) |
| |
| with io.BytesIO() as dst: |
| options = {CompressionParameter.compression_level:-5, |
| CompressionParameter.checksum_flag:1} |
| with ZstdFile(dst, "w", |
| options=options) as f: |
| f.write(raw_data) |
| |
| comp = ZstdCompressor(options=options) |
| expected = comp.compress(raw_data) + comp.flush() |
| self.assertEqual(dst.getvalue(), expected) |
| |
| def test_write_empty_frame(self): |
| # .FLUSH_FRAME generates an empty content frame |
| c = ZstdCompressor() |
| self.assertNotEqual(c.flush(c.FLUSH_FRAME), b'') |
| self.assertNotEqual(c.flush(c.FLUSH_FRAME), b'') |
| |
| # don't generate empty content frame |
| bo = io.BytesIO() |
| with ZstdFile(bo, 'w') as f: |
| pass |
| self.assertEqual(bo.getvalue(), b'') |
| |
| bo = io.BytesIO() |
| with ZstdFile(bo, 'w') as f: |
| f.flush(f.FLUSH_FRAME) |
| self.assertEqual(bo.getvalue(), b'') |
| |
| # if .write(b''), generate empty content frame |
| bo = io.BytesIO() |
| with ZstdFile(bo, 'w') as f: |
| f.write(b'') |
| self.assertNotEqual(bo.getvalue(), b'') |
| |
| # has an empty content frame |
| bo = io.BytesIO() |
| with ZstdFile(bo, 'w') as f: |
| f.flush(f.FLUSH_BLOCK) |
| self.assertNotEqual(bo.getvalue(), b'') |
| |
| def test_write_empty_block(self): |
| # If no internal data, .FLUSH_BLOCK return b''. |
| c = ZstdCompressor() |
| self.assertEqual(c.flush(c.FLUSH_BLOCK), b'') |
| self.assertNotEqual(c.compress(b'123', c.FLUSH_BLOCK), |
| b'') |
| self.assertEqual(c.flush(c.FLUSH_BLOCK), b'') |
| self.assertEqual(c.compress(b''), b'') |
| self.assertEqual(c.compress(b''), b'') |
| self.assertEqual(c.flush(c.FLUSH_BLOCK), b'') |
| |
| # mode = .last_mode |
| bo = io.BytesIO() |
| with ZstdFile(bo, 'w') as f: |
| f.write(b'123') |
| f.flush(f.FLUSH_BLOCK) |
| fp_pos = f._fp.tell() |
| self.assertNotEqual(fp_pos, 0) |
| f.flush(f.FLUSH_BLOCK) |
| self.assertEqual(f._fp.tell(), fp_pos) |
| |
| # mode != .last_mode |
| bo = io.BytesIO() |
| with ZstdFile(bo, 'w') as f: |
| f.flush(f.FLUSH_BLOCK) |
| self.assertEqual(f._fp.tell(), 0) |
| f.write(b'') |
| f.flush(f.FLUSH_BLOCK) |
| self.assertEqual(f._fp.tell(), 0) |
| |
| def test_write_101(self): |
| with io.BytesIO() as dst: |
| with ZstdFile(dst, "w") as f: |
| for start in range(0, len(THIS_FILE_BYTES), 101): |
| f.write(THIS_FILE_BYTES[start:start+101]) |
| |
| comp = ZstdCompressor() |
| expected = comp.compress(THIS_FILE_BYTES) + comp.flush() |
| self.assertEqual(dst.getvalue(), expected) |
| |
| def test_write_append(self): |
| def comp(data): |
| comp = ZstdCompressor() |
| return comp.compress(data) + comp.flush() |
| |
| part1 = THIS_FILE_BYTES[:_1K] |
| part2 = THIS_FILE_BYTES[_1K:1536] |
| part3 = THIS_FILE_BYTES[1536:] |
| expected = b"".join(comp(x) for x in (part1, part2, part3)) |
| with io.BytesIO() as dst: |
| with ZstdFile(dst, "w") as f: |
| f.write(part1) |
| with ZstdFile(dst, "a") as f: |
| f.write(part2) |
| with ZstdFile(dst, "a") as f: |
| f.write(part3) |
| self.assertEqual(dst.getvalue(), expected) |
| |
| def test_write_bad_args(self): |
| f = ZstdFile(io.BytesIO(), "w") |
| f.close() |
| self.assertRaises(ValueError, f.write, b"foo") |
| with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r") as f: |
| self.assertRaises(ValueError, f.write, b"bar") |
| with ZstdFile(io.BytesIO(), "w") as f: |
| self.assertRaises(TypeError, f.write, None) |
| self.assertRaises(TypeError, f.write, "text") |
| self.assertRaises(TypeError, f.write, 789) |
| |
| def test_writelines(self): |
| def comp(data): |
| comp = ZstdCompressor() |
| return comp.compress(data) + comp.flush() |
| |
| with io.BytesIO(THIS_FILE_BYTES) as f: |
| lines = f.readlines() |
| with io.BytesIO() as dst: |
| with ZstdFile(dst, "w") as f: |
| f.writelines(lines) |
| expected = comp(THIS_FILE_BYTES) |
| self.assertEqual(dst.getvalue(), expected) |
| |
| def test_seek_forward(self): |
| with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: |
| f.seek(555) |
| self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[555:]) |
| |
| def test_seek_forward_across_streams(self): |
| with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 2)) as f: |
| f.seek(len(DECOMPRESSED_100_PLUS_32KB) + 123) |
| self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[123:]) |
| |
| def test_seek_forward_relative_to_current(self): |
| with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: |
| f.read(100) |
| f.seek(1236, 1) |
| self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[1336:]) |
| |
| def test_seek_forward_relative_to_end(self): |
| with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: |
| f.seek(-555, 2) |
| self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[-555:]) |
| |
| def test_seek_backward(self): |
| with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: |
| f.read(1001) |
| f.seek(211) |
| self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[211:]) |
| |
| def test_seek_backward_across_streams(self): |
| with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 2)) as f: |
| f.read(len(DECOMPRESSED_100_PLUS_32KB) + 333) |
| f.seek(737) |
| self.assertEqual(f.read(), |
| DECOMPRESSED_100_PLUS_32KB[737:] + DECOMPRESSED_100_PLUS_32KB) |
| |
| def test_seek_backward_relative_to_end(self): |
| with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: |
| f.seek(-150, 2) |
| self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[-150:]) |
| |
| def test_seek_past_end(self): |
| with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: |
| f.seek(len(DECOMPRESSED_100_PLUS_32KB) + 9001) |
| self.assertEqual(f.tell(), len(DECOMPRESSED_100_PLUS_32KB)) |
| self.assertEqual(f.read(), b"") |
| |
| def test_seek_past_start(self): |
| with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: |
| f.seek(-88) |
| self.assertEqual(f.tell(), 0) |
| self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) |
| |
| def test_seek_bad_args(self): |
| f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) |
| f.close() |
| self.assertRaises(ValueError, f.seek, 0) |
| with ZstdFile(io.BytesIO(), "w") as f: |
| self.assertRaises(ValueError, f.seek, 0) |
| with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f: |
| self.assertRaises(ValueError, f.seek, 0, 3) |
| # io.BufferedReader raises TypeError instead of ValueError |
| self.assertRaises((TypeError, ValueError), f.seek, 9, ()) |
| self.assertRaises(TypeError, f.seek, None) |
| self.assertRaises(TypeError, f.seek, b"derp") |
| |
| def test_seek_not_seekable(self): |
| class C(io.BytesIO): |
| def seekable(self): |
| return False |
| obj = C(COMPRESSED_100_PLUS_32KB) |
| with ZstdFile(obj, 'r') as f: |
| d = f.read(1) |
| self.assertFalse(f.seekable()) |
| with self.assertRaisesRegex(io.UnsupportedOperation, |
| 'File or stream is not seekable'): |
| f.seek(0) |
| d += f.read() |
| self.assertEqual(d, DECOMPRESSED_100_PLUS_32KB) |
| |
| def test_tell(self): |
| with ZstdFile(io.BytesIO(DAT_130K_C)) as f: |
| pos = 0 |
| while True: |
| self.assertEqual(f.tell(), pos) |
| result = f.read(random.randint(171, 189)) |
| if not result: |
| break |
| pos += len(result) |
| self.assertEqual(f.tell(), len(DAT_130K_D)) |
| with ZstdFile(io.BytesIO(), "w") as f: |
| for pos in range(0, len(DAT_130K_D), 143): |
| self.assertEqual(f.tell(), pos) |
| f.write(DAT_130K_D[pos:pos+143]) |
| self.assertEqual(f.tell(), len(DAT_130K_D)) |
| |
| def test_tell_bad_args(self): |
| f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) |
| f.close() |
| self.assertRaises(ValueError, f.tell) |
| |
| def test_file_dict(self): |
| # default |
| bi = io.BytesIO() |
| with ZstdFile(bi, 'w', zstd_dict=TRAINED_DICT) as f: |
| f.write(SAMPLES[0]) |
| bi.seek(0) |
| with ZstdFile(bi, zstd_dict=TRAINED_DICT) as f: |
| dat = f.read() |
| self.assertEqual(dat, SAMPLES[0]) |
| |
| # .as_(un)digested_dict |
| bi = io.BytesIO() |
| with ZstdFile(bi, 'w', zstd_dict=TRAINED_DICT.as_digested_dict) as f: |
| f.write(SAMPLES[0]) |
| bi.seek(0) |
| with ZstdFile(bi, zstd_dict=TRAINED_DICT.as_undigested_dict) as f: |
| dat = f.read() |
| self.assertEqual(dat, SAMPLES[0]) |
| |
| def test_file_prefix(self): |
| bi = io.BytesIO() |
| with ZstdFile(bi, 'w', zstd_dict=TRAINED_DICT.as_prefix) as f: |
| f.write(SAMPLES[0]) |
| bi.seek(0) |
| with ZstdFile(bi, zstd_dict=TRAINED_DICT.as_prefix) as f: |
| dat = f.read() |
| self.assertEqual(dat, SAMPLES[0]) |
| |
| def test_UnsupportedOperation(self): |
| # 1 |
| with ZstdFile(io.BytesIO(), 'r') as f: |
| with self.assertRaises(io.UnsupportedOperation): |
| f.write(b'1234') |
| |
| # 2 |
| class T: |
| def read(self, size): |
| return b'a' * size |
| |
| with self.assertRaises(TypeError): # on creation |
| with ZstdFile(T(), 'w') as f: |
| pass |
| |
| # 3 |
| with ZstdFile(io.BytesIO(), 'w') as f: |
| with self.assertRaises(io.UnsupportedOperation): |
| f.read(100) |
| with self.assertRaises(io.UnsupportedOperation): |
| f.seek(100) |
| self.assertEqual(f.closed, True) |
| with self.assertRaises(ValueError): |
| f.readable() |
| with self.assertRaises(ValueError): |
| f.tell() |
| with self.assertRaises(ValueError): |
| f.read(100) |
| |
| def test_read_readinto_readinto1(self): |
| lst = [] |
| with ZstdFile(io.BytesIO(COMPRESSED_THIS_FILE*5)) as f: |
| while True: |
| method = random.randint(0, 2) |
| size = random.randint(0, 300) |
| |
| if method == 0: |
| dat = f.read(size) |
| if not dat and size: |
| break |
| lst.append(dat) |
| elif method == 1: |
| ba = bytearray(size) |
| read_size = f.readinto(ba) |
| if read_size == 0 and size: |
| break |
| lst.append(bytes(ba[:read_size])) |
| elif method == 2: |
| ba = bytearray(size) |
| read_size = f.readinto1(ba) |
| if read_size == 0 and size: |
| break |
| lst.append(bytes(ba[:read_size])) |
| self.assertEqual(b''.join(lst), THIS_FILE_BYTES*5) |
| |
| def test_zstdfile_flush(self): |
| # closed |
| f = ZstdFile(io.BytesIO(), 'w') |
| f.close() |
| with self.assertRaises(ValueError): |
| f.flush() |
| |
| # read |
| with ZstdFile(io.BytesIO(), 'r') as f: |
| # does nothing for read-only stream |
| f.flush() |
| |
| # write |
| DAT = b'abcd' |
| bi = io.BytesIO() |
| with ZstdFile(bi, 'w') as f: |
| self.assertEqual(f.write(DAT), len(DAT)) |
| self.assertEqual(f.tell(), len(DAT)) |
| self.assertEqual(bi.tell(), 0) # not enough for a block |
| |
| self.assertEqual(f.flush(), None) |
| self.assertEqual(f.tell(), len(DAT)) |
| self.assertGreater(bi.tell(), 0) # flushed |
| |
| # write, no .flush() method |
| class C: |
| def write(self, b): |
| return len(b) |
| with ZstdFile(C(), 'w') as f: |
| self.assertEqual(f.write(DAT), len(DAT)) |
| self.assertEqual(f.tell(), len(DAT)) |
| |
| self.assertEqual(f.flush(), None) |
| self.assertEqual(f.tell(), len(DAT)) |
| |
| def test_zstdfile_flush_mode(self): |
| self.assertEqual(ZstdFile.FLUSH_BLOCK, ZstdCompressor.FLUSH_BLOCK) |
| self.assertEqual(ZstdFile.FLUSH_FRAME, ZstdCompressor.FLUSH_FRAME) |
| with self.assertRaises(AttributeError): |
| ZstdFile.CONTINUE |
| |
| bo = io.BytesIO() |
| with ZstdFile(bo, 'w') as f: |
| # flush block |
| self.assertEqual(f.write(b'123'), 3) |
| self.assertIsNone(f.flush(f.FLUSH_BLOCK)) |
| p1 = bo.tell() |
| # mode == .last_mode, should return |
| self.assertIsNone(f.flush()) |
| p2 = bo.tell() |
| self.assertEqual(p1, p2) |
| # flush frame |
| self.assertEqual(f.write(b'456'), 3) |
| self.assertIsNone(f.flush(mode=f.FLUSH_FRAME)) |
| # flush frame |
| self.assertEqual(f.write(b'789'), 3) |
| self.assertIsNone(f.flush(f.FLUSH_FRAME)) |
| p1 = bo.tell() |
| # mode == .last_mode, should return |
| self.assertIsNone(f.flush(f.FLUSH_FRAME)) |
| p2 = bo.tell() |
| self.assertEqual(p1, p2) |
| self.assertEqual(decompress(bo.getvalue()), b'123456789') |
| |
| bo = io.BytesIO() |
| with ZstdFile(bo, 'w') as f: |
| f.write(b'123') |
| with self.assertRaisesRegex(ValueError, r'\.FLUSH_.*?\.FLUSH_'): |
| f.flush(ZstdCompressor.CONTINUE) |
| with self.assertRaises(ValueError): |
| f.flush(-1) |
| with self.assertRaises(ValueError): |
| f.flush(123456) |
| with self.assertRaises(TypeError): |
| f.flush(node=ZstdCompressor.CONTINUE) |
| with self.assertRaises((TypeError, ValueError)): |
| f.flush('FLUSH_FRAME') |
| with self.assertRaises(TypeError): |
| f.flush(b'456', f.FLUSH_BLOCK) |
| |
| def test_zstdfile_truncate(self): |
| with ZstdFile(io.BytesIO(), 'w') as f: |
| with self.assertRaises(io.UnsupportedOperation): |
| f.truncate(200) |
| |
| def test_zstdfile_iter_issue45475(self): |
| lines = [l for l in ZstdFile(io.BytesIO(COMPRESSED_THIS_FILE))] |
| self.assertGreater(len(lines), 0) |
| |
| def test_append_new_file(self): |
| with tempfile.NamedTemporaryFile(delete=True) as tmp_f: |
| filename = tmp_f.name |
| |
| with ZstdFile(filename, 'a') as f: |
| pass |
| self.assertTrue(os.path.isfile(filename)) |
| |
| os.remove(filename) |
| |
| class OpenTestCase(unittest.TestCase): |
| |
| def test_binary_modes(self): |
| with open(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rb") as f: |
| self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) |
| with io.BytesIO() as bio: |
| with open(bio, "wb") as f: |
| f.write(DECOMPRESSED_100_PLUS_32KB) |
| file_data = decompress(bio.getvalue()) |
| self.assertEqual(file_data, DECOMPRESSED_100_PLUS_32KB) |
| with open(bio, "ab") as f: |
| f.write(DECOMPRESSED_100_PLUS_32KB) |
| file_data = decompress(bio.getvalue()) |
| self.assertEqual(file_data, DECOMPRESSED_100_PLUS_32KB * 2) |
| |
| def test_text_modes(self): |
| # empty input |
| with self.assertRaises(EOFError): |
| with open(io.BytesIO(b''), "rt", encoding="utf-8", newline='\n') as reader: |
| for _ in reader: |
| pass |
| |
| # read |
| uncompressed = THIS_FILE_STR.replace(os.linesep, "\n") |
| with open(io.BytesIO(COMPRESSED_THIS_FILE), "rt", encoding="utf-8") as f: |
| self.assertEqual(f.read(), uncompressed) |
| |
| with io.BytesIO() as bio: |
| # write |
| with open(bio, "wt", encoding="utf-8") as f: |
| f.write(uncompressed) |
| file_data = decompress(bio.getvalue()).decode("utf-8") |
| self.assertEqual(file_data.replace(os.linesep, "\n"), uncompressed) |
| # append |
| with open(bio, "at", encoding="utf-8") as f: |
| f.write(uncompressed) |
| file_data = decompress(bio.getvalue()).decode("utf-8") |
| self.assertEqual(file_data.replace(os.linesep, "\n"), uncompressed * 2) |
| |
| def test_bad_params(self): |
| with tempfile.NamedTemporaryFile(delete=False) as tmp_f: |
| TESTFN = pathlib.Path(tmp_f.name) |
| |
| with self.assertRaises(ValueError): |
| open(TESTFN, "") |
| with self.assertRaises(ValueError): |
| open(TESTFN, "rbt") |
| with self.assertRaises(ValueError): |
| open(TESTFN, "rb", encoding="utf-8") |
| with self.assertRaises(ValueError): |
| open(TESTFN, "rb", errors="ignore") |
| with self.assertRaises(ValueError): |
| open(TESTFN, "rb", newline="\n") |
| |
| os.remove(TESTFN) |
| |
| def test_option(self): |
| options = {DecompressionParameter.window_log_max:25} |
| with open(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rb", options=options) as f: |
| self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB) |
| |
| options = {CompressionParameter.compression_level:12} |
| with io.BytesIO() as bio: |
| with open(bio, "wb", options=options) as f: |
| f.write(DECOMPRESSED_100_PLUS_32KB) |
| file_data = decompress(bio.getvalue()) |
| self.assertEqual(file_data, DECOMPRESSED_100_PLUS_32KB) |
| |
| def test_encoding(self): |
| uncompressed = THIS_FILE_STR.replace(os.linesep, "\n") |
| |
| with io.BytesIO() as bio: |
| with open(bio, "wt", encoding="utf-16-le") as f: |
| f.write(uncompressed) |
| file_data = decompress(bio.getvalue()).decode("utf-16-le") |
| self.assertEqual(file_data.replace(os.linesep, "\n"), uncompressed) |
| bio.seek(0) |
| with open(bio, "rt", encoding="utf-16-le") as f: |
| self.assertEqual(f.read().replace(os.linesep, "\n"), uncompressed) |
| |
| def test_encoding_error_handler(self): |
| with io.BytesIO(compress(b"foo\xffbar")) as bio: |
| with open(bio, "rt", encoding="ascii", errors="ignore") as f: |
| self.assertEqual(f.read(), "foobar") |
| |
| def test_newline(self): |
| # Test with explicit newline (universal newline mode disabled). |
| text = THIS_FILE_STR.replace(os.linesep, "\n") |
| with io.BytesIO() as bio: |
| with open(bio, "wt", encoding="utf-8", newline="\n") as f: |
| f.write(text) |
| bio.seek(0) |
| with open(bio, "rt", encoding="utf-8", newline="\r") as f: |
| self.assertEqual(f.readlines(), [text]) |
| |
| def test_x_mode(self): |
| with tempfile.NamedTemporaryFile(delete=False) as tmp_f: |
| TESTFN = pathlib.Path(tmp_f.name) |
| |
| for mode in ("x", "xb", "xt"): |
| os.remove(TESTFN) |
| |
| if mode == "xt": |
| encoding = "utf-8" |
| else: |
| encoding = None |
| with open(TESTFN, mode, encoding=encoding): |
| pass |
| with self.assertRaises(FileExistsError): |
| with open(TESTFN, mode): |
| pass |
| |
| os.remove(TESTFN) |
| |
| def test_open_dict(self): |
| # default |
| bi = io.BytesIO() |
| with open(bi, 'w', zstd_dict=TRAINED_DICT) as f: |
| f.write(SAMPLES[0]) |
| bi.seek(0) |
| with open(bi, zstd_dict=TRAINED_DICT) as f: |
| dat = f.read() |
| self.assertEqual(dat, SAMPLES[0]) |
| |
| # .as_(un)digested_dict |
| bi = io.BytesIO() |
| with open(bi, 'w', zstd_dict=TRAINED_DICT.as_digested_dict) as f: |
| f.write(SAMPLES[0]) |
| bi.seek(0) |
| with open(bi, zstd_dict=TRAINED_DICT.as_undigested_dict) as f: |
| dat = f.read() |
| self.assertEqual(dat, SAMPLES[0]) |
| |
| # invalid dictionary |
| bi = io.BytesIO() |
| with self.assertRaisesRegex(TypeError, 'zstd_dict'): |
| open(bi, 'w', zstd_dict={1:2, 2:3}) |
| |
| with self.assertRaisesRegex(TypeError, 'zstd_dict'): |
| open(bi, 'w', zstd_dict=b'1234567890') |
| |
| def test_open_prefix(self): |
| bi = io.BytesIO() |
| with open(bi, 'w', zstd_dict=TRAINED_DICT.as_prefix) as f: |
| f.write(SAMPLES[0]) |
| bi.seek(0) |
| with open(bi, zstd_dict=TRAINED_DICT.as_prefix) as f: |
| dat = f.read() |
| self.assertEqual(dat, SAMPLES[0]) |
| |
| def test_buffer_protocol(self): |
| # don't use len() for buffer protocol objects |
| arr = array.array("i", range(1000)) |
| LENGTH = len(arr) * arr.itemsize |
| |
| with open(io.BytesIO(), "wb") as f: |
| self.assertEqual(f.write(arr), LENGTH) |
| self.assertEqual(f.tell(), LENGTH) |
| |
| class FreeThreadingMethodTests(unittest.TestCase): |
| |
| @threading_helper.reap_threads |
| @threading_helper.requires_working_threading() |
| def test_compress_locking(self): |
| input = b'a'* (16*_1K) |
| num_threads = 8 |
| |
| # gh-136394: the first output of .compress() includes the frame header |
| # we run the first .compress() call outside of the threaded portion |
| # to make the test order-independent |
| |
| comp = ZstdCompressor() |
| parts = [comp.compress(input, ZstdCompressor.FLUSH_BLOCK)] |
| for _ in range(num_threads): |
| res = comp.compress(input, ZstdCompressor.FLUSH_BLOCK) |
| if res: |
| parts.append(res) |
| rest1 = comp.flush() |
| expected = b''.join(parts) + rest1 |
| |
| comp = ZstdCompressor() |
| output = [comp.compress(input, ZstdCompressor.FLUSH_BLOCK)] |
| def run_method(method, input_data, output_data): |
| res = method(input_data, ZstdCompressor.FLUSH_BLOCK) |
| if res: |
| output_data.append(res) |
| threads = [] |
| |
| for i in range(num_threads): |
| thread = threading.Thread(target=run_method, args=(comp.compress, input, output)) |
| |
| threads.append(thread) |
| |
| with threading_helper.start_threads(threads): |
| pass |
| |
| rest2 = comp.flush() |
| self.assertEqual(rest1, rest2) |
| actual = b''.join(output) + rest2 |
| self.assertEqual(expected, actual) |
| |
| @threading_helper.reap_threads |
| @threading_helper.requires_working_threading() |
| def test_decompress_locking(self): |
| input = compress(b'a'* (16*_1K)) |
| num_threads = 8 |
| # to ensure we decompress over multiple calls, set maxsize |
| window_size = _1K * 16//num_threads |
| |
| decomp = ZstdDecompressor() |
| parts = [] |
| for _ in range(num_threads): |
| res = decomp.decompress(input, window_size) |
| if res: |
| parts.append(res) |
| expected = b''.join(parts) |
| |
| comp = ZstdDecompressor() |
| output = [] |
| def run_method(method, input_data, output_data): |
| res = method(input_data, window_size) |
| if res: |
| output_data.append(res) |
| threads = [] |
| |
| for i in range(num_threads): |
| thread = threading.Thread(target=run_method, args=(comp.decompress, input, output)) |
| |
| threads.append(thread) |
| |
| with threading_helper.start_threads(threads): |
| pass |
| |
| actual = b''.join(output) |
| self.assertEqual(expected, actual) |
| |
| @threading_helper.reap_threads |
| @threading_helper.requires_working_threading() |
| def test_compress_shared_dict(self): |
| num_threads = 8 |
| |
| def run_method(b): |
| level = threading.get_ident() % 4 |
| # sync threads to increase chance of contention on |
| # capsule storing dictionary levels |
| b.wait() |
| ZstdCompressor(level=level, |
| zstd_dict=TRAINED_DICT.as_digested_dict) |
| b.wait() |
| ZstdCompressor(level=level, |
| zstd_dict=TRAINED_DICT.as_undigested_dict) |
| b.wait() |
| ZstdCompressor(level=level, |
| zstd_dict=TRAINED_DICT.as_prefix) |
| threads = [] |
| |
| b = threading.Barrier(num_threads) |
| for i in range(num_threads): |
| thread = threading.Thread(target=run_method, args=(b,)) |
| |
| threads.append(thread) |
| |
| with threading_helper.start_threads(threads): |
| pass |
| |
| @threading_helper.reap_threads |
| @threading_helper.requires_working_threading() |
| def test_decompress_shared_dict(self): |
| num_threads = 8 |
| |
| def run_method(b): |
| # sync threads to increase chance of contention on |
| # decompression dictionary |
| b.wait() |
| ZstdDecompressor(zstd_dict=TRAINED_DICT.as_digested_dict) |
| b.wait() |
| ZstdDecompressor(zstd_dict=TRAINED_DICT.as_undigested_dict) |
| b.wait() |
| ZstdDecompressor(zstd_dict=TRAINED_DICT.as_prefix) |
| threads = [] |
| |
| b = threading.Barrier(num_threads) |
| for i in range(num_threads): |
| thread = threading.Thread(target=run_method, args=(b,)) |
| |
| threads.append(thread) |
| |
| with threading_helper.start_threads(threads): |
| pass |
| |
| |
| if __name__ == "__main__": |
| unittest.main() |