Skip to content

Commit

Permalink
Merge pull request #4 from s-Nick/fix_header_and_complex_issue
Browse files Browse the repository at this point in the history
[BLAS] Fix complex header inclusion and multi_ptr cast
  • Loading branch information
s-Nick authored Jan 14, 2025
2 parents e44e2b6 + f367f92 commit 6411407
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 9 deletions.
6 changes: 3 additions & 3 deletions onemath/sycl/blas/include/blas_meta.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@
#ifdef BLAS_ENABLE_COMPLEX
#define SYCL_EXT_ONEAPI_COMPLEX
#include <complex>
#if __has_include(<ext/oneapi/experimental/complex/complex.hpp>)
#include <ext/oneapi/experimental/complex/complex.hpp>
#if __has_include(<sycl/ext/oneapi/experimental/complex/complex.hpp>)
#include <sycl/ext/oneapi/experimental/complex/complex.hpp>
#else
#include <ext/oneapi/experimental/sycl_complex.hpp>
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
#endif
#endif

Expand Down
12 changes: 12 additions & 0 deletions onemath/sycl/blas/src/operations/blas3/gemm_load_store_complex.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,24 @@ class vec_complex {
m_Data = *(Ptr + Offset * NumElements);
}

// Load
template <address_t Space, decorated_t DecorateAddress>
void load(size_t Offset, const DataT *Ptr) {
m_Data = *(Ptr + Offset * NumElements);
}

// Store
template <address_t Space, decorated_t DecorateAddress>
void store(size_t Offset,
sycl::multi_ptr<DataT, Space, DecorateAddress> Ptr) const {
*(Ptr + Offset * NumElements) = m_Data;
}

// Store
template <address_t Space, decorated_t DecorateAddress>
void store(size_t Offset, DataT *Ptr) const {
*(Ptr + Offset * NumElements) = m_Data;
}
};

/*! @brief Partial specialization of the Packetize class dedicated to
Expand Down
31 changes: 25 additions & 6 deletions onemath/sycl/blas/src/operations/blas3/gemm_local.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,12 +527,31 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
element_t *reg, OutputPointerType out_ptr) {
vector_out_t out_vec{};

out_vec.template load<address_t::private_space>(
0, sycl::multi_ptr<const element_t, address_t::private_space>(reg));
out_vec *= alpha_;

out_vec.template store<address_t::global_space>(
0, sycl::multi_ptr<element_t, address_t::global_space>(out_ptr));
// This if-statement is necessary starting from late 2024 nightly, because
// an update made casting raw pointers of sycl::complex to multi_ptr
// ambiguous.
if constexpr (std::is_same_v<
element_t,
sycl::ext::oneapi::experimental::complex<float>> ||
std::is_same_v<
element_t,
sycl::ext::oneapi::experimental::complex<double>>) {
out_vec.template load<address_t::private_space,
sycl::access::decorated::legacy>(0, reg);
out_vec *= alpha_;

out_vec.template store<address_t::global_space,
sycl::access::decorated::legacy>(0, out_ptr);
} else {
out_vec.template load<address_t::private_space,
sycl::access::decorated::legacy>(
0, sycl::multi_ptr<const element_t, address_t::private_space>(reg));
out_vec *= alpha_;

out_vec.template store<address_t::global_space,
sycl::access::decorated::legacy>(
0, sycl::multi_ptr<element_t, address_t::global_space>(out_ptr));
}
}
/*!
* @brief Store the computed gemm result to the C matrix
Expand Down

0 comments on commit 6411407

Please sign in to comment.