diff --git a/src/main.rs b/src/main.rs index 16058b9e..a044e7a4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -126,8 +126,7 @@ fn main_with_result() -> anyhow::Result<()> { (block_image, _metrics) = lh .decode_as_single_image( - &mut reader, - filelen, + &mut reader.take(filelen - 4), // last 4 bytes are the length of the file num_threads as usize, &enabled_features, ) diff --git a/src/structs/lepton_format.rs b/src/structs/lepton_format.rs index 5a2a7bdd..853f3dfc 100644 --- a/src/structs/lepton_format.rs +++ b/src/structs/lepton_format.rs @@ -49,17 +49,37 @@ pub fn decode_lepton_wrapper( let size = reader.seek(SeekFrom::End(0))?; reader.seek(SeekFrom::Start(orig_pos))?; + // last four bytes specify the file size + let mut reader_minus_trailer = reader.take(size - 4); + let mut lh = LeptonHeader::new(); let mut features_mut = enabled_features.clone(); - lh.read_lepton_header(reader, &mut features_mut) + lh.read_lepton_header(&mut reader_minus_trailer, &mut features_mut) .context(here!())?; let metrics = lh - .recode_jpeg(writer, reader, size, num_threads, &features_mut) + .recode_jpeg( + writer, + &mut reader_minus_trailer, + num_threads, + &features_mut, + ) .context(here!())?; + let expected_size = reader.read_u32::()?; + if expected_size != size as u32 { + return err_exit_code( + ExitCode::VerificationLengthMismatch, + format!( + "ERROR mismatch expected_size = {0}, actual_size = {1}", + expected_size, size + ) + .as_str(), + ); + } + return Ok(metrics); } @@ -302,10 +322,9 @@ pub fn read_jpeg( Ok((lp, image_data)) } -fn run_lepton_decoder_threads( +fn run_lepton_decoder_threads( lh: &LeptonHeader, reader: &mut R, - last_data_position: u64, _max_threads_to_use: usize, features: &EnabledFeatures, process: fn( @@ -335,7 +354,6 @@ fn run_lepton_decoder_threads( let mut thread_results = multiplex_read( reader, lh.thread_handoff.len(), - last_data_position, |thread_id, reader| -> Result<(Metrics, P)> { let cpu_time = CpuTimeMeasure::new(); @@ -555,11 +573,10 @@ impl LeptonHeader { }; } - fn recode_jpeg( + fn recode_jpeg( &mut self, writer: &mut W, reader: &mut R, - last_data_position: u64, num_threads: usize, enabled_features: &EnabledFeatures, ) -> Result { @@ -571,18 +588,11 @@ impl LeptonHeader { .context(here!())?; let metrics = if self.jpeg_header.jpeg_type == JPegType::Progressive { - self.recode_progressive_jpeg( - reader, - last_data_position, - writer, - num_threads, - enabled_features, - ) - .context(here!())? + self.recode_progressive_jpeg(reader, writer, num_threads, enabled_features) + .context(here!())? } else { self.recode_baseline_jpeg( reader, - last_data_position, writer, self.plain_text_size as u64 - self.garbage_data.len() as u64 @@ -605,10 +615,9 @@ impl LeptonHeader { } /// decodes the entire image and merges the results into a single set of BlockBaseImage per component - pub fn decode_as_single_image( + pub fn decode_as_single_image( &mut self, reader: &mut R, - last_data_position: u64, num_threads: usize, features: &EnabledFeatures, ) -> Result<(Vec, Metrics)> { @@ -616,7 +625,6 @@ impl LeptonHeader { let (metrics, mut results) = run_lepton_decoder_threads( self, reader, - last_data_position, num_threads, features, |_thread_handoff, image_data, _lh| { @@ -656,17 +664,16 @@ impl LeptonHeader { } /// progressive decoder, requires that the entire lepton file is processed first - fn recode_progressive_jpeg( + fn recode_progressive_jpeg( &mut self, reader: &mut R, - last_data_position: u64, writer: &mut W, num_threads: usize, enabled_features: &EnabledFeatures, ) -> Result { // run the threads first, since we need everything before we can start decoding let (merged, metrics) = self - .decode_as_single_image(reader, last_data_position, num_threads, enabled_features) + .decode_as_single_image(reader, num_threads, enabled_features) .context(here!())?; loop { @@ -696,10 +703,9 @@ impl LeptonHeader { // baseline decoder can run the jpeg encoder inside the worker thread vs progressive encoding which needs to get the entire set of coefficients first // since it runs throught it multiple times. - fn recode_baseline_jpeg( + fn recode_baseline_jpeg( &mut self, reader: &mut R, - last_data_position: u64, writer: &mut W, size_limit: u64, num_threads: usize, @@ -709,7 +715,6 @@ impl LeptonHeader { let (metrics, results) = run_lepton_decoder_threads( self, reader, - last_data_position, num_threads, enabled_features, |thread_handoff, image_data, lh| { diff --git a/src/structs/multiplexer.rs b/src/structs/multiplexer.rs index cd454ab0..9e66270f 100644 --- a/src/structs/multiplexer.rs +++ b/src/structs/multiplexer.rs @@ -3,7 +3,7 @@ use anyhow::{Context, Result}; use byteorder::{ReadBytesExt, WriteBytesExt}; use std::{ cmp, - io::{Cursor, Read, Seek, Write}, + io::{Cursor, Read, Write}, mem::swap, sync::mpsc::{channel, Receiver, SendError, Sender}, }; @@ -15,12 +15,13 @@ use std::{ /// /// The read implementation reads the blocks from the file and sends them to the appropriate worker thread. +/// The message that is sent between the threads enum Message { Eof, WriteBlock(u8, Vec), } -pub struct MessageSender { +pub struct MultiplexWriter { thread_id: u8, sender: Sender, buffer: Vec, @@ -28,7 +29,7 @@ pub struct MessageSender { const WRITE_BUFFER_SIZE: usize = 65536; -impl Write for MessageSender { +impl Write for MultiplexWriter { fn write(&mut self, buf: &[u8]) -> std::io::Result { let mut copy_start = 0; while copy_start < buf.len() { @@ -62,10 +63,105 @@ impl Write for MessageSender { } } +/// Given an arbitrary writer, this function will launch the given number of threads and call the processor function +/// on each of them, and collect the output written by each thread to the writer in blocks identified by the thread_id. +/// +/// This output stream can be processed by multiple_read to get the data back, using the same number of threads. +pub fn multiplex_write( + writer: &mut WRITE, + num_threads: usize, + processor: FN, +) -> Result> +where + WRITE: Write, + FN: Fn(&mut MultiplexWriter, usize) -> Result + Send + Copy, + RESULT: Send, +{ + let mut thread_results = Vec::>>::new(); + + for _i in 0..num_threads { + thread_results.push(None); + } + + rayon::in_place_scope(|s| -> Result<()> { + let (tx, rx) = channel(); + + for (thread_id, result) in thread_results.iter_mut().enumerate() { + let cloned_sender = tx.clone(); + + let mut thread_writer = MultiplexWriter { + thread_id: thread_id as u8, + sender: cloned_sender, + buffer: Vec::with_capacity(WRITE_BUFFER_SIZE), + }; + + let mut f = move || -> Result { + let r = processor(&mut thread_writer, thread_id)?; + + thread_writer.flush().context(here!())?; + + thread_writer.sender.send(Message::Eof).context(here!())?; + Ok(r) + }; + + s.spawn(move |_| { + *result = Some(f()); + }); + } + + // drop the sender so that the channel breaks when all the threads exit + drop(tx); + + // wait to collect work and done messages from all the threads + let mut threads_left = num_threads; + + while threads_left > 0 { + let value = rx.recv().context(here!()); + match value { + Ok(Message::Eof) => { + threads_left -= 1; + } + Ok(Message::WriteBlock(thread_id, b)) => { + let l = b.len() - 1; + + writer.write_u8(thread_id).context(here!())?; + writer.write_u8((l & 0xff) as u8).context(here!())?; + writer.write_u8(((l >> 8) & 0xff) as u8).context(here!())?; + writer.write_all(&b[..]).context(here!())?; + } + Err(_) => { + break; + } + } + } + + return Ok(()); + }) + .context(here!())?; + + let mut thread_not_run = false; + let mut results = Vec::new(); + + for result in thread_results.drain(..) { + match result { + None => thread_not_run = true, + Some(Ok(r)) => results.push(r), + // if there was an error processing anything, return it + Some(Err(e)) => return Err(e), + } + } + + if thread_not_run { + return err_exit_code(ExitCode::GeneralFailure, "thread did not run"); + } + + Ok(results) +} + /// used by the worker thread to read data for the given thread from the /// receiver. The thread_id is used only to assert that we are only /// getting the data that we are expecting -pub struct MessageReceiver { +pub struct MultiplexReader { /// the multiplexed thread stream we are processing thread_id: u8, @@ -81,7 +177,7 @@ pub struct MessageReceiver { end_of_file: bool, } -impl Read for MessageReceiver { +impl Read for MultiplexReader { /// fast path for reads. If we get zero bytes, take the slow path #[inline(always)] fn read(&mut self, buf: &mut [u8]) -> std::io::Result { @@ -94,7 +190,7 @@ impl Read for MessageReceiver { } } -impl MessageReceiver { +impl MultiplexReader { /// slow path for reads, try to get a new buffer or /// return zero if at the end of the stream #[cold] @@ -133,12 +229,11 @@ impl MessageReceiver { pub fn multiplex_read( reader: &mut READ, num_threads: usize, - last_data_position: u64, processor: FN, ) -> Result> where - READ: Read + Seek, - FN: Fn(usize, &mut MessageReceiver) -> Result + Send + Copy, + READ: Read, + FN: Fn(usize, &mut MultiplexReader) -> Result + Send + Copy, RESULT: Send, { // track if we got an error while trying to send to a thread @@ -161,7 +256,7 @@ where s.spawn(move |_| { // get the appropriate receiver so we can read out data from it - let mut proc_reader = MessageReceiver { + let mut proc_reader = MultiplexReader { thread_id: thread_id as u8, current_buffer: Cursor::new(Vec::new()), receiver: rx, @@ -172,20 +267,20 @@ where } // now that the channels are waiting for input, read the stream and send all the buffers to their respective readers - while reader.stream_position().context(here!())? < last_data_position - 4 { - let thread_marker = reader.read_u8().context(here!())?; + loop { + let mut thread_marker_a = [0; 1]; + if reader.read(&mut thread_marker_a)? == 0 { + break; + } + + let thread_marker = thread_marker_a[0]; + let thread_id = (thread_marker & 0xf) as u8; if thread_id >= channel_to_sender.len() as u8 { return err_exit_code( ExitCode::BadLeptonFile, - format!( - "invalid thread_id at {0} of {1} at {2}", - reader.stream_position().unwrap(), - last_data_position, - here!() - ) - .as_str(), + format!("invalid thread_id {0}", thread_id).as_str(), ); } @@ -205,15 +300,9 @@ where let mut buffer = Vec::::new(); buffer.resize(data_length as usize, 0); - reader.read_exact(&mut buffer).with_context(|| { - format!( - "reading {0} bytes at {1} of {2} at {3}", - buffer.len(), - reader.stream_position().unwrap(), - last_data_position, - here!() - ) - })?; + reader + .read_exact(&mut buffer) + .with_context(|| format!("reading {0} bytes", buffer.len()))?; let e = channel_to_sender[thread_id as usize].send(Message::WriteBlock(thread_id, buffer)); @@ -260,93 +349,28 @@ where Ok(result) } -pub fn multiplex_write( - writer: &mut WRITE, - num_threads: usize, - processor: FN, -) -> Result> -where - WRITE: Write, - FN: Fn(&mut MessageSender, usize) -> Result + Send + Copy, - RESULT: Send, -{ - let mut thread_results = Vec::>>::new(); - - for _i in 0..num_threads { - thread_results.push(None); - } - - rayon::in_place_scope(|s| -> Result<()> { - let (tx, rx) = channel(); - - for (thread_id, result) in thread_results.iter_mut().enumerate() { - let cloned_sender = tx.clone(); - - let mut thread_writer = MessageSender { - thread_id: thread_id as u8, - sender: cloned_sender, - buffer: Vec::with_capacity(WRITE_BUFFER_SIZE), - }; - - let mut f = move || -> Result { - let r = processor(&mut thread_writer, thread_id)?; - - thread_writer.flush().context(here!())?; +/// simple end to end test that write the thread id and reads it back +#[test] +fn test_multiplex_end_to_end() { + let mut output = Vec::new(); - thread_writer.sender.send(Message::Eof).context(here!())?; - Ok(r) - }; - - s.spawn(move |_| { - *result = Some(f()); - }); - } + let w = multiplex_write(&mut output, 10, |writer, thread_id| -> Result { + writer.write_u32::(thread_id as u32)?; - // drop the sender so that the channel breaks when all the threads exit - drop(tx); - - // wait to collect work and done messages from all the threads - let mut threads_left = num_threads; - - while threads_left > 0 { - let value = rx.recv().context(here!()); - match value { - Ok(Message::Eof) => { - threads_left -= 1; - } - Ok(Message::WriteBlock(thread_id, b)) => { - let l = b.len() - 1; - - writer.write_u8(thread_id).context(here!())?; - writer.write_u8((l & 0xff) as u8).context(here!())?; - writer.write_u8(((l >> 8) & 0xff) as u8).context(here!())?; - writer.write_all(&b[..]).context(here!())?; - } - Err(_) => { - break; - } - } - } - - return Ok(()); + Ok(thread_id) }) - .context(here!())?; + .unwrap(); - let mut thread_not_run = false; - let mut results = Vec::new(); + assert_eq!(w[..], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); - for result in thread_results.drain(..) { - match result { - None => thread_not_run = true, - Some(Ok(r)) => results.push(r), - // if there was an error processing anything, return it - Some(Err(e)) => return Err(e), - } - } + let mut reader = Cursor::new(output); - if thread_not_run { - return err_exit_code(ExitCode::GeneralFailure, "thread did not run"); - } + let r = multiplex_read(&mut reader, 10, |thread_id, reader| -> Result { + let read_thread_id = reader.read_u32::()?; + assert_eq!(read_thread_id, thread_id as u32); + Ok(thread_id) + }) + .unwrap(); - Ok(results) + assert_eq!(r[..], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); }