lzma_rs/decode/
rangecoder.rs

1use crate::decode::util;
2use crate::error;
3use byteorder::{BigEndian, ReadBytesExt};
4use std::io;
5
6pub struct RangeDecoder<'a, R>
7where
8    R: 'a + io::BufRead,
9{
10    pub stream: &'a mut R,
11    pub range: u32,
12    pub code: u32,
13}
14
15impl<'a, R> RangeDecoder<'a, R>
16where
17    R: io::BufRead,
18{
19    pub fn new(stream: &'a mut R) -> io::Result<Self> {
20        let mut dec = Self {
21            stream,
22            range: 0xFFFF_FFFF,
23            code: 0,
24        };
25        let _ = dec.stream.read_u8()?;
26        dec.code = dec.stream.read_u32::<BigEndian>()?;
27        lzma_debug!("0 {{ range: {:08x}, code: {:08x} }}", dec.range, dec.code);
28        Ok(dec)
29    }
30
31    pub fn from_parts(stream: &'a mut R, range: u32, code: u32) -> Self {
32        Self {
33            stream,
34            range,
35            code,
36        }
37    }
38
39    pub fn set(&mut self, range: u32, code: u32) {
40        self.range = range;
41        self.code = code;
42    }
43
44    pub fn read_into(&mut self, dst: &mut [u8]) -> io::Result<usize> {
45        self.stream.read(dst)
46    }
47
48    #[inline]
49    pub fn is_finished_ok(&mut self) -> io::Result<bool> {
50        Ok(self.code == 0 && self.is_eof()?)
51    }
52
53    #[inline]
54    pub fn is_eof(&mut self) -> io::Result<bool> {
55        util::is_eof(self.stream)
56    }
57
58    #[inline]
59    fn normalize(&mut self) -> io::Result<()> {
60        lzma_trace!("  {{ range: {:08x}, code: {:08x} }}", self.range, self.code);
61        if self.range < 0x0100_0000 {
62            self.range <<= 8;
63            self.code = (self.code << 8) ^ (self.stream.read_u8()? as u32);
64
65            lzma_debug!("+ {{ range: {:08x}, code: {:08x} }}", self.range, self.code);
66        }
67        Ok(())
68    }
69
70    #[inline]
71    fn get_bit(&mut self) -> error::Result<bool> {
72        self.range >>= 1;
73
74        let bit = self.code >= self.range;
75        if bit {
76            self.code -= self.range
77        }
78
79        self.normalize()?;
80        Ok(bit)
81    }
82
83    pub fn get(&mut self, count: usize) -> error::Result<u32> {
84        let mut result = 0u32;
85        for _ in 0..count {
86            result = (result << 1) ^ (self.get_bit()? as u32)
87        }
88        Ok(result)
89    }
90
91    #[inline]
92    pub fn decode_bit(&mut self, prob: &mut u16, update: bool) -> io::Result<bool> {
93        let bound: u32 = (self.range >> 11) * (*prob as u32);
94
95        lzma_trace!(
96            " bound: {:08x}, prob: {:04x}, bit: {}",
97            bound,
98            prob,
99            (self.code > bound) as u8
100        );
101        if self.code < bound {
102            if update {
103                *prob += (0x800_u16 - *prob) >> 5;
104            }
105            self.range = bound;
106
107            self.normalize()?;
108            Ok(false)
109        } else {
110            if update {
111                *prob -= *prob >> 5;
112            }
113            self.code -= bound;
114            self.range -= bound;
115
116            self.normalize()?;
117            Ok(true)
118        }
119    }
120
121    fn parse_bit_tree(
122        &mut self,
123        num_bits: usize,
124        probs: &mut [u16],
125        update: bool,
126    ) -> io::Result<u32> {
127        let mut tmp: u32 = 1;
128        for _ in 0..num_bits {
129            let bit = self.decode_bit(&mut probs[tmp as usize], update)?;
130            tmp = (tmp << 1) ^ (bit as u32);
131        }
132        Ok(tmp - (1 << num_bits))
133    }
134
135    pub fn parse_reverse_bit_tree(
136        &mut self,
137        num_bits: usize,
138        probs: &mut [u16],
139        offset: usize,
140        update: bool,
141    ) -> io::Result<u32> {
142        let mut result = 0u32;
143        let mut tmp: usize = 1;
144        for i in 0..num_bits {
145            let bit = self.decode_bit(&mut probs[offset + tmp], update)?;
146            tmp = (tmp << 1) ^ (bit as usize);
147            result ^= (bit as u32) << i;
148        }
149        Ok(result)
150    }
151}
152
153// TODO: parametrize by constant and use [u16; 1 << num_bits] as soon as Rust supports this
154#[derive(Debug, Clone)]
155pub struct BitTree {
156    num_bits: usize,
157    probs: Vec<u16>,
158}
159
160impl BitTree {
161    pub fn new(num_bits: usize) -> Self {
162        BitTree {
163            num_bits,
164            probs: vec![0x400; 1 << num_bits],
165        }
166    }
167
168    pub fn parse<R: io::BufRead>(
169        &mut self,
170        rangecoder: &mut RangeDecoder<R>,
171        update: bool,
172    ) -> io::Result<u32> {
173        rangecoder.parse_bit_tree(self.num_bits, self.probs.as_mut_slice(), update)
174    }
175
176    pub fn parse_reverse<R: io::BufRead>(
177        &mut self,
178        rangecoder: &mut RangeDecoder<R>,
179        update: bool,
180    ) -> io::Result<u32> {
181        rangecoder.parse_reverse_bit_tree(self.num_bits, self.probs.as_mut_slice(), 0, update)
182    }
183
184    pub fn reset(&mut self) {
185        self.probs.fill(0x400);
186    }
187}
188
189#[derive(Debug)]
190pub struct LenDecoder {
191    choice: u16,
192    choice2: u16,
193    low_coder: [BitTree; 16],
194    mid_coder: [BitTree; 16],
195    high_coder: BitTree,
196}
197
198impl LenDecoder {
199    pub fn new() -> Self {
200        LenDecoder {
201            choice: 0x400,
202            choice2: 0x400,
203            low_coder: [
204                BitTree::new(3),
205                BitTree::new(3),
206                BitTree::new(3),
207                BitTree::new(3),
208                BitTree::new(3),
209                BitTree::new(3),
210                BitTree::new(3),
211                BitTree::new(3),
212                BitTree::new(3),
213                BitTree::new(3),
214                BitTree::new(3),
215                BitTree::new(3),
216                BitTree::new(3),
217                BitTree::new(3),
218                BitTree::new(3),
219                BitTree::new(3),
220            ],
221            mid_coder: [
222                BitTree::new(3),
223                BitTree::new(3),
224                BitTree::new(3),
225                BitTree::new(3),
226                BitTree::new(3),
227                BitTree::new(3),
228                BitTree::new(3),
229                BitTree::new(3),
230                BitTree::new(3),
231                BitTree::new(3),
232                BitTree::new(3),
233                BitTree::new(3),
234                BitTree::new(3),
235                BitTree::new(3),
236                BitTree::new(3),
237                BitTree::new(3),
238            ],
239            high_coder: BitTree::new(8),
240        }
241    }
242
243    pub fn decode<R: io::BufRead>(
244        &mut self,
245        rangecoder: &mut RangeDecoder<R>,
246        pos_state: usize,
247        update: bool,
248    ) -> io::Result<usize> {
249        if !rangecoder.decode_bit(&mut self.choice, update)? {
250            Ok(self.low_coder[pos_state].parse(rangecoder, update)? as usize)
251        } else if !rangecoder.decode_bit(&mut self.choice2, update)? {
252            Ok(self.mid_coder[pos_state].parse(rangecoder, update)? as usize + 8)
253        } else {
254            Ok(self.high_coder.parse(rangecoder, update)? as usize + 16)
255        }
256    }
257
258    pub fn reset(&mut self) {
259        self.choice = 0x400;
260        self.choice2 = 0x400;
261        self.low_coder.iter_mut().for_each(|t| t.reset());
262        self.mid_coder.iter_mut().for_each(|t| t.reset());
263        self.high_coder.reset();
264    }
265}