Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WASI-NN] Add support for a PyTorch backend for wasi-nn #9234

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

rahulchaphalkar
Copy link
Contributor

This change adds a PyTorch backend for wasi-nn.
tch crate is used for Libtorch bindings. I have added an image classification example to demonstrate its usage, which uses a torchscript model.
This backend is currently gated behind a wasi-nn feature flag --features pytorch as due to dynamic linking, a Libtorch v2.4.0 installation on the system (specified by LIBTORCH=/path/to/libtorch) is needed for building.

@rahulchaphalkar rahulchaphalkar requested review from alexcrichton and removed request for a team September 12, 2024 18:18
@abrown abrown self-assigned this Sep 12, 2024
@alexcrichton alexcrichton requested review from abrown and removed request for a team and alexcrichton September 12, 2024 20:09
Copy link
Collaborator

@abrown abrown left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good start. The main thing to fix is the handling of the input and output tensors.

Comment on lines +10 to +11
use tch::CModule;
use tch::{Device, Kind, TchError, Tensor as TchTensor};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: merge these

) -> Result<Graph, BackendError> {
// Load the model from the file path
let compiled_module =
CModule::load_on_device(path, map_execution_target_to_string(target)).unwrap();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems incorrect: load_from_dir is going to pass the path to a directory and this code will try to use it as a file.

.iter()
.map(|&dim| dim as i64)
.collect::<Vec<_>>();
self.1 = TchTensor::from_data_size(&input_tensor.data, &dimensions, kind);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function needs to handle the passed index: a model can have multiple inputs and the job this function needs to do is map the incoming tensor to the right one.


fn compute(&mut self) -> Result<(), BackendError> {
// Use forward method on the compiled module/model after locking the mutex, and pass the input tensor to it
self.1 = self.0.lock().unwrap().forward_ts(&[&self.1]).unwrap();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Storing the output tensor in the same location as the input tensor means that set_input followed immediately by get_output would return the input tensor... probably not what you want here. It looks like forward_ts only returns a single tensor so perhaps just create an output field for that and another input: Vec<Tensor> for the inputs.

Comment on lines +106 to +111
let data = vec![0f32; numel];
let mut data_u8: Vec<u8> = data
.iter()
.flat_map(|&x| x.to_le_bytes().to_vec())
.collect();
self.1.copy_data_u8(&mut data_u8, numel);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is incorrect: we need to retrieve the data regardless of the type, so we need to first figure out how many bytes each Kind is before constructing the receiving buffer, like:

Suggested change
let data = vec![0f32; numel];
let mut data_u8: Vec<u8> = data
.iter()
.flat_map(|&x| x.to_le_bytes().to_vec())
.collect();
self.1.copy_data_u8(&mut data_u8, numel);
let data = vec![0u8; size_of(ty) * numel];
self.1.copy_data_u8(&mut data, numel);

match target {
ExecutionTarget::Cpu => Device::Cpu,
ExecutionTarget::Gpu => {
unimplemented!("Pytorch does not yet support GPU execution targets")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
unimplemented!("Pytorch does not yet support GPU execution targets")
unimplemented!("the pytorch backend does not yet support GPU execution targets")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...since PyTorch backend does indeed support GPU execution.

unimplemented!("Pytorch does not yet support GPU execution targets")
}
ExecutionTarget::Tpu => {
unimplemented!("Pytorch does not yet support TPU execution targets")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
unimplemented!("Pytorch does not yet support TPU execution targets")
unimplemented!("the pytorch backend does not yet support TPU execution targets")

@abrown
Copy link
Collaborator

abrown commented Sep 20, 2024

The cargo vet situation is a bit much:

    cargo vet diff zstd-safe 5.0.1+zstd.1.5.2 5.0.2+zstd.1.5.2
                                                          gyscos         zstd                        2 files changed, 4 insertions(+), 4 deletions(-)
    cargo vet diff zstd 0.11.1+zstd.1.5.2 0.11.2+zstd.1.5.2
                                                          gyscos         zip                         3 files changed, 5 insertions(+), 5 deletions(-)
    cargo vet diff num-complex 0.4.2 0.4.6                cuviper        ndarray                     6 files changed, 188 insertions(+), 48 deletions(-)
      NOTE: this project trusts Josh Stone (cuviper) - consider cargo vet trust num-complex or cargo vet trust --all cuviper
    cargo vet inspect constant_time_eq 0.1.5              cesarb         zip                         311 lines
    cargo vet diff sha1 0.10.5 0.10.6                     newpavlov      zip                         7 files changed, 302 insertions(+), 20 deletions(-)
    cargo vet inspect rawpointer 0.2.1                    bluss          ndarray and matrixmultiply  559 lines
    cargo vet diff zip 0.6.4 0.6.6                        Plecra         tch and torch-sys           14 files changed, 604 insertions(+), 109 deletions(-)
    cargo vet inspect inout 0.1.3                         newpavlov      cipher                      1112 lines
      NOTE: cargo vet import zcash would eliminate this
    cargo vet inspect pbkdf2 0.9.0                        tarcieri       zip                         1120 lines
    cargo vet inspect bzip2 0.4.4                         alexcrichton   zip                         2094 lines
      NOTE: this project trusts Alex Crichton (alexcrichton) - consider cargo vet trust bzip2 or cargo vet trust --all alexcrichton
    cargo vet inspect safetensors 0.3.3                   Narsil         tch                         2200 lines
    cargo vet inspect cipher 0.4.4                        newpavlov      aes                         2635 lines
      NOTE: cargo vet import zcash would reduce this to a [130](https://github.com/bytecodealliance/wasmtime/actions/runs/10836457564/job/30070281197?pr=9234#step:6:131)0-line diff
    cargo vet inspect password-hash 0.3.2                 tarcieri       pbkdf2                      3139 lines
    cargo vet inspect base64ct 1.6.0                      tarcieri       password-hash               3381 lines
    cargo vet diff half 1.8.2 2.4.1                       starkat99      tch                         19 files changed, 2546 insertions(+), 958 deletions(-)
    cargo vet inspect time 0.1.44                         jhpratt        zip                         3915 lines
    cargo vet inspect aes 0.7.5                           tarcieri       zip                         6822 lines
    cargo vet inspect matrixmultiply 0.3.8                bluss          ndarray                     7934 lines
    cargo vet inspect ndarray 0.15.6                      jturner314     tch                         41996 lines
    cargo vet inspect torch-sys 0.17.0                    LaurentMazare  tch                         52119 lines
    cargo vet inspect bzip2-sys 0.1.11+1.0.8              alexcrichton   bzip2                       264[133](https://github.com/bytecodealliance/wasmtime/actions/runs/10836457564/job/30070281197?pr=9234#step:6:134) lines
      NOTE: this project trusts Alex Crichton (alexcrichton) - consider cargo vet trust bzip2-sys or cargo vet trust --all alexcrichton
    cargo vet inspect tch 0.17.0                          LaurentMazare  wasmtime-wasi-nn            2287297 lines

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants