lzma_rs/encode/
rangecoder.rs

1use byteorder::WriteBytesExt;
2use std::io;
3
4pub struct RangeEncoder<'a, W>
5where
6    W: 'a + io::Write,
7{
8    stream: &'a mut W,
9    range: u32,
10    low: u64,
11    cache: u8,
12    cachesz: u32,
13}
14
15impl<'a, W> RangeEncoder<'a, W>
16where
17    W: io::Write,
18{
19    #[allow(clippy::let_and_return)]
20    pub fn new(stream: &'a mut W) -> Self {
21        let enc = Self {
22            stream,
23            range: 0xFFFF_FFFF,
24            low: 0,
25            cache: 0,
26            cachesz: 1,
27        };
28        lzma_debug!("0 {{ range: {:08x}, low: {:010x} }}", enc.range, enc.low);
29        enc
30    }
31
32    fn write_low(&mut self) -> io::Result<()> {
33        if self.low < 0xFF00_0000 || self.low > 0xFFFF_FFFF {
34            let mut tmp = self.cache;
35            loop {
36                let byte = tmp.wrapping_add((self.low >> 32) as u8);
37                self.stream.write_u8(byte)?;
38                lzma_debug!("> byte: {:02x}", byte);
39                tmp = 0xFF;
40                self.cachesz -= 1;
41                if self.cachesz == 0 {
42                    break;
43                }
44            }
45            self.cache = (self.low >> 24) as u8;
46        }
47
48        self.cachesz += 1;
49        self.low = (self.low << 8) & 0xFFFF_FFFF;
50        Ok(())
51    }
52
53    pub fn finish(&mut self) -> io::Result<()> {
54        for _ in 0..5 {
55            self.write_low()?;
56
57            lzma_debug!("$ {{ range: {:08x}, low: {:010x} }}", self.range, self.low);
58        }
59        Ok(())
60    }
61
62    fn normalize(&mut self) -> io::Result<()> {
63        while self.range < 0x0100_0000 {
64            lzma_debug!(
65                "+ {{ range: {:08x}, low: {:010x}, cache: {:02x}, {} }}",
66                self.range,
67                self.low,
68                self.cache,
69                self.cachesz
70            );
71            self.range <<= 8;
72            self.write_low()?;
73            lzma_debug!(
74                "* {{ range: {:08x}, low: {:010x}, cache: {:02x}, {} }}",
75                self.range,
76                self.low,
77                self.cache,
78                self.cachesz
79            );
80        }
81        lzma_trace!("  {{ range: {:08x}, low: {:010x} }}", self.range, self.low);
82        Ok(())
83    }
84
85    pub fn encode_bit(&mut self, prob: &mut u16, bit: bool) -> io::Result<()> {
86        let bound: u32 = (self.range >> 11) * (*prob as u32);
87        lzma_trace!(
88            "  bound: {:08x}, prob: {:04x}, bit: {}",
89            bound,
90            prob,
91            bit as u8
92        );
93
94        if bit {
95            *prob -= *prob >> 5;
96            self.low += bound as u64;
97            self.range -= bound;
98        } else {
99            *prob += (0x800_u16 - *prob) >> 5;
100            self.range = bound;
101        }
102
103        self.normalize()
104    }
105
106    #[cfg(test)]
107    fn encode_bit_tree(
108        &mut self,
109        num_bits: usize,
110        probs: &mut [u16],
111        value: u32,
112    ) -> io::Result<()> {
113        debug_assert!(value.leading_zeros() as usize + num_bits >= 32);
114        let mut tmp: usize = 1;
115        for i in 0..num_bits {
116            let bit = ((value >> (num_bits - i - 1)) & 1) != 0;
117            self.encode_bit(&mut probs[tmp], bit)?;
118            tmp = (tmp << 1) ^ (bit as usize);
119        }
120        Ok(())
121    }
122
123    #[cfg(test)]
124    pub fn encode_reverse_bit_tree(
125        &mut self,
126        num_bits: usize,
127        probs: &mut [u16],
128        offset: usize,
129        mut value: u32,
130    ) -> io::Result<()> {
131        debug_assert!(value.leading_zeros() as usize + num_bits >= 32);
132        let mut tmp: usize = 1;
133        for _ in 0..num_bits {
134            let bit = (value & 1) != 0;
135            value >>= 1;
136            self.encode_bit(&mut probs[offset + tmp], bit)?;
137            tmp = (tmp << 1) ^ (bit as usize);
138        }
139        Ok(())
140    }
141}
142
143// TODO: parametrize by constant and use [u16; 1 << num_bits] as soon as Rust supports this
144#[cfg(test)]
145#[derive(Clone)]
146pub struct BitTree {
147    num_bits: usize,
148    probs: Vec<u16>,
149}
150
151#[cfg(test)]
152impl BitTree {
153    pub fn new(num_bits: usize) -> Self {
154        BitTree {
155            num_bits,
156            probs: vec![0x400; 1 << num_bits],
157        }
158    }
159
160    pub fn encode<W: io::Write>(
161        &mut self,
162        rangecoder: &mut RangeEncoder<W>,
163        value: u32,
164    ) -> io::Result<()> {
165        rangecoder.encode_bit_tree(self.num_bits, self.probs.as_mut_slice(), value)
166    }
167
168    pub fn encode_reverse<W: io::Write>(
169        &mut self,
170        rangecoder: &mut RangeEncoder<W>,
171        value: u32,
172    ) -> io::Result<()> {
173        rangecoder.encode_reverse_bit_tree(self.num_bits, self.probs.as_mut_slice(), 0, value)
174    }
175}
176
177#[cfg(test)]
178pub struct LenEncoder {
179    choice: u16,
180    choice2: u16,
181    low_coder: Vec<BitTree>,
182    mid_coder: Vec<BitTree>,
183    high_coder: BitTree,
184}
185
186#[cfg(test)]
187impl LenEncoder {
188    pub fn new() -> Self {
189        LenEncoder {
190            choice: 0x400,
191            choice2: 0x400,
192            low_coder: vec![BitTree::new(3); 16],
193            mid_coder: vec![BitTree::new(3); 16],
194            high_coder: BitTree::new(8),
195        }
196    }
197
198    pub fn encode<W: io::Write>(
199        &mut self,
200        rangecoder: &mut RangeEncoder<W>,
201        pos_state: usize,
202        value: u32,
203    ) -> io::Result<()> {
204        let is_low: bool = value < 8;
205        rangecoder.encode_bit(&mut self.choice, !is_low)?;
206        if is_low {
207            return self.low_coder[pos_state].encode(rangecoder, value);
208        }
209
210        let is_middle: bool = value < 16;
211        rangecoder.encode_bit(&mut self.choice2, !is_middle)?;
212        if is_middle {
213            return self.mid_coder[pos_state].encode(rangecoder, value - 8);
214        }
215
216        self.high_coder.encode(rangecoder, value - 16)
217    }
218}
219
220#[cfg(test)]
221mod test {
222    use super::*;
223    use crate::decode::rangecoder::{LenDecoder, RangeDecoder};
224    use crate::{decode, encode};
225    use std::io::BufReader;
226
227    fn encode_decode(prob_init: u16, bits: &[bool]) {
228        let mut buf: Vec<u8> = Vec::new();
229
230        let mut encoder = RangeEncoder::new(&mut buf);
231        let mut prob = prob_init;
232        for &b in bits {
233            encoder.encode_bit(&mut prob, b).unwrap();
234        }
235        encoder.finish().unwrap();
236
237        let mut bufread = BufReader::new(buf.as_slice());
238        let mut decoder = RangeDecoder::new(&mut bufread).unwrap();
239        let mut prob = prob_init;
240        for &b in bits {
241            assert_eq!(decoder.decode_bit(&mut prob, true).unwrap(), b);
242        }
243        assert!(decoder.is_finished_ok().unwrap());
244    }
245
246    #[test]
247    fn test_encode_decode_zeros() {
248        encode_decode(0x400, &[false; 10000]);
249    }
250
251    #[test]
252    fn test_encode_decode_ones() {
253        encode_decode(0x400, &[true; 10000]);
254    }
255
256    fn encode_decode_bittree(num_bits: usize, values: &[u32]) {
257        let mut buf: Vec<u8> = Vec::new();
258
259        let mut encoder = RangeEncoder::new(&mut buf);
260        let mut tree = encode::rangecoder::BitTree::new(num_bits);
261        for &v in values {
262            tree.encode(&mut encoder, v).unwrap();
263        }
264        encoder.finish().unwrap();
265
266        let mut bufread = BufReader::new(buf.as_slice());
267        let mut decoder = RangeDecoder::new(&mut bufread).unwrap();
268        let mut tree = decode::rangecoder::BitTree::new(num_bits);
269        for &v in values {
270            assert_eq!(tree.parse(&mut decoder, true).unwrap(), v);
271        }
272        assert!(decoder.is_finished_ok().unwrap());
273    }
274
275    #[test]
276    fn test_encode_decode_bittree_zeros() {
277        for num_bits in 0..16 {
278            encode_decode_bittree(num_bits, &[0; 10000]);
279        }
280    }
281
282    #[test]
283    fn test_encode_decode_bittree_ones() {
284        for num_bits in 0..16 {
285            encode_decode_bittree(num_bits, &[(1 << num_bits) - 1; 10000]);
286        }
287    }
288
289    #[test]
290    fn test_encode_decode_bittree_all() {
291        for num_bits in 0..16 {
292            let max = 1 << num_bits;
293            let values: Vec<u32> = (0..max).collect();
294            encode_decode_bittree(num_bits, &values);
295        }
296    }
297
298    fn encode_decode_reverse_bittree(num_bits: usize, values: &[u32]) {
299        let mut buf: Vec<u8> = Vec::new();
300
301        let mut encoder = RangeEncoder::new(&mut buf);
302        let mut tree = encode::rangecoder::BitTree::new(num_bits);
303        for &v in values {
304            tree.encode_reverse(&mut encoder, v).unwrap();
305        }
306        encoder.finish().unwrap();
307
308        let mut bufread = BufReader::new(buf.as_slice());
309        let mut decoder = RangeDecoder::new(&mut bufread).unwrap();
310        let mut tree = decode::rangecoder::BitTree::new(num_bits);
311        for &v in values {
312            assert_eq!(tree.parse_reverse(&mut decoder, true).unwrap(), v);
313        }
314        assert!(decoder.is_finished_ok().unwrap());
315    }
316
317    #[test]
318    fn test_encode_decode_reverse_bittree_zeros() {
319        for num_bits in 0..16 {
320            encode_decode_reverse_bittree(num_bits, &[0; 10000]);
321        }
322    }
323
324    #[test]
325    fn test_encode_decode_reverse_bittree_ones() {
326        for num_bits in 0..16 {
327            encode_decode_reverse_bittree(num_bits, &[(1 << num_bits) - 1; 10000]);
328        }
329    }
330
331    #[test]
332    fn test_encode_decode_reverse_bittree_all() {
333        for num_bits in 0..16 {
334            let max = 1 << num_bits;
335            let values: Vec<u32> = (0..max).collect();
336            encode_decode_reverse_bittree(num_bits, &values);
337        }
338    }
339
340    fn encode_decode_length(pos_state: usize, values: &[u32]) {
341        let mut buf: Vec<u8> = Vec::new();
342
343        let mut encoder = RangeEncoder::new(&mut buf);
344        let mut len_encoder = LenEncoder::new();
345        for &v in values {
346            len_encoder.encode(&mut encoder, pos_state, v).unwrap();
347        }
348        encoder.finish().unwrap();
349
350        let mut bufread = BufReader::new(buf.as_slice());
351        let mut decoder = RangeDecoder::new(&mut bufread).unwrap();
352        let mut len_decoder = LenDecoder::new();
353        for &v in values {
354            assert_eq!(
355                len_decoder.decode(&mut decoder, pos_state, true).unwrap(),
356                v as usize
357            );
358        }
359        assert!(decoder.is_finished_ok().unwrap());
360    }
361
362    #[test]
363    fn test_encode_decode_length_zeros() {
364        for pos_state in 0..16 {
365            encode_decode_length(pos_state, &[0; 10000]);
366        }
367    }
368
369    #[test]
370    fn test_encode_decode_length_all() {
371        for pos_state in 0..16 {
372            let max = (1 << 8) + 16;
373            let values: Vec<u32> = (0..max).collect();
374            encode_decode_length(pos_state, &values);
375        }
376    }
377}