From 1cad690da29484a99c5e84f3b7eb109e879e4a6f Mon Sep 17 00:00:00 2001 From: Kristof Roomp Date: Fri, 12 Apr 2024 06:51:27 +0200 Subject: [PATCH] Clean up new rayon code to separate mulitplexer code from lepton parsing (#62) * refactor threading logic into multiplexer away from lepton processing logic * some cleanup * move comment * add comment * add more comments --- src/main.rs | 3 +- src/structs/lepton_format.rs | 518 +++++++---------------------------- src/structs/mod.rs | 1 + src/structs/multiplexer.rs | 386 ++++++++++++++++++++++++++ 4 files changed, 486 insertions(+), 422 deletions(-) create mode 100644 src/structs/multiplexer.rs diff --git a/src/main.rs b/src/main.rs index 16058b9e..a044e7a4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -126,8 +126,7 @@ fn main_with_result() -> anyhow::Result<()> { (block_image, _metrics) = lh .decode_as_single_image( - &mut reader, - filelen, + &mut reader.take(filelen - 4), // last 4 bytes are the length of the file num_threads as usize, &enabled_features, ) diff --git a/src/structs/lepton_format.rs b/src/structs/lepton_format.rs index 706f782d..853f3dfc 100644 --- a/src/structs/lepton_format.rs +++ b/src/structs/lepton_format.rs @@ -8,9 +8,6 @@ use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use log::{info, warn}; use std::cmp; use std::io::{Cursor, ErrorKind, Read, Seek, SeekFrom, Write}; -use std::mem::swap; -use std::sync::mpsc::{channel, Sender}; -use std::sync::mpsc::{Receiver, SendError}; use std::time::Instant; use anyhow::{Context, Result}; @@ -31,6 +28,7 @@ use crate::structs::jpeg_header::JPegHeader; use crate::structs::jpeg_write::jpeg_write_row_range; use crate::structs::lepton_decoder::lepton_decode_row_range; use crate::structs::lepton_encoder::lepton_encode_row_range; +use crate::structs::multiplexer::{multiplex_read, multiplex_write}; use crate::structs::probability_tables_set::ProbabilityTablesSet; use crate::structs::quantization_tables::QuantizationTables; use crate::structs::thread_handoff::ThreadHandoff; @@ -51,17 +49,37 @@ pub fn decode_lepton_wrapper( let size = reader.seek(SeekFrom::End(0))?; reader.seek(SeekFrom::Start(orig_pos))?; + // last four bytes specify the file size + let mut reader_minus_trailer = reader.take(size - 4); + let mut lh = LeptonHeader::new(); let mut features_mut = enabled_features.clone(); - lh.read_lepton_header(reader, &mut features_mut) + lh.read_lepton_header(&mut reader_minus_trailer, &mut features_mut) .context(here!())?; let metrics = lh - .recode_jpeg(writer, reader, size, num_threads, &features_mut) + .recode_jpeg( + writer, + &mut reader_minus_trailer, + num_threads, + &features_mut, + ) .context(here!())?; + let expected_size = reader.read_u32::()?; + if expected_size != size as u32 { + return err_exit_code( + ExitCode::VerificationLengthMismatch, + format!( + "ERROR mismatch expected_size = {0}, actual_size = {1}", + expected_size, size + ) + .as_str(), + ); + } + return Ok(metrics); } @@ -304,10 +322,9 @@ pub fn read_jpeg( Ok((lp, image_data)) } -fn run_lepton_decoder_threads( +fn run_lepton_decoder_threads( lh: &LeptonHeader, reader: &mut R, - last_data_position: u64, _max_threads_to_use: usize, features: &EnabledFeatures, process: fn( @@ -331,121 +348,62 @@ fn run_lepton_decoder_threads( qt.push(qtables); } - let mut thread_results = Vec::>>::new(); - for _i in 0..lh.thread_handoff.len() { - thread_results.push(None); - } - - // track if we got an error while trying to send to a thread - let mut error_sending: Option> = None; - - rayon::in_place_scope(|s| -> Result<()> { - let mut channel_to_sender = Vec::new(); - - let pts_ref = &pts; - let q_ref = &qt[..]; - - info!("decoding {0} multiplexed streams", lh.thread_handoff.len()); - - // create a channel for each stream and spawn a work item to read from it - // the return value from each work item is stored in thread_results, which - // is collected at the end - for (t, result) in thread_results.iter_mut().enumerate() { - let (tx, rx) = channel(); - channel_to_sender.push(tx); - - s.spawn(move |_| { - *result = Some(decoder_thread(lh, t, rx, pts_ref, q_ref, features, process)); - }); - } - - // now that the channels are waiting for input, read the stream and send all the buffers to their respective readers - while reader.stream_position().context(here!())? < last_data_position - 4 { - let thread_marker = reader.read_u8().context(here!())?; - let thread_id = (thread_marker & 0xf) as u8; - - if thread_id >= channel_to_sender.len() as u8 { - return err_exit_code( - ExitCode::BadLeptonFile, - format!( - "invalid thread_id at {0} of {1} at {2}", - reader.stream_position().unwrap(), - last_data_position, - here!() - ) - .as_str(), - ); + let pts_ref = &pts; + let q_ref = &qt[..]; + + let mut thread_results = multiplex_read( + reader, + lh.thread_handoff.len(), + |thread_id, reader| -> Result<(Metrics, P)> { + let cpu_time = CpuTimeMeasure::new(); + + let mut image_data = Vec::new(); + for i in 0..lh.jpeg_header.cmpc { + image_data.push(BlockBasedImage::new( + &lh.jpeg_header, + i, + lh.thread_handoff[thread_id].luma_y_start, + if thread_id == lh.thread_handoff.len() - 1 { + // if this is the last thread, then the image should extend all the way to the bottom + lh.jpeg_header.cmp_info[0].bcv + } else { + lh.thread_handoff[thread_id].luma_y_end + }, + )); } - let data_length = if thread_marker < 16 { - let b0 = reader.read_u8().context(here!())?; - let b1 = reader.read_u8().context(here!())?; + let mut metrics = Metrics::default(); - ((b1 as usize) << 8) + b0 as usize + 1 - } else { - // This format is used by Lepton C++ to write encoded chunks with length of 4096, 16384 or 65536 bytes - let flags = (thread_marker >> 4) & 3; - - 1024 << (2 * flags) - }; - - //info!("offset {0} len {1}", reader.stream_position()?-2, data_length); - - let mut buffer = Vec::::new(); - buffer.resize(data_length as usize, 0); - reader.read_exact(&mut buffer).with_context(|| { - format!( - "reading {0} bytes at {1} of {2} at {3}", - buffer.len(), - reader.stream_position().unwrap(), - last_data_position, - here!() + metrics.merge_from( + lepton_decode_row_range( + pts_ref, + q_ref, + &lh.truncate_components, + &mut image_data, + reader, + lh.thread_handoff[thread_id].luma_y_start, + lh.thread_handoff[thread_id].luma_y_end, + thread_id == lh.thread_handoff.len() - 1, + true, + features, ) - })?; - - let e = - channel_to_sender[thread_id as usize].send(Message::WriteBlock(thread_id, buffer)); + .context(here!())?, + ); - if let Err(e) = e { - error_sending = Some(e); - break; - } - } - //info!("done sending!"); + let process_result = process(&lh.thread_handoff[thread_id], image_data, lh)?; - for c in channel_to_sender { - // ignore the result of send, since a thread may have already blown up with an error and we will get it when we join (rather than exiting with a useless channel broken message) - let _ = c.send(Message::Eof); - } + metrics.record_cpu_worker_time(cpu_time.elapsed()); - Ok(()) - })?; + Ok((metrics, process_result)) + }, + )?; let mut metrics = Metrics::default(); let mut result = Vec::new(); - let mut thread_not_run = false; - for i in thread_results.drain(..) { - match i { - None => thread_not_run = true, - Some(Err(e)) => { - return Err(e).context(here!()); - } - Some(Ok((m, r))) => { - metrics.merge_from(m); - result.push(r); - } - } - } - - if thread_not_run { - return err_exit_code(ExitCode::GeneralFailure, "thread did not run").context(here!()); - } - - // if there was an error during send, it should have resulted in an error from one of the threads above and - // we wouldn't get here, but as an extra precaution, we check here to make sure we didn't miss anything - if let Some(e) = error_sending { - return Err(e).context(here!()); + for (m, r) in thread_results.drain(..) { + metrics.merge_from(m); + result.push(r); } info!( @@ -457,65 +415,6 @@ fn run_lepton_decoder_threads( Ok((metrics, result)) } -fn decoder_thread

( - lh: &LeptonHeader, - thread_id: usize, - rx: Receiver, - pts_ref: &ProbabilityTablesSet, - q_ref: &[QuantizationTables], - features: &EnabledFeatures, - process: fn(&ThreadHandoff, Vec, &LeptonHeader) -> Result, -) -> Result<(Metrics, P), anyhow::Error> { - let cpu_time = CpuTimeMeasure::new(); - - let mut image_data = Vec::new(); - for i in 0..lh.jpeg_header.cmpc { - image_data.push(BlockBasedImage::new( - &lh.jpeg_header, - i, - lh.thread_handoff[thread_id].luma_y_start, - if thread_id == lh.thread_handoff.len() - 1 { - // if this is the last thread, then the image should extend all the way to the bottom - lh.jpeg_header.cmp_info[0].bcv - } else { - lh.thread_handoff[thread_id].luma_y_end - }, - )); - } - - let mut metrics = Metrics::default(); - - // get the appropriate receiver so we can read out data from it - let mut reader = MessageReceiver { - thread_id: thread_id as u8, - current_buffer: Cursor::new(Vec::new()), - receiver: rx, - end_of_file: false, - }; - - metrics.merge_from( - lepton_decode_row_range( - pts_ref, - q_ref, - &lh.truncate_components, - &mut image_data, - &mut reader, - lh.thread_handoff[thread_id].luma_y_start, - lh.thread_handoff[thread_id].luma_y_end, - thread_id == lh.thread_handoff.len() - 1, - true, - features, - ) - .context(here!())?, - ); - - let process_result = process(&lh.thread_handoff[thread_id], image_data, lh)?; - - metrics.record_cpu_worker_time(cpu_time.elapsed()); - - Ok((metrics, process_result)) -} - /// runs the encoding threads and returns the total amount of CPU time consumed (including worker threads) fn run_lepton_encoder_threads( jpeg_header: &JPegHeader, @@ -551,87 +450,36 @@ fn run_lepton_encoder_threads( let pts_ref = &pts; let q_ref = &quantization_tables[..]; - let mut sizes = Vec::::new(); - sizes.resize(thread_handoffs.len(), 0); - - let mut thread_results = Vec::>>::new(); - for _i in 0..thread_handoffs.len() { - thread_results.push(None); - } - - rayon::in_place_scope(|s| -> Result<()> { - let (tx, rx) = channel(); - - for (thread_id, result) in thread_results.iter_mut().enumerate() { - let cloned_sender = tx.clone(); - - s.spawn(move |_| { - *result = Some(encode_thread_action( - thread_id, - cloned_sender, - pts_ref, - q_ref, - image_data, - colldata, - thread_handoffs, - features, - )); - }); - } - - // drop the sender so that the channel breaks when all the threads exit - drop(tx); - - // wait to collect work and done messages from all the threads - let mut threads_left = thread_handoffs.len(); - - while threads_left > 0 { - let value = rx.recv().context(here!()); - match value { - Ok(Message::Eof) => { - threads_left -= 1; - } - Ok(Message::WriteBlock(thread_id, b)) => { - let l = b.len() - 1; - - writer.write_u8(thread_id).context(here!())?; - writer.write_u8((l & 0xff) as u8).context(here!())?; - writer.write_u8(((l >> 8) & 0xff) as u8).context(here!())?; - writer.write_all(&b[..]).context(here!())?; + let mut thread_results = + multiplex_write(writer, thread_handoffs.len(), |thread_writer, thread_id| { + let cpu_time = CpuTimeMeasure::new(); + + let mut range_metrics = lepton_encode_row_range( + pts_ref, + q_ref, + image_data, + thread_writer, + thread_id as i32, + colldata, + thread_handoffs[thread_id].luma_y_start, + thread_handoffs[thread_id].luma_y_end, + thread_id == thread_handoffs.len() - 1, + true, + features, + ) + .context(here!())?; - sizes[thread_id as usize] += b.len() as u64; - } - Err(_) => { - break; - } - } - } + range_metrics.record_cpu_worker_time(cpu_time.elapsed()); - return Ok(()); - }) - .context(here!())?; + Ok(range_metrics) + })?; - let mut thread_not_run = false; let mut merged_metrics = Metrics::default(); for result in thread_results.drain(..) { - match result { - None => thread_not_run = true, - Some(Ok(metrics)) => merged_metrics.merge_from(metrics), - // if there was an error processing anything, return it - Some(Err(e)) => return Err(e), - } - } - - if thread_not_run { - return err_exit_code(ExitCode::GeneralFailure, "thread did not run"); + merged_metrics.merge_from(result); } - info!( - "scan portion of JPEG uncompressed size = {0}", - sizes.iter().sum::() - ); - info!( "worker threads {0}ms of CPU time in {1}ms of wall time", merged_metrics.get_cpu_time_worker_time().as_millis(), @@ -641,48 +489,6 @@ fn run_lepton_encoder_threads( Ok(merged_metrics) } -fn encode_thread_action( - thread_id: usize, - cloned_sender: Sender, - pts_ref: &ProbabilityTablesSet, - q_ref: &[QuantizationTables], - image_data: &[BlockBasedImage], - colldata: &TruncateComponents, - thread_handoffs: &[ThreadHandoff], - features: &EnabledFeatures, -) -> std::prelude::v1::Result { - let cpu_time = CpuTimeMeasure::new(); - - let mut thread_writer = MessageSender { - thread_id: thread_id as u8, - sender: cloned_sender, - buffer: Vec::with_capacity(WRITE_BUFFER_SIZE), - }; - - let mut range_metrics = lepton_encode_row_range( - pts_ref, - q_ref, - image_data, - &mut thread_writer, - thread_id as i32, - colldata, - thread_handoffs[thread_id].luma_y_start, - thread_handoffs[thread_id].luma_y_end, - thread_id == thread_handoffs.len() - 1, - true, - features, - ) - .context(here!())?; - - thread_writer.flush().context(here!())?; - - thread_writer.sender.send(Message::Eof).context(here!())?; - - range_metrics.record_cpu_worker_time(cpu_time.elapsed()); - - Ok(range_metrics) -} - #[derive(Debug)] pub struct LeptonHeader { /// raw jpeg header to be written back to the file when it is recreated @@ -767,11 +573,10 @@ impl LeptonHeader { }; } - fn recode_jpeg( + fn recode_jpeg( &mut self, writer: &mut W, reader: &mut R, - last_data_position: u64, num_threads: usize, enabled_features: &EnabledFeatures, ) -> Result { @@ -783,18 +588,11 @@ impl LeptonHeader { .context(here!())?; let metrics = if self.jpeg_header.jpeg_type == JPegType::Progressive { - self.recode_progressive_jpeg( - reader, - last_data_position, - writer, - num_threads, - enabled_features, - ) - .context(here!())? + self.recode_progressive_jpeg(reader, writer, num_threads, enabled_features) + .context(here!())? } else { self.recode_baseline_jpeg( reader, - last_data_position, writer, self.plain_text_size as u64 - self.garbage_data.len() as u64 @@ -817,10 +615,9 @@ impl LeptonHeader { } /// decodes the entire image and merges the results into a single set of BlockBaseImage per component - pub fn decode_as_single_image( + pub fn decode_as_single_image( &mut self, reader: &mut R, - last_data_position: u64, num_threads: usize, features: &EnabledFeatures, ) -> Result<(Vec, Metrics)> { @@ -828,7 +625,6 @@ impl LeptonHeader { let (metrics, mut results) = run_lepton_decoder_threads( self, reader, - last_data_position, num_threads, features, |_thread_handoff, image_data, _lh| { @@ -868,17 +664,16 @@ impl LeptonHeader { } /// progressive decoder, requires that the entire lepton file is processed first - fn recode_progressive_jpeg( + fn recode_progressive_jpeg( &mut self, reader: &mut R, - last_data_position: u64, writer: &mut W, num_threads: usize, enabled_features: &EnabledFeatures, ) -> Result { // run the threads first, since we need everything before we can start decoding let (merged, metrics) = self - .decode_as_single_image(reader, last_data_position, num_threads, enabled_features) + .decode_as_single_image(reader, num_threads, enabled_features) .context(here!())?; loop { @@ -908,10 +703,9 @@ impl LeptonHeader { // baseline decoder can run the jpeg encoder inside the worker thread vs progressive encoding which needs to get the entire set of coefficients first // since it runs throught it multiple times. - fn recode_baseline_jpeg( + fn recode_baseline_jpeg( &mut self, reader: &mut R, - last_data_position: u64, writer: &mut W, size_limit: u64, num_threads: usize, @@ -921,7 +715,6 @@ impl LeptonHeader { let (metrics, results) = run_lepton_decoder_threads( self, reader, - last_data_position, num_threads, enabled_features, |thread_handoff, image_data, lh| { @@ -1534,121 +1327,6 @@ fn get_number_of_threads_for_encoding( return num_threads; } -enum Message { - Eof, - WriteBlock(u8, Vec), -} - -struct MessageSender { - thread_id: u8, - sender: Sender, - buffer: Vec, -} - -const WRITE_BUFFER_SIZE: usize = 65536; - -impl Write for MessageSender { - fn write(&mut self, buf: &[u8]) -> std::io::Result { - let mut copy_start = 0; - while copy_start < buf.len() { - let amount_to_copy = cmp::min( - WRITE_BUFFER_SIZE - self.buffer.len(), - buf.len() - copy_start, - ); - self.buffer - .extend_from_slice(&buf[copy_start..copy_start + amount_to_copy]); - - if self.buffer.len() == WRITE_BUFFER_SIZE { - self.flush()?; - } - - copy_start += amount_to_copy; - } - - Ok(buf.len()) - } - - fn flush(&mut self) -> std::io::Result<()> { - if self.buffer.len() > 0 { - let mut new_buffer = Vec::with_capacity(WRITE_BUFFER_SIZE); - swap(&mut new_buffer, &mut self.buffer); - - self.sender - .send(Message::WriteBlock(self.thread_id, new_buffer)) - .unwrap(); - } - Ok(()) - } -} - -/// used by the worker thread to read data for the given thread from the -/// receiver. The thread_id is used only to assert that we are only -/// getting the data that we are expecting -struct MessageReceiver { - /// the multiplexed thread stream we are processing - thread_id: u8, - - /// the receiver part of the channel to get more buffers - receiver: Receiver, - - /// what we are reading. When this returns zero, we try to - /// refill the buffer if we haven't reached the end of the stream - current_buffer: Cursor>, - - /// once we get told we are at the end of the stream, we just - /// always return 0 bytes - end_of_file: bool, -} - -impl Read for MessageReceiver { - /// fast path for reads. If we get zero bytes, take the slow path - #[inline(always)] - fn read(&mut self, buf: &mut [u8]) -> std::io::Result { - let amount_read = self.current_buffer.read(buf)?; - if amount_read > 0 { - return Ok(amount_read); - } - - self.read_slow(buf) - } -} - -impl MessageReceiver { - /// slow path for reads, try to get a new buffer or - /// return zero if at the end of the stream - #[cold] - #[inline(never)] - fn read_slow(&mut self, buf: &mut [u8]) -> std::io::Result { - while !self.end_of_file { - let amount_read = self.current_buffer.read(buf)?; - if amount_read > 0 { - return Ok(amount_read); - } - - match self.receiver.recv() { - Ok(r) => match r { - Message::Eof => { - self.end_of_file = true; - } - Message::WriteBlock(tid, block) => { - debug_assert_eq!( - tid, self.thread_id, - "incoming thread must be equal to processing thread" - ); - self.current_buffer = Cursor::new(block); - } - }, - Err(e) => { - return Result::Err(std::io::Error::new(std::io::ErrorKind::Other, e)); - } - } - } - - // nothing if we reached the end of file - return Ok(0); - } -} - // internal utility we use to collect the header that we read for later struct Mirror<'a, R, W> { read: &'a mut R, diff --git a/src/structs/mod.rs b/src/structs/mod.rs index b168a238..e42c6bbd 100644 --- a/src/structs/mod.rs +++ b/src/structs/mod.rs @@ -19,6 +19,7 @@ mod lepton_decoder; mod lepton_encoder; pub mod lepton_format; mod model; +mod multiplexer; mod neighbor_summary; mod probability_tables; mod probability_tables_coefficient_context; diff --git a/src/structs/multiplexer.rs b/src/structs/multiplexer.rs new file mode 100644 index 00000000..53543560 --- /dev/null +++ b/src/structs/multiplexer.rs @@ -0,0 +1,386 @@ +/// Implements a multiplexer that reads and writes blocks to a stream from multiple threads. +/// +/// The write implementation identifies the blocks by thread_id and tries to write in 64K blocks. The file +/// ends up with an interleaved stream of blocks from each thread. +/// +/// The read implementation reads the blocks from the file and sends them to the appropriate worker thread. +use crate::{helpers::*, ExitCode}; +use anyhow::{Context, Result}; +use byteorder::{ReadBytesExt, WriteBytesExt}; +use std::{ + cmp, + io::{Cursor, Read, Write}, + mem::swap, + sync::mpsc::{channel, Receiver, SendError, Sender}, +}; + +/// The message that is sent between the threads +enum Message { + Eof, + WriteBlock(u8, Vec), +} + +pub struct MultiplexWriter { + thread_id: u8, + sender: Sender, + buffer: Vec, +} + +const WRITE_BUFFER_SIZE: usize = 65536; + +impl Write for MultiplexWriter { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let mut copy_start = 0; + while copy_start < buf.len() { + let amount_to_copy = cmp::min( + WRITE_BUFFER_SIZE - self.buffer.len(), + buf.len() - copy_start, + ); + self.buffer + .extend_from_slice(&buf[copy_start..copy_start + amount_to_copy]); + + if self.buffer.len() == WRITE_BUFFER_SIZE { + self.flush()?; + } + + copy_start += amount_to_copy; + } + + Ok(buf.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + if self.buffer.len() > 0 { + let mut new_buffer = Vec::with_capacity(WRITE_BUFFER_SIZE); + swap(&mut new_buffer, &mut self.buffer); + + self.sender + .send(Message::WriteBlock(self.thread_id, new_buffer)) + .unwrap(); + } + Ok(()) + } +} + +/// Given an arbitrary writer, this function will launch the given number of threads and call the processor function +/// on each of them, and collect the output written by each thread to the writer in blocks identified by the thread_id. +/// +/// This output stream can be processed by multiple_read to get the data back, using the same number of threads. +pub fn multiplex_write( + writer: &mut WRITE, + num_threads: usize, + processor: FN, +) -> Result> +where + WRITE: Write, + FN: Fn(&mut MultiplexWriter, usize) -> Result + Send + Copy, + RESULT: Send, +{ + let mut thread_results = Vec::>>::new(); + + for _i in 0..num_threads { + thread_results.push(None); + } + + rayon::in_place_scope(|s| -> Result<()> { + let (tx, rx) = channel(); + + for (thread_id, result) in thread_results.iter_mut().enumerate() { + let cloned_sender = tx.clone(); + + let mut thread_writer = MultiplexWriter { + thread_id: thread_id as u8, + sender: cloned_sender, + buffer: Vec::with_capacity(WRITE_BUFFER_SIZE), + }; + + let mut f = move || -> Result { + let r = processor(&mut thread_writer, thread_id)?; + + thread_writer.flush().context(here!())?; + + thread_writer.sender.send(Message::Eof).context(here!())?; + Ok(r) + }; + + s.spawn(move |_| { + *result = Some(f()); + }); + } + + // drop the sender so that the channel breaks when all the threads exit + drop(tx); + + // wait to collect work and done messages from all the threads + let mut threads_left = num_threads; + + while threads_left > 0 { + let value = rx.recv().context(here!()); + match value { + Ok(Message::Eof) => { + threads_left -= 1; + } + Ok(Message::WriteBlock(thread_id, b)) => { + let l = b.len() - 1; + + writer.write_u8(thread_id).context(here!())?; + writer.write_u8((l & 0xff) as u8).context(here!())?; + writer.write_u8(((l >> 8) & 0xff) as u8).context(here!())?; + writer.write_all(&b[..]).context(here!())?; + } + Err(_) => { + // if we get a receiving error here, this means that one of the threads broke + // with an error, and this error will be collected when we join the threads + break; + } + } + } + + // in place scope will join all the threads before it exits + return Ok(()); + }) + .context(here!())?; + + let mut thread_not_run = false; + let mut results = Vec::new(); + + for result in thread_results.drain(..) { + match result { + None => thread_not_run = true, + Some(Ok(r)) => results.push(r), + // if there was an error processing anything, return it + Some(Err(e)) => return Err(e), + } + } + + if thread_not_run { + return err_exit_code(ExitCode::GeneralFailure, "thread did not run"); + } + + Ok(results) +} + +/// Used by the processor thread to read data in a blocking way. +/// The thread_id is used only to assert that we are only +/// getting the data that we are expecting. +pub struct MultiplexReader { + /// the multiplexed thread stream we are processing + thread_id: u8, + + /// the receiver part of the channel to get more buffers + receiver: Receiver, + + /// what we are reading. When this returns zero, we try to + /// refill the buffer if we haven't reached the end of the stream + current_buffer: Cursor>, + + /// once we get told we are at the end of the stream, we just + /// always return 0 bytes + end_of_file: bool, +} + +impl Read for MultiplexReader { + /// fast path for reads. If we run out of data, take the slow path + #[inline(always)] + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + let amount_read = self.current_buffer.read(buf)?; + if amount_read > 0 { + return Ok(amount_read); + } + + self.read_slow(buf) + } +} + +impl MultiplexReader { + /// slow path for reads, try to get a new buffer or + /// return zero if at the end of the stream + #[cold] + #[inline(never)] + fn read_slow(&mut self, buf: &mut [u8]) -> std::io::Result { + while !self.end_of_file { + let amount_read = self.current_buffer.read(buf)?; + if amount_read > 0 { + return Ok(amount_read); + } + + match self.receiver.recv() { + Ok(r) => match r { + Message::Eof => { + self.end_of_file = true; + } + Message::WriteBlock(tid, block) => { + debug_assert_eq!( + tid, self.thread_id, + "incoming thread must be equal to processing thread" + ); + self.current_buffer = Cursor::new(block); + } + }, + Err(e) => { + return Result::Err(std::io::Error::new(std::io::ErrorKind::Other, e)); + } + } + } + + // nothing if we reached the end of file + return Ok(0); + } +} + +/// Reads data in multiplexed format and sends it to the appropriate processor, each +/// running on its own thread. The processor function is called with the thread_id and +/// a blocking reader that it can use to read its own data. +/// +/// Once the multiplexed data is finished reading, we break the channel to the worker threads +/// causing processor that is trying to read from the channel to error out and exit. After all +/// the readers have exited, we collect the results/errors from all the processors and return a vector +/// of the results back to the caller. +pub fn multiplex_read( + reader: &mut READ, + num_threads: usize, + processor: FN, +) -> Result> +where + READ: Read, + FN: Fn(usize, &mut MultiplexReader) -> Result + Send + Copy, + RESULT: Send, +{ + // track if we got an error while trying to send to a thread + let mut error_sending: Option> = None; + + let mut thread_results = Vec::>>::new(); + for _i in 0..num_threads { + thread_results.push(None); + } + + rayon::in_place_scope(|s| -> Result<()> { + let mut channel_to_sender = Vec::new(); + + // create a channel for each stream and spawn a work item to read from it + // the return value from each work item is stored in thread_results, which + // is collected at the end + for (thread_id, result) in thread_results.iter_mut().enumerate() { + let (tx, rx) = channel(); + channel_to_sender.push(tx); + + s.spawn(move |_| { + // get the appropriate receiver so we can read out data from it + let mut proc_reader = MultiplexReader { + thread_id: thread_id as u8, + current_buffer: Cursor::new(Vec::new()), + receiver: rx, + end_of_file: false, + }; + *result = Some(processor(thread_id, &mut proc_reader)); + }); + } + + // now that the channels are waiting for input, read the stream and send all the buffers to their respective readers + loop { + let mut thread_marker_a = [0; 1]; + if reader.read(&mut thread_marker_a)? == 0 { + break; + } + + let thread_marker = thread_marker_a[0]; + + let thread_id = (thread_marker & 0xf) as u8; + + if thread_id >= channel_to_sender.len() as u8 { + return err_exit_code( + ExitCode::BadLeptonFile, + format!("invalid thread_id {0}", thread_id).as_str(), + ); + } + + let data_length = if thread_marker < 16 { + let b0 = reader.read_u8().context(here!())?; + let b1 = reader.read_u8().context(here!())?; + + ((b1 as usize) << 8) + b0 as usize + 1 + } else { + // This format is used by Lepton C++ to write encoded chunks with length of 4096, 16384 or 65536 bytes + let flags = (thread_marker >> 4) & 3; + + 1024 << (2 * flags) + }; + + //info!("offset {0} len {1}", reader.stream_position()?-2, data_length); + + let mut buffer = Vec::::new(); + buffer.resize(data_length as usize, 0); + reader + .read_exact(&mut buffer) + .with_context(|| format!("reading {0} bytes", buffer.len()))?; + + let e = + channel_to_sender[thread_id as usize].send(Message::WriteBlock(thread_id, buffer)); + + if let Err(e) = e { + error_sending = Some(e); + break; + } + } + //info!("done sending!"); + + for c in channel_to_sender { + // ignore the result of send, since a thread may have already blown up with an error and we will get it when we join (rather than exiting with a useless channel broken message) + let _ = c.send(Message::Eof); + } + + Ok(()) + })?; + + let mut result = Vec::new(); + let mut thread_not_run = false; + for i in thread_results.drain(..) { + match i { + None => thread_not_run = true, + Some(Err(e)) => { + return Err(e).context(here!()); + } + Some(Ok(r)) => { + result.push(r); + } + } + } + + if thread_not_run { + return err_exit_code(ExitCode::GeneralFailure, "thread did not run").context(here!()); + } + + // if there was an error during send, it should have resulted in an error from one of the threads above and + // we wouldn't get here, but as an extra precaution, we check here to make sure we didn't miss anything + if let Some(e) = error_sending { + return Err(e).context(here!()); + } + + Ok(result) +} + +/// simple end to end test that write the thread id and reads it back +#[test] +fn test_multiplex_end_to_end() { + let mut output = Vec::new(); + + let w = multiplex_write(&mut output, 10, |writer, thread_id| -> Result { + writer.write_u32::(thread_id as u32)?; + + Ok(thread_id) + }) + .unwrap(); + + assert_eq!(w[..], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); + + let mut reader = Cursor::new(output); + + let r = multiplex_read(&mut reader, 10, |thread_id, reader| -> Result { + let read_thread_id = reader.read_u32::()?; + assert_eq!(read_thread_id, thread_id as u32); + Ok(thread_id) + }) + .unwrap(); + + assert_eq!(r[..], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); +}