Skip to content

Commit

Permalink
Improve the implementation of __enumerable_thread_local_storage (#2049)
Browse files Browse the repository at this point in the history
Implement with CRTP.
Store construction arguments by values.
Replace unique_ptr with optional.

Co-authored-by: Dan Hoeflinger <[email protected]>
  • Loading branch information
akukanov and danhoeflinger authored Feb 6, 2025
1 parent b21e7ca commit e9c8b1e
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 53 deletions.
44 changes: 27 additions & 17 deletions include/oneapi/dpl/pstl/omp/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,36 +152,46 @@ __process_chunk(const __chunk_metrics& __metrics, _Iterator __base, _Index __chu

namespace __detail
{
struct __get_num_threads

// Workaround for VS 2017: declare an alias to the CRTP base template
template <typename _ValueType, typename... _Args>
struct __enumerable_thread_local_storage;

template <typename... _Ts>
using __etls_base = __utils::__enumerable_thread_local_storage_base<__enumerable_thread_local_storage, _Ts...>;

template <typename _ValueType, typename... _Args>
struct __enumerable_thread_local_storage : public __etls_base<_ValueType, _Args...>
{
std::size_t
operator()() const

template <typename... _LocalArgs>
__enumerable_thread_local_storage(_LocalArgs&&... __args)
: __etls_base<_ValueType, _Args...>({std::forward<_LocalArgs>(__args)...})
{
}

static std::size_t
get_num_threads()
{
return omp_in_parallel() ? omp_get_num_threads() : omp_get_max_threads();
}
};

struct __get_thread_num
{
std::size_t
operator()() const
static std::size_t
get_thread_num()
{
return omp_get_thread_num();
}
};

} // namespace __detail

// enumerable thread local storage should only be created from make function
template <typename _ValueType, typename... Args>
oneapi::dpl::__utils::__detail::__enumerable_thread_local_storage<
_ValueType, oneapi::dpl::__omp_backend::__detail::__get_num_threads,
oneapi::dpl::__omp_backend::__detail::__get_thread_num, Args...>
__make_enumerable_tls(Args&&... __args)
// enumerable thread local storage should only be created with this make function
template <typename _ValueType, typename... _Args>
__detail::__enumerable_thread_local_storage<_ValueType, std::remove_reference_t<_Args>...>
__make_enumerable_tls(_Args&&... __args)
{
return oneapi::dpl::__utils::__detail::__enumerable_thread_local_storage<
_ValueType, oneapi::dpl::__omp_backend::__detail::__get_num_threads,
oneapi::dpl::__omp_backend::__detail::__get_thread_num, Args...>(std::forward<Args>(__args)...);
return __detail::__enumerable_thread_local_storage<_ValueType, std::remove_reference_t<_Args>...>(
std::forward<_Args>(__args)...);
}

} // namespace __omp_backend
Expand Down
49 changes: 30 additions & 19 deletions include/oneapi/dpl/pstl/parallel_backend_tbb.h
Original file line number Diff line number Diff line change
Expand Up @@ -1308,35 +1308,46 @@ __parallel_for_each(oneapi::dpl::__internal::__tbb_backend_tag, _ExecutionPolicy

namespace __detail
{
struct __get_num_threads

// Workaround for VS 2017: declare an alias to the CRTP base template
template <typename _ValueType, typename... _Args>
struct __enumerable_thread_local_storage;

template <typename... _Ts>
using __etls_base = __utils::__enumerable_thread_local_storage_base<__enumerable_thread_local_storage, _Ts...>;

template <typename _ValueType, typename... _Args>
struct __enumerable_thread_local_storage : public __etls_base<_ValueType, _Args...>
{
std::size_t
operator()() const

template <typename... _LocalArgs>
__enumerable_thread_local_storage(_LocalArgs&&... __args)
: __etls_base<_ValueType, _Args...>({std::forward<_LocalArgs>(__args)...})
{
}

static std::size_t
get_num_threads()
{
return tbb::this_task_arena::max_concurrency();
}
};

struct __get_thread_num
{
std::size_t
operator()() const
static std::size_t
get_thread_num()
{
return tbb::this_task_arena::current_thread_index();
}
};
} //namespace __detail

// enumerable thread local storage should only be created from make function
template <typename _ValueType, typename... Args>
oneapi::dpl::__utils::__detail::__enumerable_thread_local_storage<
_ValueType, oneapi::dpl::__tbb_backend::__detail::__get_num_threads,
oneapi::dpl::__tbb_backend::__detail::__get_thread_num, Args...>
__make_enumerable_tls(Args&&... __args)

} // namespace __detail

// enumerable thread local storage should only be created with this make function
template <typename _ValueType, typename... _Args>
__detail::__enumerable_thread_local_storage<_ValueType, std::remove_reference_t<_Args>...>
__make_enumerable_tls(_Args&&... __args)
{
return oneapi::dpl::__utils::__detail::__enumerable_thread_local_storage<
_ValueType, oneapi::dpl::__tbb_backend::__detail::__get_num_threads,
oneapi::dpl::__tbb_backend::__detail::__get_thread_num, Args...>(std::forward<Args>(__args)...);
return __detail::__enumerable_thread_local_storage<_ValueType, std::remove_reference_t<_Args>...>(
std::forward<_Args>(__args)...);
}

} // namespace __tbb_backend
Expand Down
29 changes: 12 additions & 17 deletions include/oneapi/dpl/pstl/parallel_backend_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include <atomic>
#include <cstddef>
#include <iterator>
#include <memory>
#include <optional>
#include <tuple>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -306,16 +306,13 @@ __set_symmetric_difference_construct(_ForwardIterator1 __first1, _ForwardIterato
return __cc_range(__first2, __last2, __result);
}

