use crate::decode::lzbuffer::{LzBuffer, LzCircularBuffer};
use crate::decode::rangecoder::{BitTree, LenDecoder, RangeDecoder};
use crate::decompress::{Options, UnpackedSize};
use crate::error;
use crate::util::vec2d::Vec2D;
use byteorder::{LittleEndian, ReadBytesExt};
use std::io;
const MAX_REQUIRED_INPUT: usize = 20;
#[derive(Debug, PartialEq)]
enum ProcessingMode {
Partial,
Finish,
}
#[derive(Debug, PartialEq)]
enum ProcessingStatus {
Continue,
Finished,
}
#[derive(Debug, Copy, Clone)]
pub struct LzmaProperties {
pub lc: u32, pub lp: u32, pub pb: u32, }
impl LzmaProperties {
pub(crate) fn validate(&self) {
assert!(self.lc <= 8);
assert!(self.lp <= 4);
assert!(self.pb <= 4);
}
}
#[derive(Debug, Copy, Clone)]
pub struct LzmaParams {
pub(crate) properties: LzmaProperties,
pub(crate) dict_size: u32,
pub(crate) unpacked_size: Option<u64>,
}
impl LzmaParams {
#[cfg(feature = "raw_decoder")]
pub fn new(
properties: LzmaProperties,
dict_size: u32,
unpacked_size: Option<u64>,
) -> LzmaParams {
Self {
properties,
dict_size,
unpacked_size,
}
}
pub fn read_header<R>(input: &mut R, options: &Options) -> error::Result<LzmaParams>
where
R: io::BufRead,
{
let props = input.read_u8().map_err(error::Error::HeaderTooShort)?;
let mut pb = props as u32;
if pb >= 225 {
return Err(error::Error::LzmaError(format!(
"LZMA header invalid properties: {} must be < 225",
pb
)));
}
let lc: u32 = pb % 9;
pb /= 9;
let lp: u32 = pb % 5;
pb /= 5;
lzma_info!("Properties {{ lc: {}, lp: {}, pb: {} }}", lc, lp, pb);
let dict_size_provided = input
.read_u32::<LittleEndian>()
.map_err(error::Error::HeaderTooShort)?;
let dict_size = if dict_size_provided < 0x1000 {
0x1000
} else {
dict_size_provided
};
lzma_info!("Dict size: {}", dict_size);
let unpacked_size: Option<u64> = match options.unpacked_size {
UnpackedSize::ReadFromHeader => {
let unpacked_size_provided = input
.read_u64::<LittleEndian>()
.map_err(error::Error::HeaderTooShort)?;
let marker_mandatory: bool = unpacked_size_provided == 0xFFFF_FFFF_FFFF_FFFF;
if marker_mandatory {
None
} else {
Some(unpacked_size_provided)
}
}
UnpackedSize::ReadHeaderButUseProvided(x) => {
input
.read_u64::<LittleEndian>()
.map_err(error::Error::HeaderTooShort)?;
x
}
UnpackedSize::UseProvided(x) => x,
};
lzma_info!("Unpacked size: {:?}", unpacked_size);
let params = LzmaParams {
properties: LzmaProperties { lc, lp, pb },
dict_size,
unpacked_size,
};
Ok(params)
}
}
#[derive(Debug)]
pub(crate) struct DecoderState {
partial_input_buf: std::io::Cursor<[u8; MAX_REQUIRED_INPUT]>,
pub(crate) lzma_props: LzmaProperties,
unpacked_size: Option<u64>,
literal_probs: Vec2D<u16>,
pos_slot_decoder: [BitTree; 4],
align_decoder: BitTree,
pos_decoders: [u16; 115],
is_match: [u16; 192], is_rep: [u16; 12],
is_rep_g0: [u16; 12],
is_rep_g1: [u16; 12],
is_rep_g2: [u16; 12],
is_rep_0long: [u16; 192],
state: usize,
rep: [usize; 4],
len_decoder: LenDecoder,
rep_len_decoder: LenDecoder,
}
impl DecoderState {
pub fn new(lzma_props: LzmaProperties, unpacked_size: Option<u64>) -> Self {
lzma_props.validate();
DecoderState {
partial_input_buf: std::io::Cursor::new([0; MAX_REQUIRED_INPUT]),
lzma_props,
unpacked_size,
literal_probs: Vec2D::init(0x400, (1 << (lzma_props.lc + lzma_props.lp), 0x300)),
pos_slot_decoder: [
BitTree::new(6),
BitTree::new(6),
BitTree::new(6),
BitTree::new(6),
],
align_decoder: BitTree::new(4),
pos_decoders: [0x400; 115],
is_match: [0x400; 192],
is_rep: [0x400; 12],
is_rep_g0: [0x400; 12],
is_rep_g1: [0x400; 12],
is_rep_g2: [0x400; 12],
is_rep_0long: [0x400; 192],
state: 0,
rep: [0; 4],
len_decoder: LenDecoder::new(),
rep_len_decoder: LenDecoder::new(),
}
}
pub fn reset_state(&mut self, new_props: LzmaProperties) {
new_props.validate();
if self.lzma_props.lc + self.lzma_props.lp == new_props.lc + new_props.lp {
self.literal_probs.fill(0x400);
} else {
self.literal_probs = Vec2D::init(0x400, (1 << (new_props.lc + new_props.lp), 0x300));
}
self.lzma_props = new_props;
self.pos_slot_decoder.iter_mut().for_each(|t| t.reset());
self.align_decoder.reset();
self.pos_decoders = [0x400; 115];
self.is_match = [0x400; 192];
self.is_rep = [0x400; 12];
self.is_rep_g0 = [0x400; 12];
self.is_rep_g1 = [0x400; 12];
self.is_rep_g2 = [0x400; 12];
self.is_rep_0long = [0x400; 192];
self.state = 0;
self.rep = [0; 4];
self.len_decoder.reset();
self.rep_len_decoder.reset();
}
pub fn set_unpacked_size(&mut self, unpacked_size: Option<u64>) {
self.unpacked_size = unpacked_size;
}
pub fn process<'a, W: io::Write, LZB: LzBuffer<W>, R: io::BufRead>(
&mut self,
output: &mut LZB,
rangecoder: &mut RangeDecoder<'a, R>,
) -> error::Result<()> {
self.process_mode(output, rangecoder, ProcessingMode::Finish)
}
#[cfg(feature = "stream")]
pub fn process_stream<'a, W: io::Write, LZB: LzBuffer<W>, R: io::BufRead>(
&mut self,
output: &mut LZB,
rangecoder: &mut RangeDecoder<'a, R>,
) -> error::Result<()> {
self.process_mode(output, rangecoder, ProcessingMode::Partial)
}
fn process_next_inner<'a, W: io::Write, LZB: LzBuffer<W>, R: io::BufRead>(
&mut self,
output: &mut LZB,
rangecoder: &mut RangeDecoder<'a, R>,
update: bool,
) -> error::Result<ProcessingStatus> {
let pos_state = output.len() & ((1 << self.lzma_props.pb) - 1);
if !rangecoder.decode_bit(
&mut self.is_match[(self.state << 4) + pos_state],
update,
)? {
let byte: u8 = self.decode_literal(output, rangecoder, update)?;
if update {
lzma_debug!("Literal: {}", byte);
output.append_literal(byte)?;
self.state = if self.state < 4 {
0
} else if self.state < 10 {
self.state - 3
} else {
self.state - 6
};
}
return Ok(ProcessingStatus::Continue);
}
let mut len: usize;
if rangecoder.decode_bit(&mut self.is_rep[self.state], update)? {
if !rangecoder.decode_bit(&mut self.is_rep_g0[self.state], update)? {
if !rangecoder.decode_bit(
&mut self.is_rep_0long[(self.state << 4) + pos_state],
update,
)? {
if update {
self.state = if self.state < 7 { 9 } else { 11 };
let dist = self.rep[0] + 1;
output.append_lz(1, dist)?;
}
return Ok(ProcessingStatus::Continue);
}
} else {
let idx: usize;
if !rangecoder.decode_bit(&mut self.is_rep_g1[self.state], update)? {
idx = 1;
} else if !rangecoder.decode_bit(&mut self.is_rep_g2[self.state], update)? {
idx = 2;
} else {
idx = 3;
}
if update {
let dist = self.rep[idx];
for i in (0..idx).rev() {
self.rep[i + 1] = self.rep[i];
}
self.rep[0] = dist
}
}
len = self.rep_len_decoder.decode(rangecoder, pos_state, update)?;
if update {
self.state = if self.state < 7 { 8 } else { 11 };
}
} else {
if update {
self.rep[3] = self.rep[2];
self.rep[2] = self.rep[1];
self.rep[1] = self.rep[0];
}
len = self.len_decoder.decode(rangecoder, pos_state, update)?;
if update {
self.state = if self.state < 7 { 7 } else { 10 };
}
let rep_0 = self.decode_distance(rangecoder, len, update)?;
if update {
self.rep[0] = rep_0;
if self.rep[0] == 0xFFFF_FFFF {
if rangecoder.is_finished_ok()? {
return Ok(ProcessingStatus::Finished);
}
return Err(error::Error::LzmaError(String::from(
"Found end-of-stream marker but more bytes are available",
)));
}
}
}
if update {
len += 2;
let dist = self.rep[0] + 1;
output.append_lz(len, dist)?;
}
Ok(ProcessingStatus::Continue)
}
fn process_next<'a, W: io::Write, LZB: LzBuffer<W>, R: io::BufRead>(
&mut self,
output: &mut LZB,
rangecoder: &mut RangeDecoder<'a, R>,
) -> error::Result<ProcessingStatus> {
self.process_next_inner(output, rangecoder, true)
}
fn try_process_next<W: io::Write, LZB: LzBuffer<W>>(
&mut self,
output: &mut LZB,
buf: &[u8],
range: u32,
code: u32,
) -> error::Result<()> {
let mut temp = std::io::Cursor::new(buf);
let mut rangecoder = RangeDecoder::from_parts(&mut temp, range, code);
let _ = self.process_next_inner(output, &mut rangecoder, false)?;
Ok(())
}
fn read_partial_input_buf<'a, R: io::BufRead>(
&mut self,
rangecoder: &mut RangeDecoder<'a, R>,
) -> error::Result<()> {
let start = self.partial_input_buf.position() as usize;
let bytes_read =
rangecoder.read_into(&mut self.partial_input_buf.get_mut()[start..])? as u64;
self.partial_input_buf
.set_position(self.partial_input_buf.position() + bytes_read);
Ok(())
}
fn process_mode<'a, W: io::Write, LZB: LzBuffer<W>, R: io::BufRead>(
&mut self,
output: &mut LZB,
rangecoder: &mut RangeDecoder<'a, R>,
mode: ProcessingMode,
) -> error::Result<()> {
loop {
if let Some(unpacked_size) = self.unpacked_size {
if output.len() as u64 >= unpacked_size {
break;
}
} else if match mode {
ProcessingMode::Partial => {
rangecoder.is_eof()? && self.partial_input_buf.position() as usize == 0
}
ProcessingMode::Finish => {
rangecoder.is_finished_ok()? && self.partial_input_buf.position() as usize == 0
}
} {
break;
}
if self.partial_input_buf.position() as usize > 0 {
self.read_partial_input_buf(rangecoder)?;
let tmp = *self.partial_input_buf.get_ref();
if mode == ProcessingMode::Partial
&& (self.partial_input_buf.position() as usize) < MAX_REQUIRED_INPUT
&& self
.try_process_next(
output,
&tmp[..self.partial_input_buf.position() as usize],
rangecoder.range,
rangecoder.code,
)
.is_err()
{
return Ok(());
}
let mut tmp_reader =
io::Cursor::new(&tmp[..self.partial_input_buf.position() as usize]);
let mut tmp_rangecoder =
RangeDecoder::from_parts(&mut tmp_reader, rangecoder.range, rangecoder.code);
let res = self.process_next(output, &mut tmp_rangecoder)?;
rangecoder.set(tmp_rangecoder.range, tmp_rangecoder.code);
let end = self.partial_input_buf.position();
let new_len = end - tmp_reader.position();
self.partial_input_buf.get_mut()[..new_len as usize]
.copy_from_slice(&tmp[tmp_reader.position() as usize..end as usize]);
self.partial_input_buf.set_position(new_len);
if res == ProcessingStatus::Finished {
break;
};
} else {
let buf: &[u8] = rangecoder.stream.fill_buf()?;
if mode == ProcessingMode::Partial
&& buf.len() < MAX_REQUIRED_INPUT
&& self
.try_process_next(output, buf, rangecoder.range, rangecoder.code)
.is_err()
{
return self.read_partial_input_buf(rangecoder);
}
if self.process_next(output, rangecoder)? == ProcessingStatus::Finished {
break;
};
}
}
if let Some(len) = self.unpacked_size {
if mode == ProcessingMode::Finish && len != output.len() as u64 {
return Err(error::Error::LzmaError(format!(
"Expected unpacked size of {} but decompressed to {}",
len,
output.len()
)));
}
}
Ok(())
}
fn decode_literal<'a, W: io::Write, LZB: LzBuffer<W>, R: io::BufRead>(
&mut self,
output: &mut LZB,
rangecoder: &mut RangeDecoder<'a, R>,
update: bool,
) -> error::Result<u8> {
let def_prev_byte = 0u8;
let prev_byte = output.last_or(def_prev_byte) as usize;
let mut result: usize = 1;
let lit_state = ((output.len() & ((1 << self.lzma_props.lp) - 1)) << self.lzma_props.lc)
+ (prev_byte >> (8 - self.lzma_props.lc));
let probs = &mut self.literal_probs[lit_state];
if self.state >= 7 {
let mut match_byte = output.last_n(self.rep[0] + 1)? as usize;
while result < 0x100 {
let match_bit = (match_byte >> 7) & 1;
match_byte <<= 1;
let bit = rangecoder
.decode_bit(&mut probs[((1 + match_bit) << 8) + result], update)?
as usize;
result = (result << 1) ^ bit;
if match_bit != bit {
break;
}
}
}
while result < 0x100 {
result = (result << 1) ^ (rangecoder.decode_bit(&mut probs[result], update)? as usize);
}
Ok((result - 0x100) as u8)
}
fn decode_distance<'a, R: io::BufRead>(
&mut self,
rangecoder: &mut RangeDecoder<'a, R>,
length: usize,
update: bool,
) -> error::Result<usize> {
let len_state = if length > 3 { 3 } else { length };
let pos_slot = self.pos_slot_decoder[len_state].parse(rangecoder, update)? as usize;
if pos_slot < 4 {
return Ok(pos_slot);
}
let num_direct_bits = (pos_slot >> 1) - 1;
let mut result = (2 ^ (pos_slot & 1)) << num_direct_bits;
if pos_slot < 14 {
result += rangecoder.parse_reverse_bit_tree(
num_direct_bits,
&mut self.pos_decoders,
result - pos_slot,
update,
)? as usize;
} else {
result += (rangecoder.get(num_direct_bits - 4)? as usize) << 4;
result += self.align_decoder.parse_reverse(rangecoder, update)? as usize;
}
Ok(result)
}
}
#[derive(Debug)]
pub struct LzmaDecoder {
params: LzmaParams,
memlimit: usize,
state: DecoderState,
}
impl LzmaDecoder {
pub fn new(params: LzmaParams, memlimit: Option<usize>) -> error::Result<LzmaDecoder> {
Ok(Self {
params,
memlimit: memlimit.unwrap_or(usize::MAX),
state: DecoderState::new(params.properties, params.unpacked_size),
})
}
#[cfg(feature = "raw_decoder")]
pub fn reset(&mut self, unpacked_size: Option<Option<u64>>) {
self.state.reset_state(self.params.properties);
if let Some(unpacked_size) = unpacked_size {
self.state.set_unpacked_size(unpacked_size);
}
}
pub fn decompress<W: io::Write, R: io::BufRead>(
&mut self,
input: &mut R,
output: &mut W,
) -> error::Result<()> {
let mut output =
LzCircularBuffer::from_stream(output, self.params.dict_size as usize, self.memlimit);
let mut rangecoder = RangeDecoder::new(input)
.map_err(|e| error::Error::LzmaError(format!("LZMA stream too short: {}", e)))?;
self.state.process(&mut output, &mut rangecoder)?;
output.finish()?;
Ok(())
}
}