Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefix sum implementation WIP #14

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 83 additions & 43 deletions compute-shader-hello/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,20 @@ use wgpu::util::DeviceExt;

use bytemuck;

const N_DATA: usize = 1 << 25;
const WG_SIZE: usize = 1 << 12;

// Verify that the data is OEIS A000217
fn verify(data: &[u32]) -> Option<usize> {
data.iter().enumerate().position(|(i, val)| {
let wrong = ((i * (i + 1)) / 2) as u32 != *val;
if wrong {
println!("diff @ {}: {} != {}", i, ((i * (i + 1)) / 2) as u32, *val);
}
wrong
})
}

async fn run() {
let instance = wgpu::Instance::new(wgpu::Backends::PRIMARY);
let adapter = instance.request_adapter(&Default::default()).await.unwrap();
Expand All @@ -30,7 +44,7 @@ async fn run() {
.request_device(
&wgpu::DeviceDescriptor {
label: None,
features: features & wgpu::Features::TIMESTAMP_QUERY,
features: features & (wgpu::Features::TIMESTAMP_QUERY | wgpu::Features::CLEAR_COMMANDS),
limits: Default::default(),
},
None,
Expand All @@ -54,13 +68,12 @@ async fn run() {
source: wgpu::ShaderSource::Wgsl(include_str!("shader.wgsl").into()),
});
println!("shader compilation {:?}", start_instant.elapsed());
let input_f = &[1.0f32, 2.0f32];
let input : &[u8] = bytemuck::bytes_of(input_f);
let input_f: Vec<u32> = (0..N_DATA as u32).collect();
let input: &[u8] = bytemuck::cast_slice(&input_f);
let input_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: None,
contents: input,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
});
let output_buf = device.create_buffer(&wgpu::BufferDescriptor {
Expand All @@ -69,6 +82,15 @@ async fn run() {
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
const N_WG: usize = N_DATA / WG_SIZE;
const STATE_SIZE: usize = N_WG * 3 + 1;
// TODO: round this up
let state_buf = device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: 4 * STATE_SIZE as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
// This works if the buffer is initialized, otherwise reads all 0, for some reason.
let query_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: None,
Expand All @@ -87,48 +109,66 @@ async fn run() {
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &bind_group_layout,
entries: &[wgpu::BindGroupEntry {
binding: 0,
resource: input_buf.as_entire_binding(),
}],
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: input_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: state_buf.as_entire_binding(),
},
],
});

let mut encoder = device.create_command_encoder(&Default::default());
if let Some(query_set) = &query_set {
encoder.write_timestamp(query_set, 0);
}
{
let mut cpass = encoder.begin_compute_pass(&Default::default());
cpass.set_pipeline(&pipeline);
cpass.set_bind_group(0, &bind_group, &[]);
cpass.dispatch(input_f.len() as u32, 1, 1);
}
if let Some(query_set) = &query_set {
encoder.write_timestamp(query_set, 1);
}
encoder.copy_buffer_to_buffer(&input_buf, 0, &output_buf, 0, input.len() as u64);
if let Some(query_set) = &query_set {
encoder.resolve_query_set(query_set, 0..2, &query_buf, 0);
}
queue.submit(Some(encoder.finish()));
for i in 0..100 {
let mut encoder = device.create_command_encoder(&Default::default());
if let Some(query_set) = &query_set {
encoder.write_timestamp(query_set, 0);
}
encoder.clear_buffer(&state_buf, 0, None);
{
let mut cpass = encoder.begin_compute_pass(&Default::default());
cpass.set_pipeline(&pipeline);
cpass.set_bind_group(0, &bind_group, &[]);
cpass.dispatch(N_WG as u32, 1, 1);
}
if let Some(query_set) = &query_set {
encoder.write_timestamp(query_set, 1);
}
if i == 0 {
encoder.copy_buffer_to_buffer(&input_buf, 0, &output_buf, 0, input.len() as u64);
}
if let Some(query_set) = &query_set {
encoder.resolve_query_set(query_set, 0..2, &query_buf, 0);
}
queue.submit(Some(encoder.finish()));

let buf_slice = output_buf.slice(..);
let buf_future = buf_slice.map_async(wgpu::MapMode::Read);
let query_slice = query_buf.slice(..);
let _query_future = query_slice.map_async(wgpu::MapMode::Read);
println!("pre-poll {:?}", std::time::Instant::now());
device.poll(wgpu::Maintain::Wait);
println!("post-poll {:?}", std::time::Instant::now());
if buf_future.await.is_ok() {
let data_raw = &*buf_slice.get_mapped_range();
let data : &[f32] = bytemuck::cast_slice(data_raw);
println!("data: {:?}", &*data);
}
if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
let ts_period = queue.get_timestamp_period();
let ts_data_raw = &*query_slice.get_mapped_range();
let ts_data : &[u64] = bytemuck::cast_slice(ts_data_raw);
println!("compute shader elapsed: {:?}ms", (ts_data[1] - ts_data[0]) as f64 * ts_period as f64 * 1e-6);
let buf_slice = output_buf.slice(..);
let buf_future = buf_slice.map_async(wgpu::MapMode::Read);
let query_slice = query_buf.slice(..);
let query_future = query_slice.map_async(wgpu::MapMode::Read);
device.poll(wgpu::Maintain::Wait);
if buf_future.await.is_ok() {
if i == 0 {
let data_raw = &*buf_slice.get_mapped_range();
let data: &[u32] = bytemuck::cast_slice(data_raw);
println!("results correct: {:?}", verify(data));
}
output_buf.unmap();
}
if query_future.await.is_ok() {
if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
let ts_period = queue.get_timestamp_period();
let ts_data_raw = &*query_slice.get_mapped_range();
let ts_data: &[u64] = bytemuck::cast_slice(ts_data_raw);
println!(
"compute shader elapsed: {:?}ms",
(ts_data[1] - ts_data[0]) as f64 * ts_period as f64 * 1e-6
);
}
}
query_buf.unmap();
}
}

