Skip to content

Commit

Permalink
Re-implement SYCL backend parallel_for to improve bandwidth utiliza…
Browse files Browse the repository at this point in the history
…tion (#1976)

Signed-off-by: Matthew Michel <[email protected]>
  • Loading branch information
mmichel11 authored Jan 31, 2025
1 parent 7bbaf83 commit 83c3741
Show file tree
Hide file tree
Showing 35 changed files with 1,376 additions and 217 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace oneapi::dpl::experimental::kt::gpu::esimd::__impl
{

//------------------------------------------------------------------------
// Please see the comment for __parallel_for_submitter for optional kernel name explanation
// Please see the comment above __parallel_for_small_submitter for optional kernel name explanation
//------------------------------------------------------------------------

template <bool __is_ascending, ::std::uint8_t __radix_bits, ::std::uint16_t __data_per_work_item,
Expand Down
18 changes: 12 additions & 6 deletions include/oneapi/dpl/internal/async_impl/async_impl_hetero.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ __pattern_walk1_async(__hetero_tag<_BackendTag>, _ExecutionPolicy&& __exec, _For

auto __future_obj = oneapi::dpl::__par_backend_hetero::__parallel_for(
_BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec),
unseq_backend::walk_n<_ExecutionPolicy, _Function>{__f}, __n, __buf.all_view());
unseq_backend::walk1_vector_or_scalar<_ExecutionPolicy, _Function, decltype(__buf.all_view())>{
__f, static_cast<std::size_t>(__n)},
__n, __buf.all_view());
return __future_obj;
}

Expand All @@ -67,7 +69,9 @@ __pattern_walk2_async(__hetero_tag<_BackendTag>, _ExecutionPolicy&& __exec, _For

auto __future = oneapi::dpl::__par_backend_hetero::__parallel_for(
_BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec),
unseq_backend::walk_n<_ExecutionPolicy, _Function>{__f}, __n, __buf1.all_view(), __buf2.all_view());
unseq_backend::walk2_vectors_or_scalars<_ExecutionPolicy, _Function, decltype(__buf1.all_view()),
decltype(__buf2.all_view())>{__f, static_cast<std::size_t>(__n)},
__n, __buf1.all_view(), __buf2.all_view());

return __future.__make_future(__first2 + __n);
}
Expand All @@ -91,10 +95,12 @@ __pattern_walk3_async(__hetero_tag<_BackendTag>, _ExecutionPolicy&& __exec, _For
oneapi::dpl::__ranges::__get_sycl_range<__par_backend_hetero::access_mode::write, _ForwardIterator3>();
auto __buf3 = __keep3(__first3, __first3 + __n);

auto __future =
oneapi::dpl::__par_backend_hetero::__parallel_for(_BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec),
unseq_backend::walk_n<_ExecutionPolicy, _Function>{__f}, __n,
__buf1.all_view(), __buf2.all_view(), __buf3.all_view());
auto __future = oneapi::dpl::__par_backend_hetero::__parallel_for(
_BackendTag{}, std::forward<_ExecutionPolicy>(__exec),
unseq_backend::walk3_vectors_or_scalars<_ExecutionPolicy, _Function, decltype(__buf1.all_view()),
decltype(__buf2.all_view()), decltype(__buf3.all_view())>{
__f, static_cast<size_t>(__n)},
__n, __buf1.all_view(), __buf2.all_view(), __buf3.all_view());

return __future.__make_future(__first3 + __n);
}
Expand Down
37 changes: 26 additions & 11 deletions include/oneapi/dpl/internal/binary_search_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,19 @@ enum class search_algorithm
binary_search
};

