diff --git a/juice/src/layers/loss/negative_log_likelihood.rs b/juice/src/layers/loss/negative_log_likelihood.rs index 05be58d10..a56f6c428 100644 --- a/juice/src/layers/loss/negative_log_likelihood.rs +++ b/juice/src/layers/loss/negative_log_likelihood.rs @@ -81,10 +81,12 @@ impl ComputeOutput for NegativeLogLikelihood { let native_labels = labels.read(native.device()).unwrap().as_slice::(); let native_probabilities = probabilities.read(native.device()).unwrap().as_slice::(); - let mut writable_loss = Vec::::new(); + let mut writable_loss = Vec::::with_capacity(native_labels.len()); + let mut offset = 0; for &label_value in native_labels { - let probability_value = native_probabilities[label_value as usize]; + let probability_value = native_probabilities[offset + label_value as usize]; writable_loss.push(-probability_value); + offset += batch_size; } let mut loss = writable_loss.iter().fold(0f32, |sum, &val| sum + val); @@ -159,4 +161,4 @@ impl Into for NegativeLogLikelihoodConfig { fn into(self) -> LayerType { LayerType::NegativeLogLikelihood(self) } -} +} \ No newline at end of file diff --git a/juice/tests/layer_specs.rs b/juice/tests/layer_specs.rs index b05f7a2f5..1f22eb066 100644 --- a/juice/tests/layer_specs.rs +++ b/juice/tests/layer_specs.rs @@ -385,4 +385,61 @@ mod layer_spec { ) .is_err()); } + + use juice::layers::SequentialConfig; + use juice::layers::NegativeLogLikelihoodConfig; + + #[test] + fn nll_basic() { + const BATCH_SIZE: usize = 7; + const KLASS_COUNT: usize = 10; + let native_backend = native_backend(); + let mut classifier_cfg = SequentialConfig::default(); + classifier_cfg.add_input("network_out", &[BATCH_SIZE, KLASS_COUNT]); + classifier_cfg.add_input("label", &[BATCH_SIZE, 1]); + // set up nll loss + let nll_layer_cfg = NegativeLogLikelihoodConfig { num_classes: 10 }; + let nll_cfg = LayerConfig::new("nll", nll_layer_cfg); + classifier_cfg.add_layer(nll_cfg); + let mut network = Layer::from_config( + native_backend.clone(), + &LayerConfig::new("foo", classifier_cfg), + ); + let labels_data = (0..(BATCH_SIZE * KLASS_COUNT)) + .into_iter() + .map(|x| x as f32) + .collect::>(); + let desc = [BATCH_SIZE, KLASS_COUNT]; + let desc: &[usize] = &desc[..]; + let mut input = SharedTensor::::new(&desc); + let mem = input.write_only(native_backend.device()).unwrap(); + let input_data = (0..(KLASS_COUNT * BATCH_SIZE)).into_iter().map(|x| x as f32 * 3.77).collect::>(); + let input_data = &input_data[..]; + juice::util::write_to_memory(mem, input_data); + + // each input has exactly one label + let labels_desc = [BATCH_SIZE, 1]; + let labels_desc = &labels_desc[..]; + let mut labels = SharedTensor::::new(&labels_desc); + + // pretend they have all different classes + let labels_data = (1..=(BATCH_SIZE * 1)) + .into_iter() + .map(|x| x as f32) + .collect::>(); + let mem = labels.write_only(native_backend.device()).unwrap(); + juice::util::write_to_memory(mem, labels_data.as_slice()); + + let input = vec![ + std::sync::Arc::new(std::sync::RwLock::new(input)), + std::sync::Arc::new(std::sync::RwLock::new(labels)), + ]; + + let output = network.forward(input.as_slice()); + + let x = output[0].read().unwrap(); + dbg!(&x); + let out = x.read(native_backend.device()).unwrap(); + dbg!(out.as_slice::()); + } }