namespace __detail
template <template <typename, typename...> typename _Concrete, typename _ValueType, typename... _Args>
struct __enumerable_thread_local_storage_base
{
using _Derived = _Concrete<_ValueType, _Args...>;

template <typename _ValueType, typename _GetNumThreads, typename _GetThreadNum, typename... _Args>
struct __enumerable_thread_local_storage
{

template <typename... _LocalArgs>
__enumerable_thread_local_storage(_LocalArgs&&... __args)
: __thread_specific_storage(_GetNumThreads{}()), __num_elements(0), __args(std::forward<_LocalArgs>(__args)...)
__enumerable_thread_local_storage_base(std::tuple<_Args...> __tp)
: __thread_specific_storage(_Derived::get_num_threads()), __num_elements(0), __args(__tp)
{
}

Expand Down Expand Up @@ -359,24 +356,22 @@ struct __enumerable_thread_local_storage
_ValueType&
get_for_current_thread()
{
const std::size_t __i = _GetThreadNum{}();
std::unique_ptr<_ValueType>& __thread_local_storage = __thread_specific_storage[__i];
if (!__thread_local_storage)
const std::size_t __i = _Derived::get_thread_num();
std::optional<_ValueType>& __local = __thread_specific_storage[__i];
if (!__local)
{
// create temporary storage on first usage to avoid extra parallel region and unnecessary instantiation
__thread_local_storage =
std::apply([](_Args... __arg_pack) { return std::make_unique<_ValueType>(__arg_pack...); }, __args);
std::apply([&__local](_Args... __arg_pack) { __local.emplace(__arg_pack...); }, __args);
__num_elements.fetch_add(1, std::memory_order_relaxed);
}
return *__thread_local_storage;
return *__local;
}

std::vector<std::unique_ptr<_ValueType>> __thread_specific_storage;
std::vector<std::optional<_ValueType>> __thread_specific_storage;
std::atomic_size_t __num_elements;
const std::tuple<_Args...> __args;
};

} // namespace __detail
} // namespace __utils
} // namespace dpl
} // namespace oneapi
Expand Down

0 comments on commit e9c8b1e

Please sign in to comment.