template <typename Comp, typename T, search_algorithm func>
struct custom_brick
#if _ONEDPL_BACKEND_SYCL
template <typename Comp, typename T, typename _Range, search_algorithm func>
struct __custom_brick : oneapi::dpl::unseq_backend::walk_scalar_base<_Range>
{
Comp comp;
T size;
bool use_32bit_indexing;

__custom_brick(Comp comp, T size, bool use_32bit_indexing)
: comp(std::move(comp)), size(size), use_32bit_indexing(use_32bit_indexing)
{
}

template <typename _Size, typename _ItemId, typename _Acc>
void
search_impl(_ItemId idx, _Acc acc) const
Expand All @@ -68,17 +74,23 @@ struct custom_brick
get<2>(acc[idx]) = (value != end_orig) && (get<1>(acc[idx]) == get<0>(acc[value]));
}
}

template <typename _ItemId, typename _Acc>
template <typename _IsFull, typename _ItemId, typename _Acc>
void
operator()(_ItemId idx, _Acc acc) const
__scalar_path_impl(_IsFull, _ItemId idx, _Acc acc) const
{
if (use_32bit_indexing)
search_impl<std::uint32_t>(idx, acc);
else
search_impl<std::uint64_t>(idx, acc);
}
template <typename _IsFull, typename _ItemId, typename _Acc>
void
operator()(_IsFull __is_full, _ItemId idx, _Acc acc) const
{
__scalar_path_impl(__is_full, idx, acc);
}
};
#endif

template <class _Tag, typename Policy, typename InputIterator1, typename InputIterator2, typename OutputIterator,
typename StrictWeakOrdering>
Expand Down Expand Up @@ -155,7 +167,8 @@ lower_bound_impl(__internal::__hetero_tag<_BackendTag>, Policy&& policy, InputIt
const bool use_32bit_indexing = size <= std::numeric_limits<std::uint32_t>::max();
__bknd::__parallel_for(
_BackendTag{}, ::std::forward<decltype(policy)>(policy),
custom_brick<StrictWeakOrdering, decltype(size), search_algorithm::lower_bound>{comp, size, use_32bit_indexing},
__custom_brick<StrictWeakOrdering, decltype(size), decltype(zip_vw), search_algorithm::lower_bound>{
comp, size, use_32bit_indexing},
value_size, zip_vw)
.__deferrable_wait();
return result + value_size;
Expand Down Expand Up @@ -187,7 +200,8 @@ upper_bound_impl(__internal::__hetero_tag<_BackendTag>, Policy&& policy, InputIt
const bool use_32bit_indexing = size <= std::numeric_limits<std::uint32_t>::max();
__bknd::__parallel_for(
_BackendTag{}, std::forward<decltype(policy)>(policy),
custom_brick<StrictWeakOrdering, decltype(size), search_algorithm::upper_bound>{comp, size, use_32bit_indexing},
__custom_brick<StrictWeakOrdering, decltype(size), decltype(zip_vw), search_algorithm::upper_bound>{
comp, size, use_32bit_indexing},
value_size, zip_vw)
.__deferrable_wait();
return result + value_size;
Expand Down Expand Up @@ -217,10 +231,11 @@ binary_search_impl(__internal::__hetero_tag<_BackendTag>, Policy&& policy, Input
auto result_buf = keep_result(result, result + value_size);
auto zip_vw = make_zip_view(input_buf.all_view(), value_buf.all_view(), result_buf.all_view());
const bool use_32bit_indexing = size <= std::numeric_limits<std::uint32_t>::max();
__bknd::__parallel_for(_BackendTag{}, std::forward<decltype(policy)>(policy),
custom_brick<StrictWeakOrdering, decltype(size), search_algorithm::binary_search>{
comp, size, use_32bit_indexing},
value_size, zip_vw)
__bknd::__parallel_for(
_BackendTag{}, std::forward<decltype(policy)>(policy),
__custom_brick<StrictWeakOrdering, decltype(size), decltype(zip_vw), search_algorithm::binary_search>{
comp, size, use_32bit_indexing},
value_size, zip_vw)
.__deferrable_wait();
return result + value_size;
}
Expand Down
Loading

0 comments on commit 83c3741

Please sign in to comment.