Expand Down
120 changes: 114 additions & 6 deletions compute-shader-hello/src/shader.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,122 @@

[[block]]
struct DataBuf {
data: [[stride(4)]] array<f32>;
data: [[stride(4)]] array<u32>;
};

[[block]]
struct StateBuf {
state: [[stride(4)]] array<atomic<u32>>;
};

[[group(0), binding(0)]]
var<storage, read_write> v_indices: DataBuf;
var<storage, read_write> main_buf: DataBuf;

[[group(0), binding(1)]]
var<storage, read_write> state_buf: StateBuf;

let FLAG_NOT_READY = 0u;
let FLAG_AGGREGATE_READY = 1u;
let FLAG_PREFIX_READY = 2u;

let workgroup_size: u32 = 512u;
let N_SEQ = 8u;

var<workgroup> part_id: u32;
var<workgroup> scratch: array<u32, workgroup_size>;
var<workgroup> shared_prefix: u32;
var<workgroup> shared_flag: u32;

[[stage(compute), workgroup_size(512)]]
fn main([[builtin(local_invocation_id)]] local_id: vec3<u32>) {
if (local_id.x == 0u) {
part_id = atomicAdd(&state_buf.state[0], 1u);
}
workgroupBarrier();
let my_part_id = part_id;
let mem_base = my_part_id * workgroup_size;
var local: array<u32, N_SEQ>;
var el = main_buf.data[(mem_base + local_id.x) * N_SEQ];
local[0] = el;
for (var i: u32 = 1u; i < N_SEQ; i = i + 1u) {
el = el + main_buf.data[(mem_base + local_id.x) * N_SEQ + i];
local[i] = el;
}
scratch[local_id.x] = el;
// This must be lg2(workgroup_size)
for (var i: u32 = 0u; i < 9u; i = i + 1u) {
workgroupBarrier();
if (local_id.x >= (1u << i)) {
el = el + scratch[local_id.x - (1u << i)];
}
workgroupBarrier();
scratch[local_id.x] = el;
}
var exclusive_prefix = 0u;

var flag = FLAG_AGGREGATE_READY;
if (local_id.x == workgroup_size - 1u) {
atomicStore(&state_buf.state[my_part_id * 3u + 2u], el);
if (my_part_id == 0u) {
atomicStore(&state_buf.state[my_part_id * 3u + 3u], el);
flag = FLAG_PREFIX_READY;
}
}
// make sure these barriers are in uniform control flow
storageBarrier();
if (local_id.x == workgroup_size - 1u) {
atomicStore(&state_buf.state[my_part_id * 3u + 1u], flag);
}

if (my_part_id != 0u) {
// decoupled look-back
var look_back_ix = my_part_id - 1u;
loop {
if (local_id.x == workgroup_size - 1u) {
shared_flag = atomicOr(&state_buf.state[look_back_ix * 3u + 1u], 0u);
}
workgroupBarrier();
flag = shared_flag;
storageBarrier();
if (flag == FLAG_PREFIX_READY) {
if (local_id.x == workgroup_size - 1u) {
let their_prefix = atomicOr(&state_buf.state[look_back_ix * 3u + 3u], 0u);
exclusive_prefix = their_prefix + exclusive_prefix;
}
break;
} elseif (flag == FLAG_AGGREGATE_READY) {
if (local_id.x == workgroup_size - 1u) {
let their_agg = atomicOr(&state_buf.state[look_back_ix * 3u + 2u], 0u);
exclusive_prefix = their_agg + exclusive_prefix;
}
look_back_ix = look_back_ix - 1u;
}
// else spin
}

// compute inclusive prefix
if (local_id.x == workgroup_size - 1u) {
let inclusive_prefix = exclusive_prefix + el;
shared_prefix = exclusive_prefix;
atomicStore(&state_buf.state[my_part_id * 3u + 3u], inclusive_prefix);
}
storageBarrier();
if (local_id.x == workgroup_size - 1u) {
atomicStore(&state_buf.state[my_part_id * 3u + 1u], FLAG_PREFIX_READY);
}
}
var prefix = 0u;
workgroupBarrier();
if (my_part_id != 0u) {
prefix = shared_prefix;
}

[[stage(compute), workgroup_size(1)]]
fn main([[builtin(global_invocation_id)]] global_id: vec3<u32>) {
// TODO: a more interesting computation than this.
v_indices.data[global_id.x] = v_indices.data[global_id.x] + 42.0;
// do final output
for (var i: u32 = 0u; i < N_SEQ; i = i + 1u) {
var old = 0u;
if (local_id.x > 0u) {
old = scratch[local_id.x - 1u];
}
main_buf.data[(mem_base + local_id.x) * N_SEQ + i] = prefix + old + local[i];
}
}