Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement common swizzle operations. #335

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
172 changes: 172 additions & 0 deletions crates/core_simd/src/swizzle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,4 +364,176 @@ where

(Even::swizzle2(self, other), Odd::swizzle2(self, other))
}

/// Splits a vector into its two halves.
///
/// Due to limitations in const generics, the length of the resulting vector cannot be inferred
/// from the input vectors. You must specify it explicitly or constrain it by context. A
/// compile-time error will be raised if `HALF_LANES * 2 != LANES`.
///
/// ```
/// # #![feature(portable_simd)]
/// # #[cfg(feature = "as_crate")] use core_simd::simd::Simd;
/// # #[cfg(not(feature = "as_crate"))] use core::simd::Simd;
/// let x = Simd::from_array([0, 1, 2, 3, 4, 5, 6, 7]);
/// let [y, z] = x.split();
/// assert_eq!(y.to_array(), [0, 1, 2, 3]);
/// assert_eq!(z.to_array(), [4, 5, 6, 7]);
/// ```
#[inline]
#[must_use = "method returns a new vector and does not mutate the original inputs"]
// TODO: when `generic_const_exprs` supports it, change signature to
// `pub fn split(self) -> [Simd<T, {LANES / 2}>; 2]`
pub fn split<const HALF_LANES: usize>(self) -> [Simd<T, HALF_LANES>; 2]
where
LaneCount<HALF_LANES>: SupportedLaneCount,
{
const fn slice_index<const LEN: usize>(hi_half: bool, lanes: usize) -> [usize; LEN] {
assert!(
LEN * 2 == lanes,
"x.split_to::<N>() must provide N=x.lanes()/2"
);
let offset = if hi_half { LEN } else { 0 };
let mut index = [0; LEN];
let mut i = 0;
while i < LEN {
index[i] = i + offset;
i += 1;
}
index
}
struct Split<const HI_HALF: bool>;
impl<const HI_HALF: bool, const LEN: usize, const LANES: usize> Swizzle<LANES, LEN>
for Split<HI_HALF>
{
const INDEX: [usize; LEN] = slice_index::<LEN>(HI_HALF, LANES);
}
[Split::<false>::swizzle(self), Split::<true>::swizzle(self)]
}

/// Concatenates two vectors of equal length.
///
/// Due to limitations in const generics, the length of the resulting vector cannot be inferred
/// from the input vectors. You must specify it explicitly or constrain it by context.
/// A compile time error will be raised if `LANES + LANES2 != OUTPUT_LANES`.
///
/// ```
/// # #![feature(portable_simd)]
/// # #[cfg(feature = "as_crate")] use core_simd::simd::Simd;
/// # #[cfg(not(feature = "as_crate"))] use core::simd::Simd;
/// let x = Simd::from_array([0, 1, 2, 3]);
/// let y = Simd::from_array([4, 5, 6, 7]);
/// let z = x.concat(y);
/// assert_eq!(z.to_array(), [0, 1, 2, 3, 4, 5, 6, 7]);
/// ```
///
/// Will be rejected at compile time if `LANES * 2 != DOUBLE_LANES`.
reinerp marked this conversation as resolved.
Show resolved Hide resolved
#[inline]
#[must_use = "method returns a new vector and does not mutate the original inputs"]
// TODO: when `generic_const_exprs` supports it, change signature to
// `pub fn concat<const LANES2>(self, other: Simd<T, LANES2>) -> Simd<T, {LANES + LANES2}>`
pub fn concat<const OUTPUT_LANES: usize, const LANES2: usize>(
self,
other: Simd<T, LANES2>,
) -> Simd<T, OUTPUT_LANES>
where
LaneCount<OUTPUT_LANES>: SupportedLaneCount,
LaneCount<LANES2>: SupportedLaneCount,
{
struct Extend;
impl<const I: usize, const O: usize> Swizzle<I, O> for Extend {
const INDEX: [usize; O] = {
assert!(I <= O);
let mut index = [0; O];
let mut i = 0;
while i < I {
index[i] = i;
i += 1;
}
index
};
}
struct Concat<const A: usize, const B: usize, const Y: usize>;
impl<const A: usize, const B: usize, const Y: usize> Swizzle2<Y, Y> for Concat<A, B, Y> {
const INDEX: [Which; Y] = {
assert!(
A + B == Y,
"concat: OUTPUT_LANES must be the sum of all input lane counts"
);
let mut retval = [Which::First(0); Y];
let mut i = 0;
while i < Y {
if i < A {
retval[i] = Which::First(i);
} else {
retval[i] = Which::Second(i - A);
}
i += 1;
}
retval
};
}
Concat::<LANES, LANES2, OUTPUT_LANES>::swizzle2(
Extend::swizzle(self),
Extend::swizzle(other),
)
}

