diff --git a/onemath/sycl/blas/include/blas_meta.h b/onemath/sycl/blas/include/blas_meta.h index 9e813da..03b19c6 100644 --- a/onemath/sycl/blas/include/blas_meta.h +++ b/onemath/sycl/blas/include/blas_meta.h @@ -30,10 +30,10 @@ #ifdef BLAS_ENABLE_COMPLEX #define SYCL_EXT_ONEAPI_COMPLEX #include -#if __has_include() -#include +#if __has_include() +#include #else -#include +#include #endif #endif diff --git a/onemath/sycl/blas/src/operations/blas3/gemm_load_store_complex.hpp b/onemath/sycl/blas/src/operations/blas3/gemm_load_store_complex.hpp index 61977df..291fc16 100644 --- a/onemath/sycl/blas/src/operations/blas3/gemm_load_store_complex.hpp +++ b/onemath/sycl/blas/src/operations/blas3/gemm_load_store_complex.hpp @@ -116,12 +116,24 @@ class vec_complex { m_Data = *(Ptr + Offset * NumElements); } + // Load + template + void load(size_t Offset, const DataT *Ptr) { + m_Data = *(Ptr + Offset * NumElements); + } + // Store template void store(size_t Offset, sycl::multi_ptr Ptr) const { *(Ptr + Offset * NumElements) = m_Data; } + + // Store + template + void store(size_t Offset, DataT *Ptr) const { + *(Ptr + Offset * NumElements) = m_Data; + } }; /*! @brief Partial specialization of the Packetize class dedicated to diff --git a/onemath/sycl/blas/src/operations/blas3/gemm_local.hpp b/onemath/sycl/blas/src/operations/blas3/gemm_local.hpp index c4c2165..356da25 100644 --- a/onemath/sycl/blas/src/operations/blas3/gemm_local.hpp +++ b/onemath/sycl/blas/src/operations/blas3/gemm_local.hpp @@ -527,12 +527,31 @@ class Gemm( - 0, sycl::multi_ptr(reg)); - out_vec *= alpha_; - - out_vec.template store( - 0, sycl::multi_ptr(out_ptr)); + // This if-statement is necessary starting from late 2024 nightly, because + // an update made casting raw pointers of sycl::complex to multi_ptr + // ambiguous. + if constexpr (std::is_same_v< + element_t, + sycl::ext::oneapi::experimental::complex> || + std::is_same_v< + element_t, + sycl::ext::oneapi::experimental::complex>) { + out_vec.template load(0, reg); + out_vec *= alpha_; + + out_vec.template store(0, out_ptr); + } else { + out_vec.template load( + 0, sycl::multi_ptr(reg)); + out_vec *= alpha_; + + out_vec.template store( + 0, sycl::multi_ptr(out_ptr)); + } } /*! * @brief Store the computed gemm result to the C matrix