-
-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/reshape: Add evaluate_shape to reshape.rs to allow for variable #148
base: master
Are you sure you want to change the base?
Conversation
…reshape inputs. Update examples to use isize and variable batch_size.
Hey @oxctdev :) thanks for opening this PR. The commit seems to be in need of some additional attention, I'll add comments inline. |
dbg!(&network); | ||
panic!("End"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stray dbg statement, this should not end up in master
, not even in a test.
@@ -262,7 +262,7 @@ pub(crate) fn test(backend: Rc<Backend<Cuda>>, batch_size: usize, file: &Path) - | |||
} | |||
|
|||
fn main() { | |||
env_logger::builder().filter_level(log::LevelFilter::Info).init(); | |||
env_logger::builder().filter_level(log::LevelFilter::Trace).init(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Info
was picked intentionally to not flood first time users with lot's of overwhelming (=useless) info.
@@ -28,6 +28,7 @@ rand = "0.8" | |||
num = "0.4" | |||
capnp = "0.14" | |||
timeit = "0.1" | |||
anyhow = "1.0" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is a library, thiserror
is the preferred approach. Anyhow is for applications.
fn evaluate_shape(&self, input_shape: &TensorDesc) -> Result<Vec<usize>> { | ||
dbg!(&self.shape); | ||
dbg!(input_shape); | ||
let unknown_dimensions: usize = self.shape.iter().filter(|x| **x == -1).count(); | ||
let invalid_dimensions: usize = self.shape.iter().filter(|x| **x < -1).count(); | ||
if invalid_dimensions > 0 { | ||
return Err(anyhow!("Invalid elements provided to Reshape")) | ||
} | ||
return match unknown_dimensions { | ||
0 => Ok(self.shape.clone().into_iter().map(|x| x as usize).collect()), | ||
1 => { | ||
let total_prior_elements: usize = input_shape.iter().product(); | ||
let known_elements: usize = self.shape.iter().filter(|x| **x > -1).product::<isize>() as usize; | ||
dbg!(total_prior_elements); | ||
dbg!(known_elements); | ||
if total_prior_elements != (total_prior_elements / known_elements * known_elements) { | ||
Err(anyhow!( | ||
"Dimensions {:?} do not cleanly reshape into {:?}", | ||
input_shape, self.shape | ||
)) | ||
} else { | ||
let unknown_element: usize = total_prior_elements / known_elements; | ||
Ok(self.shape | ||
.iter() | ||
.map(|x| if *x == -1 { unknown_element } else { *x as usize }) | ||
.collect()) | ||
} | ||
} | ||
_ => Err(anyhow!("More than 2 unknown elements provided to Reshape")), | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is in need of bunch of tests :) and also has stray dbg!
statements and requires some love for the Err
types (or the lack of those).
Gentle ping? |
…reshape inputs. Update examples to use isize and variable batch_size.
What does this PR accomplish?
Changes proposed by this PR:
Notes to reviewer:
📜 Checklist
juice-examples
run just fine