1
// Copyright 2018 Klaus Post. All rights reserved.
2
// Use of this source code is governed by a BSD-style
3
// license that can be found in the LICENSE file.
4
// Based on work Copyright (c) 2013, Yann Collet, released under BSD License.
13
// Compress the input bytes. Input must be < 2GB.
14
// Provide a Scratch buffer to avoid memory allocations.
15
// Note that the output is also kept in the scratch buffer.
16
// If input is too hard to compress, ErrIncompressible is returned.
17
// If input is a single byte value repeated ErrUseRLE is returned.
18
func Compress(in []byte, s *Scratch) ([]byte, error) {
20
return nil, ErrIncompressible
22
if len(in) > (2<<30)-1 {
23
return nil, errors.New("input too big, must be < 2GB")
25
s, err := s.prepare(in)
30
// Create histogram, if none was provided.
31
maxCount := s.maxCount
33
maxCount = s.countSimple(in)
35
// Reset for next run.
38
if maxCount == len(in) {
39
// One symbol, use RLE
42
if maxCount == 1 || maxCount < (len(in)>>7) {
43
// Each symbol present maximum once or too well distributed.
44
return nil, ErrIncompressible
47
err = s.normalizeCount()
57
err = s.validateNorm()
72
// Check if we compressed.
73
if len(s.Out) >= len(in) {
74
return nil, ErrIncompressible
79
// cState contains the compression state of a stream.
86
// init will initialize the compression state to the first symbol of the stream.
87
func (c *cState) init(bw *bitWriter, ct *cTable, tableLog uint8, first symbolTransform) {
89
c.stateTable = ct.stateTable
91
nbBitsOut := (first.deltaNbBits + (1 << 15)) >> 16
92
im := int32((nbBitsOut << 16) - first.deltaNbBits)
93
lu := (im >> nbBitsOut) + first.deltaFindState
94
c.state = c.stateTable[lu]
97
// encode the output symbol provided and write it to the bitstream.
98
func (c *cState) encode(symbolTT symbolTransform) {
99
nbBitsOut := (uint32(c.state) + symbolTT.deltaNbBits) >> 16
100
dstState := int32(c.state>>(nbBitsOut&15)) + symbolTT.deltaFindState
101
c.bw.addBits16NC(c.state, uint8(nbBitsOut))
102
c.state = c.stateTable[dstState]
105
// encode the output symbol provided and write it to the bitstream.
106
func (c *cState) encodeZero(symbolTT symbolTransform) {
107
nbBitsOut := (uint32(c.state) + symbolTT.deltaNbBits) >> 16
108
dstState := int32(c.state>>(nbBitsOut&15)) + symbolTT.deltaFindState
109
c.bw.addBits16ZeroNC(c.state, uint8(nbBitsOut))
110
c.state = c.stateTable[dstState]
113
// flush will write the tablelog to the output and flush the remaining full bytes.
114
func (c *cState) flush(tableLog uint8) {
116
c.bw.addBits16NC(c.state, tableLog)
120
// compress is the main compression loop that will encode the input from the last byte to the first.
121
func (s *Scratch) compress(src []byte) error {
123
return errors.New("compress: src too small")
125
tt := s.ct.symbolTT[:256]
128
// Our two states each encodes every second byte.
129
// Last byte encoded (first byte decoded) will always be encoded by c1.
132
// Encode so remaining size is divisible by 4.
135
c1.init(&s.bw, &s.ct, s.actualTableLog, tt[src[ip-1]])
136
c2.init(&s.bw, &s.ct, s.actualTableLog, tt[src[ip-2]])
137
c1.encodeZero(tt[src[ip-3]])
140
c2.init(&s.bw, &s.ct, s.actualTableLog, tt[src[ip-1]])
141
c1.init(&s.bw, &s.ct, s.actualTableLog, tt[src[ip-2]])
145
c2.encodeZero(tt[src[ip-1]])
146
c1.encodeZero(tt[src[ip-2]])
150
// Main compression loop.
152
case !s.zeroBits && s.actualTableLog <= 8:
153
// We can encode 4 symbols without requiring a flush.
154
// We do not need to check if any output is 0 bits.
157
v3, v2, v1, v0 := src[ip-4], src[ip-3], src[ip-2], src[ip-1]
165
// We do not need to check if any output is 0 bits.
168
v3, v2, v1, v0 := src[ip-4], src[ip-3], src[ip-2], src[ip-1]
176
case s.actualTableLog <= 8:
177
// We can encode 4 symbols without requiring a flush
180
v3, v2, v1, v0 := src[ip-4], src[ip-3], src[ip-2], src[ip-1]
181
c2.encodeZero(tt[v0])
182
c1.encodeZero(tt[v1])
183
c2.encodeZero(tt[v2])
184
c1.encodeZero(tt[v3])
190
v3, v2, v1, v0 := src[ip-4], src[ip-3], src[ip-2], src[ip-1]
191
c2.encodeZero(tt[v0])
192
c1.encodeZero(tt[v1])
194
c2.encodeZero(tt[v2])
195
c1.encodeZero(tt[v3])
200
// Flush final state.
201
// Used to initialize state when decoding.
202
c2.flush(s.actualTableLog)
203
c1.flush(s.actualTableLog)
208
// writeCount will write the normalized histogram count to header.
209
// This is read back by readNCount.
210
func (s *Scratch) writeCount() error {
212
tableLog = s.actualTableLog
213
tableSize = 1 << tableLog
217
maxHeaderSize = ((int(s.symbolLen) * int(tableLog)) >> 3) + 3
220
bitStream = uint32(tableLog - minTablelog)
222
remaining = int16(tableSize + 1) /* +1 for extra accuracy */
223
threshold = int16(tableSize)
224
nbBits = uint(tableLog + 1)
226
if cap(s.Out) < maxHeaderSize {
227
s.Out = make([]byte, 0, s.br.remain()+maxHeaderSize)
230
out := s.Out[:maxHeaderSize]
236
for s.norm[charnum] == 0 {
239
for charnum >= start+24 {
241
bitStream += uint32(0xFFFF) << bitCount
242
out[outP] = byte(bitStream)
243
out[outP+1] = byte(bitStream >> 8)
247
for charnum >= start+3 {
249
bitStream += 3 << bitCount
252
bitStream += uint32(charnum-start) << bitCount
255
out[outP] = byte(bitStream)
256
out[outP+1] = byte(bitStream >> 8)
263
count := s.norm[charnum]
265
max := (2*threshold - 1) - remaining
271
count++ // +1 for extra accuracy
272
if count >= threshold {
273
count += max // [0..max[ [max..threshold[ (...) [threshold+max 2*threshold[
275
bitStream += uint32(count) << bitCount
281
previous0 = count == 1
283
return errors.New("internal error: remaining<1")
285
for remaining < threshold {
291
out[outP] = byte(bitStream)
292
out[outP+1] = byte(bitStream >> 8)
299
out[outP] = byte(bitStream)
300
out[outP+1] = byte(bitStream >> 8)
301
outP += (bitCount + 7) / 8
303
if charnum > s.symbolLen {
304
return errors.New("internal error: charnum > s.symbolLen")
310
// symbolTransform contains the state transform for a symbol.
311
type symbolTransform struct {
316
// String prints values as a human readable string.
317
func (s symbolTransform) String() string {
318
return fmt.Sprintf("dnbits: %08x, fs:%d", s.deltaNbBits, s.deltaFindState)
321
// cTable contains tables used for compression.
325
symbolTT []symbolTransform
328
// allocCtable will allocate tables needed for compression.
329
// If existing tables a re big enough, they are simply re-used.
330
func (s *Scratch) allocCtable() {
331
tableSize := 1 << s.actualTableLog
332
// get tableSymbol that is big enough.
333
if cap(s.ct.tableSymbol) < tableSize {
334
s.ct.tableSymbol = make([]byte, tableSize)
336
s.ct.tableSymbol = s.ct.tableSymbol[:tableSize]
339
if cap(s.ct.stateTable) < ctSize {
340
s.ct.stateTable = make([]uint16, ctSize)
342
s.ct.stateTable = s.ct.stateTable[:ctSize]
344
if cap(s.ct.symbolTT) < 256 {
345
s.ct.symbolTT = make([]symbolTransform, 256)
347
s.ct.symbolTT = s.ct.symbolTT[:256]
350
// buildCTable will populate the compression table so it is ready to be used.
351
func (s *Scratch) buildCTable() error {
352
tableSize := uint32(1 << s.actualTableLog)
353
highThreshold := tableSize - 1
354
var cumul [maxSymbolValue + 2]int16
357
tableSymbol := s.ct.tableSymbol[:tableSize]
358
// symbol start positions
361
for ui, v := range s.norm[:s.symbolLen-1] {
362
u := byte(ui) // one less than reference
365
cumul[u+1] = cumul[u] + 1
366
tableSymbol[highThreshold] = u
369
cumul[u+1] = cumul[u] + v
372
// Encode last symbol separately to avoid overflowing u
373
u := int(s.symbolLen - 1)
374
v := s.norm[s.symbolLen-1]
377
cumul[u+1] = cumul[u] + 1
378
tableSymbol[highThreshold] = byte(u)
381
cumul[u+1] = cumul[u] + v
383
if uint32(cumul[s.symbolLen]) != tableSize {
384
return fmt.Errorf("internal error: expected cumul[s.symbolLen] (%d) == tableSize (%d)", cumul[s.symbolLen], tableSize)
386
cumul[s.symbolLen] = int16(tableSize) + 1
391
step := tableStep(tableSize)
392
tableMask := tableSize - 1
394
// if any symbol > largeLimit, we may have 0 bits output.
395
largeLimit := int16(1 << (s.actualTableLog - 1))
396
for ui, v := range s.norm[:s.symbolLen] {
401
for nbOccurrences := int16(0); nbOccurrences < v; nbOccurrences++ {
402
tableSymbol[position] = symbol
403
position = (position + step) & tableMask
404
for position > highThreshold {
405
position = (position + step) & tableMask
406
} /* Low proba area */
410
// Check if we have gone through all positions
412
return errors.New("position!=0")
417
table := s.ct.stateTable
419
tsi := int(tableSize)
420
for u, v := range tableSymbol {
421
// TableU16 : sorted by symbol order; gives next state value
422
table[cumul[v]] = uint16(tsi + u)
427
// Build Symbol Transformation Table
430
symbolTT := s.ct.symbolTT[:s.symbolLen]
431
tableLog := s.actualTableLog
432
tl := (uint32(tableLog) << 16) - (1 << tableLog)
433
for i, v := range s.norm[:s.symbolLen] {
437
symbolTT[i].deltaNbBits = tl
438
symbolTT[i].deltaFindState = int32(total - 1)
441
maxBitsOut := uint32(tableLog) - highBits(uint32(v-1))
442
minStatePlus := uint32(v) << maxBitsOut
443
symbolTT[i].deltaNbBits = (maxBitsOut << 16) - minStatePlus
444
symbolTT[i].deltaFindState = int32(total - v)
448
if total != int16(tableSize) {
449
return fmt.Errorf("total mismatch %d (got) != %d (want)", total, tableSize)
455
// countSimple will create a simple histogram in s.count.
456
// Returns the biggest count.
457
// Does not update s.clearCount.
458
func (s *Scratch) countSimple(in []byte) (max int) {
459
for _, v := range in {
463
for i, v := range s.count[:] {
468
s.symbolLen = uint16(i) + 1
474
// minTableLog provides the minimum logSize to safely represent a distribution.
475
func (s *Scratch) minTableLog() uint8 {
476
minBitsSrc := highBits(uint32(s.br.remain()-1)) + 1
477
minBitsSymbols := highBits(uint32(s.symbolLen-1)) + 2
478
if minBitsSrc < minBitsSymbols {
479
return uint8(minBitsSrc)
481
return uint8(minBitsSymbols)
484
// optimalTableLog calculates and sets the optimal tableLog in s.actualTableLog
485
func (s *Scratch) optimalTableLog() {
486
tableLog := s.TableLog
487
minBits := s.minTableLog()
488
maxBitsSrc := uint8(highBits(uint32(s.br.remain()-1))) - 2
489
if maxBitsSrc < tableLog {
490
// Accuracy can be reduced
491
tableLog = maxBitsSrc
493
if minBits > tableLog {
496
// Need a minimum to safely represent all symbol values
497
if tableLog < minTablelog {
498
tableLog = minTablelog
500
if tableLog > maxTableLog {
501
tableLog = maxTableLog
503
s.actualTableLog = tableLog
506
var rtbTable = [...]uint32{0, 473195, 504333, 520860, 550000, 700000, 750000, 830000}
508
// normalizeCount will normalize the count of the symbols so
509
// the total is equal to the table size.
510
func (s *Scratch) normalizeCount() error {
512
tableLog = s.actualTableLog
513
scale = 62 - uint64(tableLog)
514
step = (1 << 62) / uint64(s.br.remain())
515
vStep = uint64(1) << (scale - 20)
516
stillToDistribute = int16(1 << tableLog)
519
lowThreshold = (uint32)(s.br.remain() >> tableLog)
522
for i, cnt := range s.count[:s.symbolLen] {
524
// if (count[s] == s.length) return 0; /* rle special case */
530
if cnt <= lowThreshold {
534
proba := (int16)((uint64(cnt) * step) >> scale)
536
restToBeat := vStep * uint64(rtbTable[proba])
537
v := uint64(cnt)*step - (uint64(proba) << scale)
542
if proba > largestP {
547
stillToDistribute -= proba
551
if -stillToDistribute >= (s.norm[largest] >> 1) {
552
// corner case, need another normalization method
553
return s.normalizeCount2()
555
s.norm[largest] += stillToDistribute
559
// Secondary normalization method.
560
// To be used when primary method fails.
561
func (s *Scratch) normalizeCount2() error {
562
const notYetAssigned = -2
565
total = uint32(s.br.remain())
566
tableLog = s.actualTableLog
567
lowThreshold = total >> tableLog
568
lowOne = (total * 3) >> (tableLog + 1)
570
for i, cnt := range s.count[:s.symbolLen] {
575
if cnt <= lowThreshold {
587
s.norm[i] = notYetAssigned
589
toDistribute := (1 << tableLog) - distributed
591
if (total / toDistribute) > lowOne {
592
// risk of rounding to zero
593
lowOne = (total * 3) / (toDistribute * 2)
594
for i, cnt := range s.count[:s.symbolLen] {
595
if (s.norm[i] == notYetAssigned) && (cnt <= lowOne) {
602
toDistribute = (1 << tableLog) - distributed
604
if distributed == uint32(s.symbolLen)+1 {
605
// all values are pretty poor;
606
// probably incompressible data (should have already been detected);
607
// find max, then give all remaining points to max
610
for i, cnt := range s.count[:s.symbolLen] {
616
s.norm[maxV] += int16(toDistribute)
621
// all of the symbols were low enough for the lowOne or lowThreshold
622
for i := uint32(0); toDistribute > 0; i = (i + 1) % (uint32(s.symbolLen)) {
632
vStepLog = 62 - uint64(tableLog)
633
mid = uint64((1 << (vStepLog - 1)) - 1)
634
rStep = (((1 << vStepLog) * uint64(toDistribute)) + mid) / uint64(total) // scale on remaining
637
for i, cnt := range s.count[:s.symbolLen] {
638
if s.norm[i] == notYetAssigned {
640
end = tmpTotal + uint64(cnt)*rStep
641
sStart = uint32(tmpTotal >> vStepLog)
642
sEnd = uint32(end >> vStepLog)
643
weight = sEnd - sStart
646
return errors.New("weight < 1")
648
s.norm[i] = int16(weight)
655
// validateNorm validates the normalized histogram table.
656
func (s *Scratch) validateNorm() (err error) {
658
for _, v := range s.norm[:s.symbolLen] {
669
fmt.Printf("selected TableLog: %d, Symbol length: %d\n", s.actualTableLog, s.symbolLen)
670
for i, v := range s.norm[:s.symbolLen] {
671
fmt.Printf("%3d: %5d -> %4d \n", i, s.count[i], v)
674
if total != (1 << s.actualTableLog) {
675
return fmt.Errorf("warning: Total == %d != %d", total, 1<<s.actualTableLog)
677
for i, v := range s.count[s.symbolLen:] {
679
return fmt.Errorf("warning: Found symbol out of range, %d after cut", i)