1use byteorder::WriteBytesExt;
2use std::io;
3
4pub struct RangeEncoder<'a, W>
5where
6 W: 'a + io::Write,
7{
8 stream: &'a mut W,
9 range: u32,
10 low: u64,
11 cache: u8,
12 cachesz: u32,
13}
14
15impl<'a, W> RangeEncoder<'a, W>
16where
17 W: io::Write,
18{
19 #[allow(clippy::let_and_return)]
20 pub fn new(stream: &'a mut W) -> Self {
21 let enc = Self {
22 stream,
23 range: 0xFFFF_FFFF,
24 low: 0,
25 cache: 0,
26 cachesz: 1,
27 };
28 lzma_debug!("0 {{ range: {:08x}, low: {:010x} }}", enc.range, enc.low);
29 enc
30 }
31
32 fn write_low(&mut self) -> io::Result<()> {
33 if self.low < 0xFF00_0000 || self.low > 0xFFFF_FFFF {
34 let mut tmp = self.cache;
35 loop {
36 let byte = tmp.wrapping_add((self.low >> 32) as u8);
37 self.stream.write_u8(byte)?;
38 lzma_debug!("> byte: {:02x}", byte);
39 tmp = 0xFF;
40 self.cachesz -= 1;
41 if self.cachesz == 0 {
42 break;
43 }
44 }
45 self.cache = (self.low >> 24) as u8;
46 }
47
48 self.cachesz += 1;
49 self.low = (self.low << 8) & 0xFFFF_FFFF;
50 Ok(())
51 }
52
53 pub fn finish(&mut self) -> io::Result<()> {
54 for _ in 0..5 {
55 self.write_low()?;
56
57 lzma_debug!("$ {{ range: {:08x}, low: {:010x} }}", self.range, self.low);
58 }
59 Ok(())
60 }
61
62 fn normalize(&mut self) -> io::Result<()> {
63 while self.range < 0x0100_0000 {
64 lzma_debug!(
65 "+ {{ range: {:08x}, low: {:010x}, cache: {:02x}, {} }}",
66 self.range,
67 self.low,
68 self.cache,
69 self.cachesz
70 );
71 self.range <<= 8;
72 self.write_low()?;
73 lzma_debug!(
74 "* {{ range: {:08x}, low: {:010x}, cache: {:02x}, {} }}",
75 self.range,
76 self.low,
77 self.cache,
78 self.cachesz
79 );
80 }
81 lzma_trace!(" {{ range: {:08x}, low: {:010x} }}", self.range, self.low);
82 Ok(())
83 }
84
85 pub fn encode_bit(&mut self, prob: &mut u16, bit: bool) -> io::Result<()> {
86 let bound: u32 = (self.range >> 11) * (*prob as u32);
87 lzma_trace!(
88 " bound: {:08x}, prob: {:04x}, bit: {}",
89 bound,
90 prob,
91 bit as u8
92 );
93
94 if bit {
95 *prob -= *prob >> 5;
96 self.low += bound as u64;
97 self.range -= bound;
98 } else {
99 *prob += (0x800_u16 - *prob) >> 5;
100 self.range = bound;
101 }
102
103 self.normalize()
104 }
105
106 #[cfg(test)]
107 fn encode_bit_tree(
108 &mut self,
109 num_bits: usize,
110 probs: &mut [u16],
111 value: u32,
112 ) -> io::Result<()> {
113 debug_assert!(value.leading_zeros() as usize + num_bits >= 32);
114 let mut tmp: usize = 1;
115 for i in 0..num_bits {
116 let bit = ((value >> (num_bits - i - 1)) & 1) != 0;
117 self.encode_bit(&mut probs[tmp], bit)?;
118 tmp = (tmp << 1) ^ (bit as usize);
119 }
120 Ok(())
121 }
122
123 #[cfg(test)]
124 pub fn encode_reverse_bit_tree(
125 &mut self,
126 num_bits: usize,
127 probs: &mut [u16],
128 offset: usize,
129 mut value: u32,
130 ) -> io::Result<()> {
131 debug_assert!(value.leading_zeros() as usize + num_bits >= 32);
132 let mut tmp: usize = 1;
133 for _ in 0..num_bits {
134 let bit = (value & 1) != 0;
135 value >>= 1;
136 self.encode_bit(&mut probs[offset + tmp], bit)?;
137 tmp = (tmp << 1) ^ (bit as usize);
138 }
139 Ok(())
140 }
141}
142
143#[cfg(test)]
145#[derive(Clone)]
146pub struct BitTree {
147 num_bits: usize,
148 probs: Vec<u16>,
149}
150
151#[cfg(test)]
152impl BitTree {
153 pub fn new(num_bits: usize) -> Self {
154 BitTree {
155 num_bits,
156 probs: vec![0x400; 1 << num_bits],
157 }
158 }
159
160 pub fn encode<W: io::Write>(
161 &mut self,
162 rangecoder: &mut RangeEncoder<W>,
163 value: u32,
164 ) -> io::Result<()> {
165 rangecoder.encode_bit_tree(self.num_bits, self.probs.as_mut_slice(), value)
166 }
167
168 pub fn encode_reverse<W: io::Write>(
169 &mut self,
170 rangecoder: &mut RangeEncoder<W>,
171 value: u32,
172 ) -> io::Result<()> {
173 rangecoder.encode_reverse_bit_tree(self.num_bits, self.probs.as_mut_slice(), 0, value)
174 }
175}
176
177#[cfg(test)]
178pub struct LenEncoder {
179 choice: u16,
180 choice2: u16,
181 low_coder: Vec<BitTree>,
182 mid_coder: Vec<BitTree>,
183 high_coder: BitTree,
184}
185
186#[cfg(test)]
187impl LenEncoder {
188 pub fn new() -> Self {
189 LenEncoder {
190 choice: 0x400,
191 choice2: 0x400,
192 low_coder: vec![BitTree::new(3); 16],
193 mid_coder: vec![BitTree::new(3); 16],
194 high_coder: BitTree::new(8),
195 }
196 }
197
198 pub fn encode<W: io::Write>(
199 &mut self,
200 rangecoder: &mut RangeEncoder<W>,
201 pos_state: usize,
202 value: u32,
203 ) -> io::Result<()> {
204 let is_low: bool = value < 8;
205 rangecoder.encode_bit(&mut self.choice, !is_low)?;
206 if is_low {
207 return self.low_coder[pos_state].encode(rangecoder, value);
208 }
209
210 let is_middle: bool = value < 16;
211 rangecoder.encode_bit(&mut self.choice2, !is_middle)?;
212 if is_middle {
213 return self.mid_coder[pos_state].encode(rangecoder, value - 8);
214 }
215
216 self.high_coder.encode(rangecoder, value - 16)
217 }
218}
219
220#[cfg(test)]
221mod test {
222 use super::*;
223 use crate::decode::rangecoder::{LenDecoder, RangeDecoder};
224 use crate::{decode, encode};
225 use std::io::BufReader;
226
227 fn encode_decode(prob_init: u16, bits: &[bool]) {
228 let mut buf: Vec<u8> = Vec::new();
229
230 let mut encoder = RangeEncoder::new(&mut buf);
231 let mut prob = prob_init;
232 for &b in bits {
233 encoder.encode_bit(&mut prob, b).unwrap();
234 }
235 encoder.finish().unwrap();
236
237 let mut bufread = BufReader::new(buf.as_slice());
238 let mut decoder = RangeDecoder::new(&mut bufread).unwrap();
239 let mut prob = prob_init;
240 for &b in bits {
241 assert_eq!(decoder.decode_bit(&mut prob, true).unwrap(), b);
242 }
243 assert!(decoder.is_finished_ok().unwrap());
244 }
245
246 #[test]
247 fn test_encode_decode_zeros() {
248 encode_decode(0x400, &[false; 10000]);
249 }
250
251 #[test]
252 fn test_encode_decode_ones() {
253 encode_decode(0x400, &[true; 10000]);
254 }
255
256 fn encode_decode_bittree(num_bits: usize, values: &[u32]) {
257 let mut buf: Vec<u8> = Vec::new();
258
259 let mut encoder = RangeEncoder::new(&mut buf);
260 let mut tree = encode::rangecoder::BitTree::new(num_bits);
261 for &v in values {
262 tree.encode(&mut encoder, v).unwrap();
263 }
264 encoder.finish().unwrap();
265
266 let mut bufread = BufReader::new(buf.as_slice());
267 let mut decoder = RangeDecoder::new(&mut bufread).unwrap();
268 let mut tree = decode::rangecoder::BitTree::new(num_bits);
269 for &v in values {
270 assert_eq!(tree.parse(&mut decoder, true).unwrap(), v);
271 }
272 assert!(decoder.is_finished_ok().unwrap());
273 }
274
275 #[test]
276 fn test_encode_decode_bittree_zeros() {
277 for num_bits in 0..16 {
278 encode_decode_bittree(num_bits, &[0; 10000]);
279 }
280 }
281
282 #[test]
283 fn test_encode_decode_bittree_ones() {
284 for num_bits in 0..16 {
285 encode_decode_bittree(num_bits, &[(1 << num_bits) - 1; 10000]);
286 }
287 }
288
289 #[test]
290 fn test_encode_decode_bittree_all() {
291 for num_bits in 0..16 {
292 let max = 1 << num_bits;
293 let values: Vec<u32> = (0..max).collect();
294 encode_decode_bittree(num_bits, &values);
295 }
296 }
297
298 fn encode_decode_reverse_bittree(num_bits: usize, values: &[u32]) {
299 let mut buf: Vec<u8> = Vec::new();
300
301 let mut encoder = RangeEncoder::new(&mut buf);
302 let mut tree = encode::rangecoder::BitTree::new(num_bits);
303 for &v in values {
304 tree.encode_reverse(&mut encoder, v).unwrap();
305 }
306 encoder.finish().unwrap();
307
308 let mut bufread = BufReader::new(buf.as_slice());
309 let mut decoder = RangeDecoder::new(&mut bufread).unwrap();
310 let mut tree = decode::rangecoder::BitTree::new(num_bits);
311 for &v in values {
312 assert_eq!(tree.parse_reverse(&mut decoder, true).unwrap(), v);
313 }
314 assert!(decoder.is_finished_ok().unwrap());
315 }
316
317 #[test]
318 fn test_encode_decode_reverse_bittree_zeros() {
319 for num_bits in 0..16 {
320 encode_decode_reverse_bittree(num_bits, &[0; 10000]);
321 }
322 }
323
324 #[test]
325 fn test_encode_decode_reverse_bittree_ones() {
326 for num_bits in 0..16 {
327 encode_decode_reverse_bittree(num_bits, &[(1 << num_bits) - 1; 10000]);
328 }
329 }
330
331 #[test]
332 fn test_encode_decode_reverse_bittree_all() {
333 for num_bits in 0..16 {
334 let max = 1 << num_bits;
335 let values: Vec<u32> = (0..max).collect();
336 encode_decode_reverse_bittree(num_bits, &values);
337 }
338 }
339
340 fn encode_decode_length(pos_state: usize, values: &[u32]) {
341 let mut buf: Vec<u8> = Vec::new();
342
343 let mut encoder = RangeEncoder::new(&mut buf);
344 let mut len_encoder = LenEncoder::new();
345 for &v in values {
346 len_encoder.encode(&mut encoder, pos_state, v).unwrap();
347 }
348 encoder.finish().unwrap();
349
350 let mut bufread = BufReader::new(buf.as_slice());
351 let mut decoder = RangeDecoder::new(&mut bufread).unwrap();
352 let mut len_decoder = LenDecoder::new();
353 for &v in values {
354 assert_eq!(
355 len_decoder.decode(&mut decoder, pos_state, true).unwrap(),
356 v as usize
357 );
358 }
359 assert!(decoder.is_finished_ok().unwrap());
360 }
361
362 #[test]
363 fn test_encode_decode_length_zeros() {
364 for pos_state in 0..16 {
365 encode_decode_length(pos_state, &[0; 10000]);
366 }
367 }
368
369 #[test]
370 fn test_encode_decode_length_all() {
371 for pos_state in 0..16 {
372 let max = (1 << 8) + 16;
373 let values: Vec<u32> = (0..max).collect();
374 encode_decode_length(pos_state, &values);
375 }
376 }
377}