/// For each lane `i`, swaps it with lane `i ^ SWAP_MASK`.
///
/// This is a powerful swizzle operation that can implement many common patterns as special cases.
/// For power-of-2 swap masks, this produces the [butterfly shuffles](https://en.wikipedia.org/wiki/Butterfly_network)
/// that are often useful for horizontal reductions.
///
/// A similar operation (operating on bits instead of lanes) is known as `grev` in the RISC-V
/// Bitmanip specification.
///
/// ```
/// # #![feature(portable_simd)]
/// # #[cfg(feature = "as_crate")] use core_simd::simd::Simd;
/// # #[cfg(not(feature = "as_crate"))] use core::simd::Simd;
/// let x = Simd::from_array([0, 1, 2, 3, 4, 5, 6, 7]);
/// // Swap adjacent lanes:
/// assert_eq!(x.general_reverse::<1>().to_array(), [1, 0, 3, 2, 5, 4, 7, 6]);
/// // Swap lanes separated by distance 2:
/// assert_eq!(x.general_reverse::<2>().to_array(), [2, 3, 0, 1, 6, 7, 4, 5]);
/// // Swap lanes separated by distance 4:
/// assert_eq!(x.general_reverse::<4>().to_array(), [4, 5, 6, 7, 0, 1, 2, 3]);
/// // Reverse lanes, within each 4-lane group:
/// assert_eq!(x.general_reverse::<3>().to_array(), [3, 2, 1, 0, 7, 6, 5, 4]);
/// ```
///
/// Commonly useful for horizontal reductions, for example:
///
/// ```
/// # #![feature(portable_simd)]
/// # #[cfg(feature = "as_crate")] use core_simd::simd::Simd;
/// # #[cfg(not(feature = "as_crate"))] use core::simd::Simd;
/// let x = Simd::from_array([0u32, 1, 2, 3, 4, 5, 6, 7]);
/// let x = x + x.general_reverse::<1>();
/// let x = x + x.general_reverse::<2>();
/// let x = x + x.general_reverse::<4>();
/// assert_eq!(x.to_array(), [28, 28, 28, 28, 28, 28, 28, 28]);
/// ```
#[inline]
#[must_use = "method returns a new vector and does not mutate the original inputs"]
#[doc(alias = "grev")]
#[doc(alias = "butterfly")]
#[doc(alias = "bfly")]
pub fn general_reverse<const SWAP_MASK: usize>(self) -> Self {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imho this should include something like lanewise in it's name, since I expect Rust to gain a bitwise grev at some point and we'd want the bitwise simd and scalar integer operations to have matching names (probably gen_rev or generalized_reverse or grev?).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For something that effectively implements a very specific xor-striding pattern, general_reverse is a non-descript name.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, grev is what that particular bitwise op has become named, thanks to RISC-V afaict.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grev -- general bit reverse

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just because an ISA has picked a terrible name doesn't mean we need to copy it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the precedent of u32::count_ones, it looks like the Rust standard library takes the convention of providing a more meaningful name (count_ones) while maintaining searchability for the common name popcnt using #[doc(alias = "popcnt")], https://github.com/rust-lang/rust/blob/f1b1ed7e18f1fbe5226a96626827c625985f8285/library/core/src/num/int_macros.rs#L104. I think such an approach could be warranted here too, keeping grev as an alias. Indeed, grev is a much less established term than popcnt, so the precedent from grev is weaker.

Other names I considered:

  • butterfly_shuffle. The idea here is that, in the common case of SWAP_MASK being a power of 2, it implements one stage of a butterfly network. https://en.wikipedia.org/wiki/Butterfly_network
  • swap_lanes_xor. The "swap lanes" part is pretty self-explanatory. The "xor" part is confusing, however: it suggests that the data bits are being xored, whereas it's actually the lane indices that are being xored.

Currently I lean towards swap_lanes_xor. Open to other suggestions!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

swizzle_to_xor_indexes?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reverse_pow2_lane_groups? Conceptually this operation performs the reversal of blocks of k lanes within n-lane groups, where k and n are both powers of two, k ≤ n ≤ LANES, and the choice of k and n is determined uniquely up to operation uniqueness by choosing where index 0 will be swizzled to (it’ll be exactly SWAP_MASK).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reverse_pow2_lane_groups?

afaict grev is actually more powerful than that, it can do any arbitrary combination of those k-n reversals for arbitrary k and n.

e.g. grev(v, 5) is equivalent to simd_swizzle!(v, [5, 4, 7, 6, 1, 0, 3, 2]) which is a combination of reversing adjacent pairs (blocks of 1) and swapping blocks of 4.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks all for the suggestions. Of the proposals so far, my preference order is:

  1. swizzle_to_xor_indices
  2. butterfly_swizzle
  3. grev

I have gone with swizzle_to_xor_indices. Let me know what you think.

const fn general_reverse_index<const LANES: usize>(swap_mask: usize) -> [usize; LANES] {
let mut index = [0; LANES];
let mut i = 0;
while i < LANES {
index[i] = i ^ swap_mask;
i += 1;
}
index
}
struct GeneralReverse<const DISTANCE: usize>;
impl<const LANES: usize, const DISTANCE: usize> Swizzle<LANES, LANES> for GeneralReverse<DISTANCE> {
const INDEX: [usize; LANES] = general_reverse_index::<LANES>(DISTANCE);
}
GeneralReverse::<SWAP_MASK>::swizzle(self)
}
}
63 changes: 63 additions & 0 deletions crates/core_simd/tests/swizzle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,66 @@ fn interleave_one() {
assert_eq!(even, a);
assert_eq!(odd, b);
}

