use std::cmp::min;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BitWriter {
data: Vec<u8>,
cursor: usize,
}
impl BitWriter {
pub fn new() -> Self {
BitWriter {
data: vec![0],
cursor: 0,
}
}
pub fn append<D>(&mut self, data: D)
where
D: BitSource,
{
self.append_range(data, 0, D::LEN);
}
pub fn append_tail<D>(&mut self, data: D, len: u8)
where
D: BitSource,
{
self.append_range(data, D::LEN - len, len);
}
pub fn append_range<D>(&mut self, data: D, offset: u8, len: u8)
where
D: BitSource,
{
let room = self.room();
let overlay_len = min(room, len);
let overlay = data.slice_left(offset, overlay_len, (self.cursor % 8) as u8);
let mut handled = overlay_len;
let mut remaining = len - overlay_len;
let mut append = Vec::new();
while remaining > 0 {
let append_offset = offset + handled;
let append_len = min(8, remaining);
let append_val = data.slice_left(append_offset, append_len, 0);
append.push(append_val);
handled += append_len;
remaining -= append_len;
}
self.update(overlay, append, len);
}
fn room(&self) -> u8 {
let res = (self.data.len() * 8) - self.cursor;
res as u8
}
fn update(&mut self, overlay: u8, mut append: Vec<u8>, increment: u8) {
let last = self.data.last_mut().unwrap();
*last |= overlay;
self.data.append(&mut append);
self.cursor += increment as usize;
}
pub fn as_bytes(&self) -> &[u8] {
&self.data
}
pub fn len(&self) -> usize {
self.cursor
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BitReader {
data: Vec<u8>,
cursor: usize,
len: usize,
}
impl BitReader {
pub fn len(&self) -> usize {
return self.len;
}
pub fn remaining_len(&self) -> usize {
return self.len - self.cursor;
}
pub fn skip(&mut self, n: usize) -> bool {
if self.cursor + n <= self.len {
self.cursor += n;
true
} else {
false
}
}
pub fn read_u8(&mut self, len: usize) -> Option<u8> {
if self.cursor + len <= self.len {
let byte_idx = self.cursor / 8;
let offset = self.cursor % 8;
self.cursor += len;
let l_len = 8 - offset;
if len <= l_len {
let byte = self.data[byte_idx] & LEFT_MASK_U8[offset + len] & RIGHT_MASK_U8[l_len];
let byte = byte >> (l_len - len);
Some(byte)
} else {
let r_len = len - l_len;
let l_byte =
self.data[byte_idx] & LEFT_MASK_U8[offset + l_len] & RIGHT_MASK_U8[l_len];
let r_byte = self.data[byte_idx + 1] & LEFT_MASK_U8[r_len];
let l_byte = l_byte << r_len;
let r_byte = r_byte >> (8 - r_len);
Some(l_byte | r_byte)
}
} else {
None
}
}
pub fn read_u16(&mut self, len: usize) -> Option<u16> {
if self.cursor + len <= self.len {
let mut rem = len;
let mut buffer = BitWriter::new();
while rem > 0 {
let byte_idx = self.cursor / 8;
let offset = self.cursor % 8;
let byte_rem = 8 - offset;
let byte_rem = if rem < byte_rem { rem } else { byte_rem };
buffer.append_range(self.data[byte_idx], offset as u8, byte_rem as u8);
self.cursor += byte_rem;
rem -= byte_rem;
}
let buffer_bytes = buffer.as_bytes();
let mut output: u16 = 0;
for i in 0..2 {
if i < buffer_bytes.len() {
let shift = (1 - i) * 8;
let byte = buffer_bytes[i];
output |= (byte as u16) << shift;
} else {
break;
}
}
Some(output >> (16 - len))
} else {
None
}
}
pub fn read_u32(&mut self, len: usize) -> Option<u32> {
if self.cursor + len <= self.len {
let mut rem = len;
let mut buffer = BitWriter::new();
while rem > 0 {
let byte_idx = self.cursor / 8;
let offset = self.cursor % 8;
let byte_rem = 8 - offset;
let byte_rem = if rem < byte_rem { rem } else { byte_rem };
buffer.append_range(self.data[byte_idx], offset as u8, byte_rem as u8);
self.cursor += byte_rem;
rem -= byte_rem;
}
let buffer_bytes = buffer.as_bytes();
let mut output: u32 = 0;
for i in 0..4 {
if i < buffer_bytes.len() {
let shift = (3 - i) * 8;
let byte = buffer_bytes[i];
output |= (byte as u32) << shift;
} else {
break;
}
}
Some(output >> (32 - len))
} else {
None
}
}
}
pub trait BitSource {
const LEN: u8;
fn slice(&self, offset: u8, len: u8) -> u8;
fn slice_left(&self, offset: u8, len: u8, margin: u8) -> u8 {
if len == 0 {
0
} else {
self.slice(offset, len) << (8 - len - margin)
}
}
}
impl BitSource for u8 {
const LEN: u8 = 8;
fn slice(&self, offset: u8, len: u8) -> u8 {
let shift = 8 - offset - len;
if shift >= 8 || len > 8 {
0
} else {
(self >> shift) & RIGHT_MASK_U8[len as usize]
}
}
}
impl BitSource for u16 {
const LEN: u8 = 16;
fn slice(&self, offset: u8, len: u8) -> u8 {
let shift = 16 - offset - len;
if shift >= 16 || len > 16 {
0
} else {
((self >> shift) as u8) & RIGHT_MASK_U8[len as usize]
}
}
}
impl BitSource for u32 {
const LEN: u8 = 32;
fn slice(&self, offset: u8, len: u8) -> u8 {
let shift = 32 - offset - len;
if shift >= 32 || len > 32 {
0
} else {
((self >> shift) as u8) & RIGHT_MASK_U8[len as usize]
}
}
}
impl BitSource for u64 {
const LEN: u8 = 64;
fn slice(&self, offset: u8, len: u8) -> u8 {
let shift = 64 - offset - len;
if shift >= 64 || len > 64 {
0
} else {
((self >> shift) as u8) & RIGHT_MASK_U8[len as usize]
}
}
}
impl BitSource for u128 {
const LEN: u8 = 128;
fn slice(&self, offset: u8, len: u8) -> u8 {
let shift = 128 - offset - len;
if shift >= 128 || len > 128 {
0
} else {
((self >> shift) as u8) & RIGHT_MASK_U8[len as usize]
}
}
}
pub const MAX_3BIT: u8 = (1 << 3) - 1;
pub const MAX_4BIT: u8 = (1 << 4) - 1;
pub const MAX_7BIT: u8 = (1 << 7) - 1;
pub const MAX_10BIT: u16 = (1 << 10) - 1;
pub const MAX_11BIT: u16 = (1 << 11) - 1;
pub const MAX_12BIT: u16 = (1 << 12) - 1;
pub const MAX_20BIT: u32 = (1 << 20) - 1;
pub const LEFT_MASK_U8: [u8; 9] = [
0b00000000, 0b10000000, 0b11000000, 0b11100000, 0b11110000, 0b11111000, 0b11111100, 0b11111110,
0b11111111,
];
pub const RIGHT_MASK_U8: [u8; 9] = [
0b00000000, 0b00000001, 0b00000011, 0b00000111, 0b00001111, 0b00011111, 0b00111111, 0b01111111,
0b11111111,
];
impl<'a> From<&'a [u8]> for BitReader {
fn from(input: &'a [u8]) -> Self {
let len = input.len() * 8;
BitReader {
data: Vec::from(input),
cursor: 0,
len,
}
}
}
impl From<Vec<u8>> for BitReader {
fn from(input: Vec<u8>) -> Self {
let len = input.len() * 8;
BitReader {
data: input,
cursor: 0,
len,
}
}
}
use std::{fs::File, io, io::Read, path::PathBuf};
impl BitReader {
pub fn from_file(path: PathBuf) -> Result<Self, io::Error> {
let mut file = File::open(path)?;
let mut vec_output: Vec<u8> = Vec::new();
file.read_to_end(&mut vec_output)?;
let reader_output: BitReader = BitReader::from(vec_output);
Ok(reader_output)
}
}
impl From<BitWriter> for BitReader {
fn from(input: BitWriter) -> Self {
let len = input.len();
BitReader {
data: input.data,
cursor: 0,
len,
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn slice_left_u8() {
assert_eq!(0b1000_0000u8, 0b0000_0001u8.slice_left(7, 1, 0));
assert_eq!(0b0010_0000u8, 0b0000_0001u8.slice_left(5, 3, 0));
assert_eq!(0b1110_0000u8, 0b0000_1110u8.slice_left(4, 3, 0));
assert_eq!(0b1110_0000u8, 0b1110_1110u8.slice_left(0, 3, 0));
}
#[test]
fn slice_left_margin_u8() {
assert_eq!(0b0110_0000u8, 0b0000_1100u8.slice_left(4, 2, 1));
assert_eq!(0b0011_0000u8, 0b0000_1100u8.slice_left(4, 2, 2));
assert_eq!(0b0001_1000u8, 0b0000_1100u8.slice_left(4, 2, 3));
assert_eq!(0b0000_1100u8, 0b0000_1100u8.slice_left(4, 2, 4));
assert_eq!(0b0000_0110u8, 0b0000_1100u8.slice_left(4, 2, 5));
assert_eq!(0b0000_0011u8, 0b0000_1100u8.slice_left(4, 2, 6));
}
#[test]
fn slice_left_u16() {
assert_eq!(0b1000_0000u8, 0b0000_0001_0000_0000u16.slice_left(7, 1, 0));
assert_eq!(0b0011_0000u8, 0b0011_0110_1001_1100u16.slice_left(4, 4, 1));
assert_eq!(0b0111_0000u8, 0b0000_1110_0000_0000u16.slice_left(4, 3, 1));
assert_eq!(0b1111_1111u8, 0b0000_1111_1111_0000u16.slice_left(4, 8, 0));
}
#[test]
fn add_whole_range() {
let mut buff = BitWriter::new();
buff.append_range(0x0Fu8, 0, 8);
buff.append_range(0x0F0Fu16, 0, 16);
buff.append_range(0x0F0F0F0Fu32, 0, 32);
let expected: Vec<u8> = vec![0x0F; 7];
assert_eq!(expected, buff.as_bytes());
}
#[test]
fn add_single_bits() {
let mut buff = BitWriter::new();
buff.append_range(0b10000000u8, 0, 1);
buff.append_range(0b00001000u8, 4, 1);
buff.append_range(0b01000000u8, 1, 1);
buff.append_range(0b00000001u8, 7, 1);
buff.append_range(0b00010000u8, 3, 1);
buff.append_range(0b00100000u8, 2, 1);
buff.append_range(0b00000100u8, 5, 1);
buff.append_range(0b00000010u8, 6, 1);
let expected: Vec<u8> = vec![0xFF];
assert_eq!(expected, buff.as_bytes());
}
#[test]
fn add_random_bits() {
let mut buff = BitWriter::new();
buff.append_range(0b10000000u8, 0, 1);
buff.append_range(0b00000100u8, 5, 1);
buff.append_range(0b00000100u8, 4, 1);
buff.append_range(0b00000010u8, 6, 1);
let expected: Vec<u8> = vec![0b11010000u8];
println!("{:#?}", buff.as_bytes());
assert_eq!(expected, buff.as_bytes());
}
#[test]
fn add_triples() {
let mut buff = BitWriter::new();
buff.append_range(0b10111111u8, 0, 3);
buff.append_range(0b11111011u8, 4, 3);
buff.append_range(0b11011111u8, 1, 3);
buff.append_range(0b11110111u8, 3, 3);
buff.append_range(0b11101111u8, 2, 3);
buff.append_range(0b11111101u8, 5, 3);
let expected: Vec<u8> = vec![0b10110110, 0b11011011, 0b01000000];
assert_eq!(expected, buff.as_bytes());
}
#[test]
fn test_read_u8() {
let factor = 100;
let mut writer: BitWriter = BitWriter::new();
for _i in 0..factor {
writer.append_range(0b00011110u8, 3, 5);
}
let mut reader: BitReader = writer.into();
for _i in 0..factor {
assert_eq!(Some(0b00011110u8), reader.read_u8(5));
}
}
#[test]
fn test_read_u16() {
let factor = 100;
let mut writer: BitWriter = BitWriter::new();
for _i in 0..factor {
writer.append_tail(0b00000011_00011110u16, 10);
writer.append_tail(0b00110011_00001110u16, 14);
writer.append_tail(0b00000000_00010110u16, 7);
writer.append_tail(0b10101010_10101010u16, 16);
}
assert_eq!(factor * 47, writer.len());
let mut reader: BitReader = writer.into();
for i in 0..factor {
assert_eq!(
Some(0b00000011_00011110u16),
reader.read_u16(10),
"Iteration i={}",
i
);
assert_eq!(
Some(0b00110011_00001110u16),
reader.read_u16(14),
"Iteration i={}",
i
);
assert_eq!(
Some(0b00000000_00010110u16),
reader.read_u16(7),
"Iteration i={}",
i
);
assert_eq!(
Some(0b10101010_10101010u16),
reader.read_u16(16),
"Iteration i={}",
i
);
}
assert_eq!(factor * 47, reader.cursor);
}
#[test]
fn test_read_u32() {
let factor = 100;
let mut writer: BitWriter = BitWriter::new();
for _i in 0..factor {
writer.append_tail(0b00000011_00011110u16, 10);
writer.append_tail(0b00110011_00001110u16, 14);
writer.append_tail(0b00000000_00010110u16, 7);
writer.append_tail(0b10101010_10101010u16, 16);
writer.append_tail(0b00000000_00011111u16, 5);
}
assert_eq!(factor * 52, writer.len());
let mut reader: BitReader = writer.into();
for i in 0..factor {
assert_eq!(
Some(0b11000111_10110011_00001110_00101101u32),
reader.read_u32(32),
"Iteration i={}",
i
);
assert_eq!(
Some(0b01010101_01010101_1111u32),
reader.read_u32(20),
"Iteration i={}",
i
);
}
assert_eq!(factor * 52, reader.cursor);
}
}