-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WASI-NN] Add support for a PyTorch backend for wasi-nn #9234
base: main
Are you sure you want to change the base?
[WASI-NN] Add support for a PyTorch backend for wasi-nn #9234
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good start. The main thing to fix is the handling of the input and output tensors.
use tch::CModule; | ||
use tch::{Device, Kind, TchError, Tensor as TchTensor}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: merge these
) -> Result<Graph, BackendError> { | ||
// Load the model from the file path | ||
let compiled_module = | ||
CModule::load_on_device(path, map_execution_target_to_string(target)).unwrap(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems incorrect: load_from_dir
is going to pass the path to a directory and this code will try to use it as a file.
.iter() | ||
.map(|&dim| dim as i64) | ||
.collect::<Vec<_>>(); | ||
self.1 = TchTensor::from_data_size(&input_tensor.data, &dimensions, kind); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function needs to handle the passed index: a model can have multiple inputs and the job this function needs to do is map the incoming tensor to the right one.
|
||
fn compute(&mut self) -> Result<(), BackendError> { | ||
// Use forward method on the compiled module/model after locking the mutex, and pass the input tensor to it | ||
self.1 = self.0.lock().unwrap().forward_ts(&[&self.1]).unwrap(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Storing the output tensor in the same location as the input tensor means that set_input
followed immediately by get_output
would return the input tensor... probably not what you want here. It looks like forward_ts
only returns a single tensor so perhaps just create an output
field for that and another input: Vec<Tensor>
for the inputs.
let data = vec![0f32; numel]; | ||
let mut data_u8: Vec<u8> = data | ||
.iter() | ||
.flat_map(|&x| x.to_le_bytes().to_vec()) | ||
.collect(); | ||
self.1.copy_data_u8(&mut data_u8, numel); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is incorrect: we need to retrieve the data regardless of the type, so we need to first figure out how many bytes each Kind
is before constructing the receiving buffer, like:
let data = vec![0f32; numel]; | |
let mut data_u8: Vec<u8> = data | |
.iter() | |
.flat_map(|&x| x.to_le_bytes().to_vec()) | |
.collect(); | |
self.1.copy_data_u8(&mut data_u8, numel); | |
let data = vec![0u8; size_of(ty) * numel]; | |
self.1.copy_data_u8(&mut data, numel); |
match target { | ||
ExecutionTarget::Cpu => Device::Cpu, | ||
ExecutionTarget::Gpu => { | ||
unimplemented!("Pytorch does not yet support GPU execution targets") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unimplemented!("Pytorch does not yet support GPU execution targets") | |
unimplemented!("the pytorch backend does not yet support GPU execution targets") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
...since PyTorch backend does indeed support GPU execution.
unimplemented!("Pytorch does not yet support GPU execution targets") | ||
} | ||
ExecutionTarget::Tpu => { | ||
unimplemented!("Pytorch does not yet support TPU execution targets") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unimplemented!("Pytorch does not yet support TPU execution targets") | |
unimplemented!("the pytorch backend does not yet support TPU execution targets") |
The
|
This change adds a PyTorch backend for wasi-nn.
tch crate is used for Libtorch bindings. I have added an image classification example to demonstrate its usage, which uses a torchscript model.
This backend is currently gated behind a wasi-nn feature flag
--features pytorch
as due to dynamic linking, a Libtorch v2.4.0 installation on the system (specified byLIBTORCH=/path/to/libtorch
) is needed for building.