#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn slice() {
let a = Simd::from_array([0, 1, 2, 3, 4, 5, 6, 7]);
let [lo, hi] = a.split();
assert_eq!(lo.to_array(), [0, 1, 2, 3]);
assert_eq!(hi.to_array(), [4, 5, 6, 7]);
}

#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn concat_equal_width() {
let x = Simd::from_array([0, 1, 2, 3]);
let y = Simd::from_array([4, 5, 6, 7]);
let z = x.concat(y);
assert_eq!(z.to_array(), [0, 1, 2, 3, 4, 5, 6, 7]);
}

#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn concat_different_width() {
reinerp marked this conversation as resolved.
Show resolved Hide resolved
let x = Simd::from_array([0, 1, 2, 3]);
let y = Simd::from_array([4, 5]);
let z = x.concat(y);
assert_eq!(z.to_array(), [0, 1, 2, 3, 4, 5]);
}

#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn general_reverse() {
let x = Simd::from_array([0, 1, 2, 3, 4, 5, 6, 7]);
// Swap adjacent lanes:
assert_eq!(
x.general_reverse::<1>().to_array(),
[1, 0, 3, 2, 5, 4, 7, 6]
);
// Swap lanes separated by distance 2:
assert_eq!(
x.general_reverse::<2>().to_array(),
[2, 3, 0, 1, 6, 7, 4, 5]
);
// Swap lanes separated by distance 4:
assert_eq!(
x.general_reverse::<4>().to_array(),
[4, 5, 6, 7, 0, 1, 2, 3]
);
// Reverse lanes, within each 4-lane group:
assert_eq!(
x.general_reverse::<3>().to_array(),
[3, 2, 1, 0, 7, 6, 5, 4]
);
}

#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn general_reverse_sum() {
let x = Simd::from_array([0u32, 1, 2, 3, 4, 5, 6, 7]);
let x = x + x.general_reverse::<1>();
let x = x + x.general_reverse::<2>();
let x = x + x.general_reverse::<4>();
assert_eq!(x.to_array(), [28, 28, 28, 28, 28, 28, 28, 28]);
}