Skip to content

Commit

Permalink
Add NPU support for wasi-nn WinML backend.
Browse files Browse the repository at this point in the history
This change adds support for NPU (Neural Processing Unit) to the wasi-nn
WinML backend. Since NPU support in DirectML is still in developer
preview, only a subset of learning models are supported.
  • Loading branch information
jianjunz committed Jul 17, 2024
1 parent c69ab34 commit d6ad4f8
Showing 1 changed file with 74 additions and 8 deletions.
82 changes: 74 additions & 8 deletions crates/wasi-nn/src/backend/winml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@ use windows::Foundation::Collections::IVectorView;
use windows::Storage::Streams::{
DataWriter, InMemoryRandomAccessStream, RandomAccessStreamReference,
};
use windows::Win32::Graphics::DXCore::{
DXCoreCreateAdapterFactory, IDXCoreAdapter, IDXCoreAdapterFactory, IDXCoreAdapterList,
DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE, DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS,
};
use windows::Win32::Graphics::{
Direct3D::D3D_FEATURE_LEVEL_1_0_CORE,
Direct3D12::{
D3D12CreateDevice, ID3D12CommandQueue, ID3D12Device, D3D12_COMMAND_LIST_TYPE_COMPUTE,
D3D12_COMMAND_QUEUE_DESC, D3D12_COMMAND_QUEUE_FLAG_NONE,
},
};
use windows::Win32::System::WinRT::ML::ILearningModelDeviceFactoryNative;
use windows::AI::MachineLearning::{
ILearningModelFeatureDescriptor, LearningModel, LearningModelBinding, LearningModelDevice,
LearningModelDeviceKind, LearningModelEvaluationResult, LearningModelSession,
Expand Down Expand Up @@ -45,12 +57,66 @@ impl BackendInner for WinMLBackend {
let model = LearningModel::LoadFromStream(&RandomAccessStreamReference::CreateFromStream(
&model_stream,
)?)?;
let device_kind = match target {
ExecutionTarget::Cpu => LearningModelDeviceKind::Cpu,
ExecutionTarget::Gpu => LearningModelDeviceKind::DirectX,
ExecutionTarget::Tpu => unimplemented!(),
let device = match target {
ExecutionTarget::Cpu => LearningModelDevice::Create(LearningModelDeviceKind::Cpu),
ExecutionTarget::Gpu => LearningModelDevice::Create(LearningModelDeviceKind::DirectX),
ExecutionTarget::Tpu => unsafe {
// Enumerate adapters with DXCore APIs so MCDM (Microsoft Compute Driver Model) devices can be found.
let dx_adapter_factory: IDXCoreAdapterFactory = DXCoreCreateAdapterFactory()?;
let adapter_list =
dx_adapter_factory.CreateAdapterList::<IDXCoreAdapterList>(&[
DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE,
])?;
let mut selected_device: Option<IDXCoreAdapter> = None;
for i in 0..adapter_list.GetAdapterCount() {
let adapter = adapter_list.GetAdapter::<IDXCoreAdapter>(i)?;
// Select a compute only device. DXCORE_ADAPTER_ATTRIBUTE_D3D12_GENERIC_ML looks more suitable here, but it's defined in DirectX headers.
if adapter.IsAttributeSupported(&DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE)
&& !adapter.IsAttributeSupported(&DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS)
{
selected_device = Some(adapter);
break;
}
}
if selected_device.is_none() {
return Err(BackendError::BackendAccess(anyhow::Error::msg(
"NPU is not available on this device.",
)));
}

let mut d3d12_device: Option<ID3D12Device> = None;
D3D12CreateDevice(
&selected_device.unwrap(),
D3D_FEATURE_LEVEL_1_0_CORE,
&mut d3d12_device,
)?;
if d3d12_device.is_none() {
return Err(BackendError::BackendAccess(anyhow::Error::msg(
"Failed to create D3D12 device.",
)));
}
let d3d12_command_queue_desc: D3D12_COMMAND_QUEUE_DESC = D3D12_COMMAND_QUEUE_DESC {
Type: D3D12_COMMAND_LIST_TYPE_COMPUTE,
Flags: D3D12_COMMAND_QUEUE_FLAG_NONE,
NodeMask: 0,
Priority: 0,
};
let d3d12_command_queue = d3d12_device
.unwrap()
.CreateCommandQueue::<ID3D12CommandQueue>(&d3d12_command_queue_desc)?;
let factory = windows::core::factory::<
LearningModelDevice,
ILearningModelDeviceFactoryNative,
>()?;
factory
.CreateFromD3D12CommandQueue(&d3d12_command_queue)?
.cast::<LearningModelDevice>()
},
};
let graph = WinMLGraph {
model,
device: device?,
};
let graph = WinMLGraph { model, device_kind };

let box_: Box<dyn BackendGraph> = Box::new(graph);
Ok(box_.into())
Expand All @@ -74,16 +140,16 @@ impl BackendFromDir for WinMLBackend {

struct WinMLGraph {
model: LearningModel,
device_kind: LearningModelDeviceKind,
device: LearningModelDevice,
}

unsafe impl Send for WinMLGraph {}
unsafe impl Sync for WinMLGraph {}

impl BackendGraph for WinMLGraph {
fn init_execution_context(&self) -> Result<ExecutionContext, BackendError> {
let device = LearningModelDevice::Create(self.device_kind.clone())?;
let session = LearningModelSession::CreateFromModelOnDevice(&self.model, &device)?;
let session =
LearningModelSession::CreateFromModelOnDevice(&self.model, &self.device).unwrap();
let box_: Box<dyn BackendExecutionContext> = Box::new(WinMLExecutionContext::new(session));
Ok(box_.into())
}
Expand Down

0 comments on commit d6ad4f8

Please sign in to comment.