diff --git a/hi_scripting/scripting/api/ScriptingApiObjects.cpp b/hi_scripting/scripting/api/ScriptingApiObjects.cpp index 69e4863e66..43954c0c07 100644 --- a/hi_scripting/scripting/api/ScriptingApiObjects.cpp +++ b/hi_scripting/scripting/api/ScriptingApiObjects.cpp @@ -5057,6 +5057,7 @@ struct ScriptingObjects::ScriptNeuralNetwork::Wrapper API_METHOD_WRAPPER_0(ScriptNeuralNetwork, getModelJSON); API_VOID_METHOD_WRAPPER_1(ScriptNeuralNetwork, loadTensorFlowModel); API_VOID_METHOD_WRAPPER_1(ScriptNeuralNetwork, loadPytorchModel); + API_VOID_METHOD_WRAPPER_1(ScriptNeuralNetwork, loadNAMModel); API_METHOD_WRAPPER_1(ScriptNeuralNetwork, createModelJSONFromTextFile); API_METHOD_WRAPPER_2(ScriptNeuralNetwork, loadOnnxModel); API_METHOD_WRAPPER_3(ScriptNeuralNetwork, processFFTSpectrum); @@ -5073,6 +5074,7 @@ ScriptingObjects::ScriptNeuralNetwork::ScriptNeuralNetwork(ProcessorWithScriptin ADD_API_METHOD_1(createModelJSONFromTextFile); ADD_API_METHOD_1(loadTensorFlowModel); ADD_API_METHOD_1(loadPytorchModel); + ADD_API_METHOD_1(loadNAMModel); ADD_API_METHOD_0(getModelJSON); ADD_API_METHOD_2(loadOnnxModel); ADD_API_METHOD_3(processFFTSpectrum); @@ -5296,6 +5298,16 @@ void ScriptingObjects::ScriptNeuralNetwork::loadPytorchModel(const var& modelJSO #endif } +void ScriptingObjects::ScriptNeuralNetwork::loadNAMModel(const var& modelJSON) +{ +#if HISE_INCLUDE_RT_NEURAL + nn->loadNAMModel(modelJSON); + postBuild(); +#else + reportScriptError("You must enable HISE_INCLUDE_RT_NEURAL"); +#endif +} + bool ScriptingObjects::ScriptNeuralNetwork::loadOnnxModel(const var& base64Data, int numOutputs) { if(onnx == nullptr) diff --git a/hi_scripting/scripting/api/ScriptingApiObjects.h b/hi_scripting/scripting/api/ScriptingApiObjects.h index fd5809310a..f230e94a64 100644 --- a/hi_scripting/scripting/api/ScriptingApiObjects.h +++ b/hi_scripting/scripting/api/ScriptingApiObjects.h @@ -1606,6 +1606,9 @@ namespace ScriptingObjects /** Loads the model layout and weights from a Pytorch model JSON. */ void loadPytorchModel(const var& modelJSON); + /** Loads the model from a NAM file. */ + void loadNAMModel(const var& modelJSON); + /** Loads the ONNX runtime model for spectral analysis. */ bool loadOnnxModel(const var& base64Data, int numOutputValues); diff --git a/hi_scripting/scripting/scriptnode/ui/NodeComponent.cpp b/hi_scripting/scripting/scriptnode/ui/NodeComponent.cpp index 9cda75a8be..86fe2ee7e3 100644 --- a/hi_scripting/scripting/scriptnode/ui/NodeComponent.cpp +++ b/hi_scripting/scripting/scriptnode/ui/NodeComponent.cpp @@ -662,7 +662,7 @@ void NodeComponent::handlePopupMenuResult(int result) if (wType == 2) { - auto id = node->getId(); + auto name = snex::cppgen::Helpers::getValidCppVariableName(node->getName()); struct ConnectionState { @@ -724,23 +724,23 @@ void NodeComponent::handlePopupMenuResult(int result) c.removeOldConnection(node.get()); } - if (id == node->getPath().getIdentifier().toString()) + if (name == node->getPath().getIdentifier().toString()) { - id = PresetHandler::getCustomName(id, "Enter a customized name for the node"); + name = PresetHandler::getCustomName(name, "Enter a customized name for the node"); } - String newId = id + "_"; + String newId = name + "_"; node->setValueTreeProperty(PropertyIds::ID, newId); - PopupHelpers::wrapIntoChain(node.get(), MenuActions::WrapIntoChain, id); + PopupHelpers::wrapIntoChain(node.get(), MenuActions::WrapIntoChain, name); auto pn = node->getParentNode(); pn->getValueTree().setProperty(PropertyIds::ShowParameters, true, node->getUndoManager()); if (auto modNode = dynamic_cast(node.get())) { - String pmodId = id + "_pm"; + String pmodId = name + "_pm"; var pmodvar = node->getRootNetwork()->create("routing.public_mod", pmodId); auto pmod = dynamic_cast(pmodvar.getObject()); @@ -1035,7 +1035,9 @@ void NodeComponent::PopupHelpers::wrapIntoNetwork(NodeBase* node, bool makeCompi for (int i = 0; i < rootTree.getNumProperties(); i++) nData.setProperty(rootTree.getPropertyName(i), rootTree.getProperty(rootTree.getPropertyName(i)), nullptr); - nData.setProperty(PropertyIds::ID, node->getId(), nullptr); + auto name = snex::cppgen::Helpers::getValidCppVariableName(node->getName()); + + nData.setProperty(PropertyIds::ID, name, nullptr); nData.addChild(node->getValueTree().createCopy(), -1, nullptr); auto ndir = BackendDllManager::getSubFolder(node->getScriptProcessor()->getMainController_(), BackendDllManager::FolderSubType::Networks); @@ -1121,15 +1123,8 @@ void NodeComponent::PopupHelpers::wrapIntoChain(NodeBase* node, MenuActions resu auto parent = selection.getFirst()->getValueTree().getParent(); auto nIndex = parent.indexOf(selection.getFirst()->getValueTree()); - - for (auto n : selection) - { n->setParent(newContainer, -1); - - //n->getValueTree().getParent().removeChild(n->getValueTree(), um); - //containerTree.getChildWithName(PropertyIds::Nodes).addChild(n->getValueTree(), -1, um); - } parent.addChild(containerTree, nIndex, um); } diff --git a/hi_tools/hi_neural/RTNeural/RTNeural/conv1d/strided_conv1d.h b/hi_tools/hi_neural/RTNeural/RTNeural/conv1d/strided_conv1d.h new file mode 100644 index 0000000000..72b2f1010b --- /dev/null +++ b/hi_tools/hi_neural/RTNeural/RTNeural/conv1d/strided_conv1d.h @@ -0,0 +1,210 @@ +#pragma once + +#include "conv1d.h" + +namespace RTNEURAL_NAMESPACE +{ +/** + * Dynamic implementation of a 1-dimensional convolutional layer + * with strides. + * + * Internally, this is just a wrapper around the Conv1D layer. + */ +template +class StridedConv1D final : public Layer +{ +public: + /** + * Constructs a strided convolution layer for the given dimensions. + * + * @param in_size: the input size for the layer + * @param out_size: the output size for the layer + * @param kernel_size: the size of the convolution kernel + * @param dilation: the dilation rate to use for dilated convolution + * @param stride: the stride of the convolution + */ + StridedConv1D(int in_size, int out_size, int kernel_size, int dilation, int stride, int groups = 1) + : Layer(in_size, out_size) + , internal(in_size, out_size, kernel_size, dilation, groups) + , stride(stride) + { + skip_output.resize(out_size, T {}); + } + + StridedConv1D(std::initializer_list sizes) + : StridedConv1D(*sizes.begin(), *(sizes.begin() + 1), *(sizes.begin() + 2), + *(sizes.begin() + 3), *(sizes.begin() + 4), *(sizes.begin() + 5)) + { + } + + StridedConv1D(const StridedConv1D& other) = default; + StridedConv1D& operator=(const StridedConv1D& other) = default; + + /** Resets the layer state. */ + RTNEURAL_REALTIME void reset() override + { + strides_counter = 0; + std::fill(std::begin(skip_output), std::end(skip_output), T {}); + internal.reset(); + } + + /** Returns the name of this layer. */ + std::string getName() const noexcept override { return "strided_conv1d"; } + + /** Performs a stride step for this layer. */ + RTNEURAL_REALTIME inline void skip(const T* input) + { + internal.skip(input); + } + + /** Performs forward propagation for this layer. */ + RTNEURAL_REALTIME inline void forward(const T* input, T* h) noexcept override + { + if(strides_counter == 0) + { + internal.forward(input, h); + std::copy(h, h + Layer::out_size, std::begin(skip_output)); + } + else + { + internal.skip(input); + std::copy(std::begin(skip_output), std::end(skip_output), h); + } + + strides_counter = (strides_counter == stride - 1) ? 0 : strides_counter + 1; + } + + /** + * Sets the layer weights. + * + * The weights vector must have size weights[out_size][in_size][kernel_size * dilation] + */ + RTNEURAL_REALTIME void setWeights(const std::vector>>& weights) + { + internal.setWeights(weights); + } + + /** + * Sets the layer biases. + * + * The bias vector must have size bias[out_size] + */ + RTNEURAL_REALTIME void setBias(const std::vector& biasVals) + { + internal.setBias(biasVals); + } + + /** Returns the size of the convolution kernel. */ + RTNEURAL_REALTIME int getKernelSize() const noexcept { return internal.getKernelSize(); } + + /** Returns the convolution dilation rate. */ + RTNEURAL_REALTIME int getDilationRate() const noexcept { return internal.getDilationRate(); } + + /** Returns the number of "groups" in the convolution. */ + int getGroups() const noexcept { return internal.getGroups(); } + +private: + Conv1D internal; + + const int stride; + int strides_counter = 0; + std::vector skip_output {}; +}; + +//==================================================== +/** + * Static implementation of a 1-dimensional convolution layer + * with strides. + * + * Internally, this is just a wrapper around the Conv1DT layer. + * + * @param in_sizet: the input size for the layer + * @param out_sizet: the output size for the layer + * @param kernel_size: the size of the convolution kernel + * @param dilation_rate: the dilation rate to use for dilated convolution + * @param stride: the stride of the convolution + * @param groups: controls connections between inputs and outputs + * @param dynamic_state: use dynamically allocated layer state + */ +template +class StridedConv1DT +{ + Conv1DT internal; + + int strides_counter = 0; + +public: + static constexpr auto in_size = in_sizet; + static constexpr auto out_size = out_sizet; + static constexpr auto filters_per_group = in_size / groups; + static constexpr auto channels_per_group = out_size / groups; + + StridedConv1DT() + : outs(internal.outs) + { + } + + /** Returns the name of this layer. */ + std::string getName() const noexcept { return "strided_conv1d"; } + + /** Returns false since convolution is not an activation layer. */ + constexpr bool isActivation() const noexcept { return false; } + + /** Resets the layer state. */ + RTNEURAL_REALTIME void reset() + { + internal.reset(); + } + + /** Performs a stride step for this layer. */ + template + RTNEURAL_REALTIME inline void skip(const Inputs& ins) noexcept + { + internal.skip(ins); + } + + /** Performs forward propagation for this layer. */ + template + RTNEURAL_REALTIME inline void forward(const Inputs& ins) noexcept + { + if(strides_counter == 0) + internal.forward(ins); + else + internal.skip(ins); + + strides_counter = (strides_counter == stride - 1) ? 0 : strides_counter + 1; + } + + /** + * Sets the layer weights. + * + * The weights vector must have size weights[out_size][group_count][kernel_size * dilation] + */ + RTNEURAL_REALTIME void setWeights(const std::vector>>& weights) + { + internal.setWeights(weights); + } + + /** + * Sets the layer biases. + * + * The bias vector must have size bias[out_size] + */ + RTNEURAL_REALTIME void setBias(const std::vector& biasVals) + { + internal.setBias(biasVals); + } + + /** Returns the size of the convolution kernel. */ + RTNEURAL_REALTIME int getKernelSize() const noexcept { return kernel_size; } + + /** Returns the convolution dilation rate. */ + RTNEURAL_REALTIME int getDilationRate() const noexcept { return dilation_rate; } + + /** Returns the number of "groups" in the convolution. */ + int getGroups() const noexcept { return groups; } + + /** Reference to the internal layer weights. */ + decltype(internal.outs)& outs; +}; +} diff --git a/hi_tools/hi_neural/RTNeural/RTNeural/wavenet/wavenet_layer.hpp b/hi_tools/hi_neural/RTNeural/RTNeural/wavenet/wavenet_layer.hpp new file mode 100644 index 0000000000..d8a9db566f --- /dev/null +++ b/hi_tools/hi_neural/RTNeural/RTNeural/wavenet/wavenet_layer.hpp @@ -0,0 +1,107 @@ +#pragma once + + + +namespace wavenet +{ +template > +// TODO: gated? +struct Wavenet_Layer +{ + RTNeural::Conv1DT conv; + RTNeural::DenseT input_mixin; + RTNeural::DenseT _1x1; + Activation activation; + +#if RTNEURAL_USE_EIGEN + Eigen::Matrix outs; +#elif RTNEURAL_USE_XSIMD + xsimd::batch outs[RTNeural::ceil_div (channels, (int) xsimd::batch::size)]; +#endif + + void reset() + { + conv.reset(); + } + + void load_weights (std::vector::iterator& weights) + { + conv.reset(); + + std::vector>> conv_weights (channels, std::vector> (channels, std::vector (kernel_size))); + for (int i = 0; i < channels; ++i) + for (int j = 0; j < channels; ++j) + for (int k = 0; k < kernel_size; k++) + conv_weights[i][j][k] = *(weights++); + RTNeural::torch_helpers::detail::reverseKernels (conv_weights); + conv.setWeights (conv_weights); + + std::vector conv_bias (channels); + for (int i = 0; i < channels; ++i) + conv_bias[i] = *(weights++); + conv.setBias (conv_bias); + + std::vector> input_mixin_weights (channels, std::vector (condition_size)); + for (int i = 0; i < channels; i++) + for (int j = 0; j < condition_size; j++) + input_mixin_weights[i][j] = *(weights++); + input_mixin.setWeights (input_mixin_weights); + + std::vector> _1x1_weights (channels, std::vector (channels)); + for (int i = 0; i < channels; i++) + for (int j = 0; j < channels; j++) + _1x1_weights[i][j] = *(weights++); + _1x1.setWeights (_1x1_weights); + + std::vector _1x1_bias (channels); + for (int i = 0; i < channels; i++) + _1x1_bias[i] = *(weights++); + _1x1.setBias (_1x1_bias.data()); + } + +#if RTNEURAL_USE_EIGEN + void forward (const Eigen::Matrix& ins, + const Eigen::Matrix& condition, + Eigen::Map, RTNeural::RTNeuralEigenAlignment>& head_io) +#elif RTNEURAL_USE_XSIMD + void forward (const xsimd::batch (&ins)[RTNeural::ceil_div (channels, (int) xsimd::batch::size)], + const xsimd::batch (&condition)[RTNeural::ceil_div (condition_size, (int) xsimd::batch::size)], + xsimd::batch (&head_io)[RTNeural::ceil_div (channels, (int) xsimd::batch::size)]) +#endif + { + conv.forward (ins); + input_mixin.forward (condition); + +#if RTNEURAL_USE_EIGEN + outs = conv.outs + input_mixin.outs; +#elif RTNEURAL_USE_XSIMD + for (int i = 0; i < std::size (outs); ++i) + outs[i] = conv.outs[i] + input_mixin.outs[i]; +#endif + + activation.forward (outs); + +#if RTNEURAL_USE_EIGEN + head_io.noalias() += activation.outs; +#elif RTNEURAL_USE_XSIMD + for (int i = 0; i < std::size (head_io); ++i) + head_io[i] += activation.outs[i]; +#endif + + _1x1.forward (activation.outs); + +#if RTNEURAL_USE_EIGEN + outs = ins + _1x1.outs; +#elif RTNEURAL_USE_XSIMD + for (int i = 0; i < std::size (outs); ++i) + outs[i] = ins[i] + _1x1.outs[i]; +#endif + } +}; +} // namespace wavenet diff --git a/hi_tools/hi_neural/RTNeural/RTNeural/wavenet/wavenet_layer_array.hpp b/hi_tools/hi_neural/RTNeural/RTNeural/wavenet/wavenet_layer_array.hpp new file mode 100644 index 0000000000..ea3428bce0 --- /dev/null +++ b/hi_tools/hi_neural/RTNeural/RTNeural/wavenet/wavenet_layer_array.hpp @@ -0,0 +1,105 @@ +#pragma once + +#include "wavenet_layer.hpp" + +namespace wavenet +{ +template +using Dilations = std::integer_sequence; + +template +struct Layer_Array +{ + template + struct Layers_Helper + { + }; + + template + struct Layers_Helper> + { + using type = std::tuple...>; + }; + + using Layers = typename Layers_Helper::type; + + static constexpr auto n_channels = channels; + + RTNeural::DenseT rechannel; // no bias! + Layers layers; + static constexpr auto num_layers = std::tuple_size_v; + RTNeural::DenseT head_rechannel; + + using Last_Layer_Type = std::remove_reference_t - 1> (layers))>; + decltype (Last_Layer_Type::outs)& layer_outputs { std::get - 1> (layers).outs }; + decltype (RTNeural::DenseT::outs)& head_outputs { head_rechannel.outs }; + + void reset() + { + RTNeural::modelt_detail::forEachInTuple ([] (auto& layer, size_t) + { layer.reset(); }, + layers); + } + + void load_weights (std::vector::iterator& weights) + { + std::vector> rechannel_weights (channels, std::vector (in_size)); + for (int i = 0; i < channels; i++) + for (int j = 0; j < in_size; j++) + rechannel_weights[i][j] = *(weights++); + rechannel.setWeights (rechannel_weights); + + RTNeural::modelt_detail::forEachInTuple ([&weights] (auto& layer, size_t) + { layer.load_weights (weights); }, + layers); + + std::vector> head_rechannel_weights (head_size, std::vector (channels)); + for (int i = 0; i < head_size; i++) + for (int j = 0; j < channels; j++) + head_rechannel_weights[i][j] = *(weights++); + head_rechannel.setWeights (head_rechannel_weights); + + if constexpr (has_head_bias) + { + std::vector head_rechannel_bias (head_size); + for (int i = 0; i < head_size; i++) + head_rechannel_bias[i] = *(weights++); + head_rechannel.setBias (head_rechannel_bias.data()); + } + } + +#if RTNEURAL_USE_EIGEN + void forward (const Eigen::Matrix& ins, + const Eigen::Matrix& condition, + Eigen::Map, RTNeural::RTNeuralEigenAlignment>& head_io) +#elif RTNEURAL_USE_XSIMD + void forward (const xsimd::batch (&ins)[RTNeural::ceil_div (in_size, (int) xsimd::batch::size)], + const xsimd::batch (&condition)[RTNeural::ceil_div (condition_size, (int) xsimd::batch::size)], + xsimd::batch (&head_io)[RTNeural::ceil_div (channels, (int) xsimd::batch::size)]) +#endif + { + rechannel.forward (ins); + + RTNeural::modelt_detail::forEachInTuple ( + [&] (auto& layer, auto index_t) + { + static constexpr size_t index = index_t; + if constexpr (index == 0) + layer.forward (rechannel.outs, condition, head_io); + else + layer.forward (std::get (layers).outs, condition, head_io); + }, + layers); + + head_rechannel.forward (head_io); + } +}; +} // namespace wavenet diff --git a/hi_tools/hi_neural/RTNeural/RTNeural/wavenet/wavenet_model.hpp b/hi_tools/hi_neural/RTNeural/RTNeural/wavenet/wavenet_model.hpp new file mode 100644 index 0000000000..df840f96d1 --- /dev/null +++ b/hi_tools/hi_neural/RTNeural/RTNeural/wavenet/wavenet_model.hpp @@ -0,0 +1,92 @@ +#pragma once + +#include "wavenet_layer_array.hpp" + +namespace wavenet +{ +template +struct Wavenet_Model +{ + std::tuple layer_arrays; + + static constexpr auto head_layer_n_channels = std::tuple_element_t<0, std::tuple>::n_channels; + +#if RTNEURAL_USE_EIGEN + Eigen::Matrix head_input {}; +#elif RTNEURAL_USE_XSIMD + xsimd::batch head_input[RTNeural::ceil_div (head_layer_n_channels, (int) xsimd::batch::size)]; +#endif + T head_scale = (T) 0; + + Wavenet_Model() = default; + + void prewarm() + { + RTNeural::modelt_detail::forEachInTuple ( + [] (auto& layer, size_t) + { + layer.reset(); + }, + layer_arrays); + for (int i = 0; i < 1 << 14; ++i) + forward (0.0f); + } + + void load_weights (const nlohmann::json& model_json) + { + std::vector model_weights = model_json.at ("weights"); + auto weights_iterator = model_weights.begin(); + RTNeural::modelt_detail::forEachInTuple ( + [&weights_iterator] (auto& layer, size_t) + { + layer.load_weights (weights_iterator); + }, + layer_arrays); + + head_scale = *weights_iterator++; + + // Make sure we use the all of the weights exactly + assert (std::distance (model_weights.begin(), weights_iterator) == model_weights.size()); + } + + T forward (T input) noexcept + { +#if RTNEURAL_USE_EIGEN + const auto v_ins = Eigen::Matrix::Constant (input); +#elif RTNEURAL_USE_XSIMD + xsimd::batch v_ins[1]; + v_ins[0] = RTNeural::set_value (v_ins[0], 0, input); +#endif + RTNeural::modelt_detail::forEachInTuple ( + [this, v_ins] (auto& layer_array, auto index_t) + { + static constexpr size_t index = index_t; + if constexpr (index == 0) + { +#if RTNEURAL_USE_EIGEN + head_input.setZero(); + Eigen::Map, RTNeural::RTNeuralEigenAlignment> head_input_map { head_input.data() }; + std::get<0> (layer_arrays).forward (v_ins, v_ins, head_input_map); +#elif RTNEURAL_USE_XSIMD + std::fill (std::begin (head_input), std::end (head_input), xsimd::batch {}); + std::get<0> (layer_arrays).forward (v_ins, v_ins, head_input); +#endif + } + else + { + std::get (layer_arrays).forward (std::get (layer_arrays).layer_outputs, v_ins, std::get (layer_arrays).head_outputs); + } + }, + layer_arrays); + +#if RTNEURAL_USE_EIGEN + return std::get - 1> (layer_arrays).head_outputs[0] * head_scale; +#elif RTNEURAL_USE_XSIMD + return std::get - 1> (layer_arrays).head_outputs[0].get (0) * head_scale; + +#endif + } +}; +} // namespace wavenet diff --git a/hi_tools/hi_neural/RTNeural/modules/math_approx/math_approx.hpp b/hi_tools/hi_neural/RTNeural/modules/math_approx/math_approx.hpp new file mode 100644 index 0000000000..bb3a47c03d --- /dev/null +++ b/hi_tools/hi_neural/RTNeural/modules/math_approx/math_approx.hpp @@ -0,0 +1,17 @@ +#pragma once + +namespace math_approx +{ +} + +#include "src/basic_math.hpp" + +#include "src/trig_approx.hpp" +#include "src/inverse_trig_approx.hpp" +#include "src/pow_approx.hpp" +#include "src/log_approx.hpp" +#include "src/hyperbolic_trig_approx.hpp" +#include "src/inverse_hyperbolic_trig_approx.hpp" +#include "src/sigmoid_approx.hpp" +#include "src/wright_omega_approx.hpp" +#include "src/polylogarithm_approx.hpp" diff --git a/hi_tools/hi_neural/RTNeural/modules/math_approx/src/basic_math.hpp b/hi_tools/hi_neural/RTNeural/modules/math_approx/src/basic_math.hpp new file mode 100644 index 0000000000..de827d5173 --- /dev/null +++ b/hi_tools/hi_neural/RTNeural/modules/math_approx/src/basic_math.hpp @@ -0,0 +1,107 @@ +#pragma once + +// If MATH_APPROX_XSIMD_TARGET is not defined +// the user can still use XSIMD by manually including +// it before including the math_approx header. +#if MATH_APPROX_XSIMD_TARGET +#include +#endif + +#if ! defined(XSIMD_HPP) +#include +#endif + +#include +#include + +namespace math_approx +{ +template +struct scalar_of +{ + using type = T; +}; + +/** + * When T is a scalar floating-point type, scalar_of_t is T. + * When T is a SIMD floating-point type, scalar_of_t is the corresponding scalar type. + */ +template +using scalar_of_t = typename scalar_of::type; + +/** Inverse square root */ +template +T rsqrt (T x) +{ + // @TODO: figure out a way that we can make this method constexpr + + // sqrtss followed by divss... this seems to measure a bit faster than the rsqrtss plus NR iteration below + return (T) 1 / std::sqrt (x); + + // fast inverse square root (using rsqrtss hardware instruction), plus one Newton-Raphson iteration + // auto r = xsimd::rsqrt (xsimd::broadcast (x)).get (0); + // x *= r; + // x *= r; + // x += -3.0f; + // r *= -0.5f; + // return x * r; +} + +/** Function interface for the ternary operator. */ +template +T select (bool q, T t, T f) +{ + return q ? t : f; +} + +#if defined(XSIMD_HPP) +template +struct scalar_of> +{ + using type = T; +}; + +/** Inverse square root */ +template +xsimd::batch rsqrt (xsimd::batch x) +{ + using S = scalar_of_t; + auto r = xsimd::rsqrt (x); + x *= r; + x *= r; + x += (S) -3; + r *= (S) -0.5; + return x * r; +} + +/** Function interface for the ternary operator. */ +template +xsimd::batch select (xsimd::batch_bool q, xsimd::batch t, xsimd::batch f) +{ + return xsimd::select (q, t, f); +} +#endif + +#if ! __cpp_lib_bit_cast +// bit_cast requirement. +template +using is_bitwise_castable = std::integral_constant::value && std::is_trivially_copyable::value>; + +// compiler support is needed for bitwise copy with constexpr. +template +inline typename std::enable_if::value, To>::type bit_cast (const From& from) noexcept +{ + union U + { + U() {}; + char storage[sizeof (To)] {}; + typename std::remove_const::type dest; + } u; // instead of To dest; because To doesn't require DefaultConstructible. + std::memcpy (&u.dest, &from, sizeof from); + return u.dest; +} +#else +using std::bit_cast; +#endif +} // namespace math_approx diff --git a/hi_tools/hi_neural/RTNeural/modules/math_approx/src/hyperbolic_trig_approx.hpp b/hi_tools/hi_neural/RTNeural/modules/math_approx/src/hyperbolic_trig_approx.hpp new file mode 100644 index 0000000000..0f1960ddfc --- /dev/null +++ b/hi_tools/hi_neural/RTNeural/modules/math_approx/src/hyperbolic_trig_approx.hpp @@ -0,0 +1,138 @@ +#pragma once + +#include "pow_approx.hpp" + +namespace math_approx +{ +// ref: https://en.wikipedia.org/wiki/Hyperbolic_functions#Definitions +// sinh = (e^(2x) - 1) / (2e^x), cosh = (e^(2x) + 1) / (2e^x) +// let B = e^x, then sinh = (B^2 - 1) / (2B), cosh = (B^2 + 1) / (2B) +// simplifying, we get: sinh = 0.5 (B - 1/B), cosh = 0.5 (B + 1/B) + +/** Approximation of sinh(x), using exp(x) internally */ +template +constexpr T sinh (T x) +{ + using S = scalar_of_t; + auto B = exp (x); + auto Br = (S) 0.5 / B; + B *= (S) 0.5; + return B - Br; +} + +/** Approximation of cosh(x), using exp(x) internally */ +template +constexpr T cosh (T x) +{ + using S = scalar_of_t; + auto B = exp (x); + auto Br = (S) 0.5 / B; + B *= (S) 0.5; + return B + Br; +} + +/** + * Simultaneous pproximation of sinh(x) and cosh(x), + * using exp(x) internally. + * + * For more information see the comments above. + */ +template +constexpr auto sinh_cosh (T x) +{ + using S = scalar_of_t; + auto B = exp (x); + auto Br = (S) 0.5 / B; + B *= (S) 0.5; + + auto sinh = B - Br; + auto cosh = B + Br; + + return std::make_pair (sinh, cosh); +} + +namespace tanh_detail +{ + // See notebooks/tanh_approx.nb for the derivation of these polynomials + + template + constexpr T tanh_poly_11 (T x) + { + using S = scalar_of_t; + const auto x_sq = x * x; + const auto y_9_11 = (S) 2.63661358122e-6 + (S) 3.33765558362e-8 * x_sq; + const auto y_7_9_11 = (S) 0.000199027336899 + y_9_11 * x_sq; + const auto y_5_7_9_11 = (S) 0.00833223857843 + y_7_9_11 * x_sq; + const auto y_3_5_7_9_11 = (S) 0.166667159320 + y_5_7_9_11 * x_sq; + const auto y_1_3_5_7_9_11 = (S) 1 + y_3_5_7_9_11 * x_sq; + return x * y_1_3_5_7_9_11; + } + + template + constexpr T tanh_poly_9 (T x) + { + using S = scalar_of_t; + const auto x_sq = x * x; + const auto y_7_9 = (S) 0.000192218110330 + (S) 3.54808622170e-6 * x_sq; + const auto y_5_7_9 = (S) 0.00834777254865 + y_7_9 * x_sq; + const auto y_3_5_7_9 = (S) 0.166658873283 + y_5_7_9 * x_sq; + const auto y_1_3_5_7_9 = (S) 1 + y_3_5_7_9 * x_sq; + return x * y_1_3_5_7_9; + } + + template + constexpr T tanh_poly_7 (T x) + { + using S = scalar_of_t; + const auto x_sq = x * x; + const auto y_5_7 = (S) 0.00818199927912 + (S) 0.000243153287690 * x_sq; + const auto y_3_5_7 = (S) 0.166769941467 + y_5_7 * x_sq; + const auto y_1_3_5_7 = (S) 1 + y_3_5_7 * x_sq; + return x * y_1_3_5_7; + } + + template + constexpr T tanh_poly_5 (T x) + { + using S = scalar_of_t; + const auto x_sq = x * x; + const auto y_3_5 = (S) 0.165326984031 + (S) 0.00970240200826 * x_sq; + const auto y_1_3_5 = (S) 1 + y_3_5 * x_sq; + return x * y_1_3_5; + } + + template + constexpr T tanh_poly_3 (T x) + { + using S = scalar_of_t; + const auto x_sq = x * x; + const auto y_1_3 = (S) 1 + (S) 0.183428244899 * x_sq; + return x * y_1_3; + } +} // namespace tanh_detail + +/** + * Approximation of tanh(x), using tanh(x) ≈ p(x) / (p(x)^2 + 1), + * where p(x) is an odd polynomial fit to minimize the maxinimum relative error. + */ +template +T tanh (T x) +{ + static_assert (order % 2 == 1 && order <= 11 && order >= 3, "Order must e an odd number within [3, 11]"); + + T x_poly {}; + if constexpr (order == 11) + x_poly = tanh_detail::tanh_poly_11 (x); + else if constexpr (order == 9) + x_poly = tanh_detail::tanh_poly_9 (x); + else if constexpr (order == 7) + x_poly = tanh_detail::tanh_poly_7 (x); + else if constexpr (order == 5) + x_poly = tanh_detail::tanh_poly_5 (x); + else if constexpr (order == 3) + x_poly = tanh_detail::tanh_poly_3 (x); + + using S = scalar_of_t; + return x_poly * rsqrt (x_poly * x_poly + (S) 1); +} +} // namespace math_approx diff --git a/hi_tools/hi_neural/RTNeural/modules/math_approx/src/inverse_hyperbolic_trig_approx.hpp b/hi_tools/hi_neural/RTNeural/modules/math_approx/src/inverse_hyperbolic_trig_approx.hpp new file mode 100644 index 0000000000..78df2edb08 --- /dev/null +++ b/hi_tools/hi_neural/RTNeural/modules/math_approx/src/inverse_hyperbolic_trig_approx.hpp @@ -0,0 +1,108 @@ +#pragma once + +#include "basic_math.hpp" +#include "log_approx.hpp" + +namespace math_approx +{ +struct AsinhLog2Provider +{ + // for polynomial derivations, see notebooks/asinh_approx.nb + + /** approximation for log2(x), optimized on the range [1, 2], to be used within an asinh(x) computation */ + template + static constexpr T log2_approx (T x) + { + static_assert (order >= 3 && order <= 5); + using S = scalar_of_t; + + const auto x_sq = x * x; + if constexpr (order == 3) + { + const auto x_2_3 = (S) -1.21535595794871 + (S) 0.194363894384581 * x; + const auto x_0_1 = (S) -2.26452854958994 + (S) 3.28552061315407 * x; + return x_0_1 + x_2_3 * x_sq; + } + else if constexpr (order == 4) + { + const auto x_3_4 = (S) 0.770443387059628 + (S) -0.102652345633016 * x; + const auto x_1_2 = (S) 4.33013912645867 + (S) -2.39448588379361 * x; + const auto x_1_2_3_4 = x_1_2 + x_3_4 * x_sq; + return (S) -2.60344428409168 + x_1_2_3_4 * x; + } + else if constexpr (order == 5) + { + const auto x_4_5 = (S) -0.511946284688366 + (S) 0.0578217518982235 * x; + const auto x_2_3 = (S) -3.94632584968643 + (S) 1.90796087279737 * x; + const auto x_0_1 = (S) -2.87748189127908 + (S) 5.36997140095829 * x; + const auto x_2_3_4_5 = x_2_3 + x_4_5 * x_sq; + return x_0_1 + x_2_3_4_5 * x_sq; + } + else + { + return {}; + } + } +}; + +/** + * Approximation of asinh(x) in the full range, using identity + * asinh(x) = log(x + sqrt(x^2 + 1)). + * + * Orders 6 and 7 use an additional Newton-Raphson iteration, + * but for most cases the accuracy improvement is not worth + * the additional cost (when compared to the performance and + * accuracy achieved by the STL implementation). + */ +template +constexpr T asinh (T x) +{ + using S = scalar_of_t; + using std::abs, std::sqrt; +#if defined(XSIMD_HPP) + using xsimd::abs, xsimd::sqrt; +#endif + + const auto sign = select (x > (S) 0, (T) (S) 1, select (x < (S) 0, (T) (S) -1, (T) (S) 0)); + x = abs (x); + + const auto log_arg = x + sqrt (x * x + (S) 1); + auto y = log>, std::min (order, 5), false, AsinhLog2Provider> (log_arg); + + if constexpr (order > 5) + { + const auto exp_y = math_approx::exp (y); + y -= (exp_y - log_arg) / exp_y; + } + + return sign * y; +} + +/** + * Approximation of acosh(x) in the full range, using identity + * acosh(x) = log(x + sqrt(x^2 - 1)). + */ +template +constexpr T acosh (T x) +{ + using S = scalar_of_t; + using std::sqrt; +#if defined(XSIMD_HPP) + using xsimd::sqrt; +#endif + + const auto z1 = x + sqrt (x * x - (S) 1); + return log (z1); +} + +/** + * Approximation of atanh(x), using identity + * atanh(x) = (1/2) log((x + 1) / (x - 1)). + */ +template +constexpr T atanh (T x) +{ + using S = scalar_of_t; + return (S) 0.5 * log (((S) 1 + x) / ((S) 1 - x)); +} +} // namespace math_approx diff --git a/hi_tools/hi_neural/RTNeural/modules/math_approx/src/inverse_trig_approx.hpp b/hi_tools/hi_neural/RTNeural/modules/math_approx/src/inverse_trig_approx.hpp new file mode 100644 index 0000000000..9ca85c7cb1 --- /dev/null +++ b/hi_tools/hi_neural/RTNeural/modules/math_approx/src/inverse_trig_approx.hpp @@ -0,0 +1,185 @@ +#pragma once + +#include "basic_math.hpp" + +namespace math_approx +{ +namespace inv_trig_detail +{ + // for polynomial derivations, see notebooks/asin_acos_approx.nb + + template + constexpr T asin_kernel (T x) + { + using S = scalar_of_t; + static_assert (order >= 1 && order <= 4); + + if constexpr (order == 1) + { + return (S) 0.16443531037029196495 + x * (S) 0.097419577664394046979; + } + else if constexpr (order == 2) + { + return (S) 0.16687742065041710759 + x * ((S) 0.070980446338571381859 + x * (S) 0.066682760821292624831); + } + else if constexpr (order == 3) + { + return (S) 0.16665080061757006624 + x * ((S) 0.075508850204912977833 + x * ((S) 0.039376231206556484843 + x * (S) 0.051275338699694958389)); + } + else if constexpr (order == 4) + { + return (S) 0.16666803275183153521 + x * ((S) 0.074936964020844071266 + x * ((S) 0.045640288439217274741 + x * ((S) 0.023435504410713306478 + x * (S) 0.043323710842752508055))); + } + else + { + return {}; + } + } + + template + constexpr T acos_kernel (T x) + { + using S = scalar_of_t; + static_assert (order >= 1 && order <= 5); + + if constexpr (order == 1) + { + return (S) 0.061454830783555181029 + x * (S) 0.50934149601134137697; + } + else if constexpr (order == 2) + { + return (S) 0.18188825560430002537 + x * ((S) -0.092825628092384385170 + x * (S) 0.48173369928298098719); + } + else if constexpr (order == 3) + { + return (S) 0.16480511788348814473 + x * ((S) 0.11286070199090997290 + x * ((S) -0.18795205899643871450 + x * (S) 0.48108256591693704385)); + } + else if constexpr (order == 4) + { + return (S) 0.16687235373875186628 + x * ((S) 0.068412956842158992310 + x * ((S) 0.11466969910945928879 + x * ((S) -0.27433862418620241774 + x * (S) 0.49517994129072917531))); + } + else if constexpr (order == 5) + { + return (S) 0.16664924406383360700 + x * ((S) 0.075837825275592588015 + x * ((S) 0.030665158374004904823 + x * ((S) 0.13572846625592635550 + x * ((S) -0.34609357317006372856 + x * (S) 0.50800920599560273061)))); + } + else + { + return {}; + } + } + + // for polynomial derivations, see notebooks/arctan_approx.nb + + template + constexpr T atan_kernel (T x) + { + using S = scalar_of_t; + static_assert (order >= 4 && order <= 7); + + if constexpr (order == 4) + { + const auto x_sq = x * x; + const auto num = x + x_sq * (S) 0.498001992540; + const auto den = (S) 1 + x * (S) 0.481844539675 + x_sq * (S) 0.425470835319; + return num / den; + } + else if constexpr (order == 5 || order == 6) + { + const auto x_sq = x * x; + const auto num = (S) 0.177801521472 + x * (S) 0.116983970701; + const auto den = (S) 1 + x * (S) 0.174763903018 + x_sq * (S) 0.473808187566; + return (x + x_sq * num) / den; + } + else if constexpr (order == 7) + { + const auto x_sq = x * x; + const auto num = (S) 0.274959104817 + (S) 0.351814748865 * x + (S) -0.0395798531406 * x_sq; + const auto den = (S) 1 + x * ((S) 0.275079063405 + x * ((S) 0.683311392128 + x * (S) 0.0624877111229)); + return (x + x_sq * num) / den; + } + else + { + return {}; + } + } +} // namespace inv_trig_detail + +/** + * Approximation of asin(x) using asin(x) ≈ p(x^2) * x^3 + x for x in [0, 0.5], + * and asin(x) ≈ pi/2 - p((1-x)/2) * ((1-x)/2)^3/2 + ((1-x)/2)^1/2 for x in [0.5, 1], + * where p(x) is a polynomial fit to achieve the minimum absolute error. + */ +template +T asin (T x) +{ + using S = scalar_of_t; + + using std::abs, std::sqrt; +#if defined(XSIMD_HPP) + using xsimd::abs, xsimd::sqrt; +#endif + + const auto abs_x = abs (x); + + const auto reflect = abs_x > (S) 0.5; + auto z0 = select (reflect, (S) 0.5 * ((S) 1 - abs_x), abs_x * abs_x); + + auto x2 = select (reflect, sqrt (z0), abs_x); + auto z1 = inv_trig_detail::asin_kernel (z0); + + auto z2 = z1 * (z0 * x2) + x2; + auto res = select (reflect, (S) M_PI_2 - (z2 + z2), z2); + return select (x > (S) 0, res, -res); +} + +/** + * Approximation of acos(x) using the same approach as asin(x), + * but with a different polynomial fit. + */ +template +T acos (T x) +{ + using S = scalar_of_t; + + using std::abs, std::sqrt; +#if defined(XSIMD_HPP) + using xsimd::abs, xsimd::sqrt; +#endif + + const auto abs_x = abs (x); + + const auto reflect = abs_x > (S) 0.5; + auto z0 = select (reflect, (S) 0.5 * ((S) 1 - abs_x), abs_x * abs_x); + + auto x2 = select (reflect, sqrt (z0), abs_x); + auto z1 = inv_trig_detail::acos_kernel (z0); + + auto z2 = z1 * (z0 * x2) + x2; + auto res = select (reflect, (S) M_PI_2 - (z2 + z2), z2); + return (S) M_PI_2 - select (x > (S) 0, res, -res); +} + +/** + * Approximation of atan(x) using a polynomial approximation of arctan(x) on [0, 1], + * and arctan(x) = pi/2 - arctan(1/x) for x > 1. + */ +template +T atan (T x) +{ + using S = scalar_of_t; + + using std::abs, std::sqrt; +#if defined(XSIMD_HPP) + using xsimd::abs, xsimd::sqrt; +#endif + + const auto abs_x = abs (x); + const auto reflect = abs_x > (S) 1; + + const auto z = select (reflect, (S) 1 / abs_x, abs_x); + const auto atan_01 = inv_trig_detail::atan_kernel (z); + + const auto res = select (reflect, (S) M_PI_2 - atan_01, atan_01); + return select (x > (S) 0, res, -res); +} +} // namespace math_approx diff --git a/hi_tools/hi_neural/RTNeural/modules/math_approx/src/log_approx.hpp b/hi_tools/hi_neural/RTNeural/modules/math_approx/src/log_approx.hpp new file mode 100644 index 0000000000..d58ec4e207 --- /dev/null +++ b/hi_tools/hi_neural/RTNeural/modules/math_approx/src/log_approx.hpp @@ -0,0 +1,204 @@ +#pragma once + +#include "basic_math.hpp" +#include "pow_approx.hpp" + +namespace math_approx +{ +namespace log_detail +{ + struct Log2Provider + { + // for polynomial derivations, see notebooks/log_approx.nb + + /** approximation for log2(x), optimized on the range [1, 2] */ + template + static constexpr T log2_approx (T x) + { + static_assert (order >= 3 && order <= 6); + using S = scalar_of_t; + + const auto x_sq = x * x; + if constexpr (C1_continuous) + { + if constexpr (order == 3) + { + const auto x_2_3 = (S) -1.09886528622 + (S) 0.164042561333 * x; + const auto x_0_1 = (S) -2.21347520444 + (S) 3.14829792933 * x; + return x_0_1 + x_2_3 * x_sq; + } + else if constexpr (order == 4) + { + const auto x_3_4 = (S) 0.671618567027 + (S) -0.0845960009489 * x; + const auto x_1_2 = (S) 4.16344994072 + (S) -2.19861329856 * x; + const auto x_1_2_3_4 = x_1_2 + x_3_4 * x_sq; + return (S) -2.55185920824 + x_1_2_3_4 * x; + } + else if constexpr (order == 5) + { + const auto x_4_5 = (S) -0.432338320780 + (S) 0.0464481811023 * x; + const auto x_2_3 = (S) -3.65368350361 + (S) 1.68976432066 * x; + const auto x_0_1 = (S) -2.82807214111 + (S) 5.17788146374 * x; + const auto x_2_3_4_5 = x_2_3 + x_4_5 * x_sq; + return x_0_1 + x_2_3_4_5 * x_sq; + } + else if constexpr (order == 6) + { + const auto x_5_6 = (S) 0.284794437502 + (S) -0.0265448504094 * x; + const auto x_3_4 = (S) 3.38542517475 + (S) -1.31007090775 * x; + const auto x_1_2 = (S) 6.19242937536 + (S) -5.46521465640 * x; + const auto x_3_4_5_6 = x_3_4 + x_5_6 * x_sq; + const auto x_1_2_3_4_5_6 = x_1_2 + x_3_4_5_6 * x_sq; + return (S) -3.06081857306 + x_1_2_3_4_5_6 * x; + } + else + { + return {}; + } + } + else + { + if constexpr (order == 3) + { + const auto x_2_3 = (S) -1.05974531422 + (S) 0.159220010975 * x; + const auto x_0_1 = (S) -2.16417056258 + (S) 3.06469586582 * x; + return x_0_1 + x_2_3 * x_sq; + } + else if constexpr (order == 4) + { + const auto x_3_4 = (S) 0.649709537672 + (S) -0.0821303550902 * x; + const auto x_1_2 = (S) 4.08637809379 + (S) -2.13412984371 * x; + const auto x_1_2_3_4 = x_1_2 + x_3_4 * x_sq; + return (S) -2.51982743265 + x_1_2_3_4 * x; + } + else if constexpr (order == 5) + { + const auto x_4_5 = (S) -0.419319345483 + (S) 0.0451488402558 * x; + const auto x_2_3 = (S) -3.56885211615 + (S) 1.64139451414 * x; + const auto x_0_1 = (S) -2.80534277658 + (S) 5.10697088382 * x; + const auto x_2_3_4_5 = x_2_3 + x_4_5 * x_sq; + return x_0_1 + x_2_3_4_5 * x_sq; + } + else if constexpr (order == 6) + { + const auto x_5_6 = (S) 0.276834061071 + (S) -0.0258400886535 * x; + const auto x_3_4 = (S) 3.30388341157 + (S) -1.27446900713 * x; + const auto x_1_2 = (S) 6.12708086513 + (S) -5.36371998242 * x; + const auto x_3_4_5_6 = x_3_4 + x_5_6 * x_sq; + const auto x_1_2_3_4_5_6 = x_1_2 + x_3_4_5_6 * x_sq; + return (S) -3.04376925958 + x_1_2_3_4_5_6 * x; + } + else + { + return {}; + } + } + } + }; +} + +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" // these methods require some type-punning +#pragma GCC diagnostic ignored "-Wuninitialized" +#endif + +/** approximation for log(Base, x) (32-bit) */ +template +constexpr float log (float x) +{ + const auto vi = bit_cast (x); + const auto ex = vi & 0x7f800000; + const auto e = (ex >> 23) - 127; + const auto vfi = (vi - ex) | 0x3f800000; + const auto vf = bit_cast (vfi); + + constexpr auto log2_base_r = 1.0f / Base::log2_base; + return log2_base_r * ((float) e + Log2ProviderType::template log2_approx (vf)); +} + +/** approximation for log(x) (64-bit) */ +template +constexpr double log (double x) +{ + const auto vi = bit_cast (x); + const auto ex = vi & 0x7ff0000000000000; + const auto e = (ex >> 52) - 1023; + const auto vfi = (vi - ex) | 0x3ff0000000000000; + const auto vf = bit_cast (vfi); + + constexpr auto log2_base_r = 1.0 / Base::log2_base; + return log2_base_r * ((double) e + Log2ProviderType::template log2_approx (vf)); +} + +#if defined(XSIMD_HPP) +/** approximation for pow(Base, x) (32-bit SIMD) */ +template +xsimd::batch log (xsimd::batch x) +{ + const auto vi = xsimd::bit_cast> (x); + const auto ex = vi & 0x7f800000; + const auto e = (ex >> 23) - 127; + const auto vfi = (vi - ex) | 0x3f800000; + const auto vf = xsimd::bit_cast> (vfi); + + static constexpr auto log2_base_r = 1.0f / Base::log2_base; + return log2_base_r * (xsimd::to_float (e) + Log2ProviderType::template log2_approx, order, C1_continuous> (vf)); +} + +/** approximation for pow(Base, x) (64-bit SIMD) */ +template +xsimd::batch log (xsimd::batch x) +{ + const auto vi = xsimd::bit_cast> (x); + const auto ex = vi & 0x7ff0000000000000; + const auto e = (ex >> 52) - 1023; + const auto vfi = (vi - ex) | 0x3ff0000000000000; + const auto vf = xsimd::bit_cast> (vfi); + + static constexpr auto log2_base_r = 1.0 / Base::log2_base; + return log2_base_r * (xsimd::to_float (e) + Log2ProviderType::template log2_approx, order, C1_continuous> (vf)); +} +#endif + +#if defined(__GNUC__) +#pragma GCC diagnostic pop // end ignore strict-aliasing warnings +#endif + +/** + * Approximation of log(x), using + * log(x) = (1 / log2(e)) * (Exponent(x) + log2(1 + Mantissa(x)) + */ +template +constexpr T log (T x) +{ + return log>, order, C1_continuous> (x); +} + +/** + * Approximation of log2(x), using + * log2(x) = Exponent(x) + log2(1 + Mantissa(x) + */ +template +constexpr T log2 (T x) +{ + return log>, order, C1_continuous> (x); +} + +/** + * Approximation of log10(x), using + * log10(x) = (1 / log2(10)) * (Exponent(x) + log2(1 + Mantissa(x)) + */ +template +constexpr T log10 (T x) +{ + return log>, order, C1_continuous> (x); +} + +/** Approximation of log(1 + x), using math_approx::log(x) */ +template +constexpr T log1p (T x) +{ + return log>, order, C1_continuous> ((T) 1 + x); +} +} diff --git a/hi_tools/hi_neural/RTNeural/modules/math_approx/src/polylogarithm_approx.hpp b/hi_tools/hi_neural/RTNeural/modules/math_approx/src/polylogarithm_approx.hpp new file mode 100644 index 0000000000..a58d796863 --- /dev/null +++ b/hi_tools/hi_neural/RTNeural/modules/math_approx/src/polylogarithm_approx.hpp @@ -0,0 +1,225 @@ +#pragma once + +#include "basic_math.hpp" + +namespace math_approx +{ +/** + * Approximation of the "dilogarithm" function for inputs + * in the range [0, 1/2]. This method does not do any + * bounds-checking. + * + * Orders higher than 3 are generally not recommended for + * single-precision floating-point types, since they don't + * improve the accuracy very much. + * + * For derivations, see notebooks/li2_approx.nb + */ +template +constexpr T li2_0_half (T x) +{ + static_assert (order >= 1 && order <= 6); + using S = scalar_of_t; + + if constexpr (order == 1) + { + const auto n_0 = (S) 0.996460629617; + const auto d_0_1 = (S) 1 + (S) -0.288575624121 * x; + return x * n_0 / d_0_1; + } + else if constexpr (order == 2) + { + const auto n_0_1 = (S) 0.999994847641 + (S) -0.546961998015 * x; + const auto d_1_2 = (S) -0.797206910618 + (S) 0.0899936224040 * x; + const auto d_0_1_2 = (S) 1 + d_1_2 * x; + return x * n_0_1 / d_0_1_2; + } + else if constexpr (order == 3) + { + const auto x_sq = x * x; + const auto n_0_2 = (S) 0.999999991192 + (S) 0.231155739205 * x_sq; + const auto n_0_1_2 = n_0_2 + (S) -1.07612533343 * x; + const auto d_2_3 = (S) 0.451592861555 + (S) -0.0281544399023 * x; + const auto d_0_1 = (S) 1 + (S) -1.32612627824 * x; + const auto d_0_1_2_3 = d_0_1 + d_2_3 * x_sq; + return x * n_0_1_2 / d_0_1_2_3; + } + else if constexpr (order == 4) + { + const auto x_sq = x * x; + const auto n_2_3 = (S) 0.74425269014090502911555775982556365472 + (S) -0.08749607277005140673532964399704145939 * x; + const auto n_0_1 = (S) 0.99999999998544094594795118478024862055 + (S) -1.6098648159028159794757437744309391591 * x; + const auto n_0_1_2_3 = n_0_1 + n_2_3 * x_sq; + const auto d_3_4 = (S) -0.21787247785577362691148412819704459614 + (S) 0.00870385570778120787932426702624346169 * x; + const auto d_1_2 = (S) -1.85986481869406218896935179306183665107 + (S) 1.09810787318601772062220747277929300408 * x; + const auto d_1_2_3_4 = d_1_2 + d_3_4 * x_sq; + const auto d_0_1_2_3_4 = (S) 1 + d_1_2_3_4 * x; + return x * n_0_1_2_3 / d_0_1_2_3_4; + } + else if constexpr (order == 5) + { + const auto x_sq = x * x; + + const auto n_3_4 = (S) -0.41945653857264507277532555842378439927 + (S) 0.03140351694981020435408321943912212079 * x; + const auto n_1_2 = (S) -2.14843104749890205674150618938194330623 + (S) 1.54956546570292751217524363072830456069 * x; + const auto n_1_2_3_4 = n_1_2 + n_3_4 * x_sq; + const auto n_0_1_2_3_4 = (S) 0.99999999999997312289180148636206726177 + n_1_2_3_4 * x; + + const auto d_4_5 = (S) 0.09609912057603552016206051904306797162 + (S) -0.00269129500193871901659324657805482418 * x; + const auto d_2_3 = (S) 2.03806211686824385201410542913121040892 + (S) -0.72497973694183708484311198715866984035 * x; + const auto d_0_1 = (S) 1 + (S) -2.398431047506893407956406025441134862 * x; + const auto d_2_3_4_5 = d_2_3 + d_4_5 * x_sq; + const auto d_0_1_2_3_4_5 = d_0_1 + d_2_3_4_5 * x_sq; + + return x * n_0_1_2_3_4 / d_0_1_2_3_4_5; + } + else if constexpr (order == 6) + { + const auto x_sq = x * x; + + const auto n_4_5 = (S) 0.20885966267164674441979654645138181067 + (S) -0.01085968986663512120143497781484214416 * x; + const auto n_2_3 = (S) 2.64771686149306717256638234054408732899 + (S) -1.15385196641292513334184445301529897694 * x; + const auto n_0_1 = (S) 0.99999999999999995022522902211061062582 + (S) -2.6883902117841251600624689886592808124 * x; + const auto n_2_3_4_5 = n_2_3 + n_4_5 * x_sq; + const auto n_0_1_2_3_4_5 = n_0_1 + n_2_3_4_5 * x_sq; + + const auto d_5_6 = (S) -0.03980108270103465616851961097089502921 + (S) 0.00082742905522813187941384917520432493 * x; + const auto d_3_4 = (S) -1.70766499097900947314107956633154245176 + (S) 0.41595826557420951684124942212799147948 * x; + const auto d_1_2 = (S) -2.93839021178414636324893816529360171731 + (S) 3.27120330332951521662427278605230451458 * x; + const auto d_3_4_5_6 = d_3_4 + d_5_6 * x_sq; + const auto d_0_1_2 = (S) 1 + d_1_2 * x; + const auto d_0_1_2_3_4_5_6 = d_0_1_2 + d_3_4_5_6 * x_sq * x; + + return x * n_0_1_2_3_4_5 / d_0_1_2_3_4_5_6; + } + else + { + return {}; + } +} + +/** + * Approximation of the "dilogarithm" function for all inputs. + * + * Orders higher than 3 are generally not recommended for + * single-precision floating-point types, since they don't + * improve the accuracy very much. + */ +template = 5), typename T> +constexpr T li2 (T x) +{ + const auto x_r = (T) 1 / x; + const auto x_r1 = (T) 1 / (x - (T) 1); + + constexpr auto pisq_o_6 = (T) M_PI * (T) M_PI / (T) 6; + constexpr auto pisq_o_3 = (T) M_PI * (T) M_PI / (T) 3; + + T y, r; + bool sign = true; + if (x < (T) -1) + { + y = -x_r1; + const auto l = log ((T) 1 - x); + r = -pisq_o_6 + l * ((T) 0.5 * l - log (-x)); + } + else if (x < (T) 0) + { + y = x * x_r1; + const auto l = log ((T) 1 - x); + r = (T) -0.5 * l * l; + sign = false; + } + else if (x < (T) 0.5) + { + y = x; + r = {}; + } + else if (x < (T) 1) + { + y = (T) 1 - x; + r = pisq_o_6 - log (x) * log (y); + sign = false; + } + else if (x < (T) 2) + { + y = (T) 1 - x_r; + const auto l = log (x); + r = pisq_o_6 - l * (log (y) + (T) 0.5 * l); + } + else + { + y = x_r; + const auto l = log (x); + r = pisq_o_3 - (T) 0.5 * l * l; + sign = false; + } + + const auto li2_reduce = li2_0_half (y); + return r + select (sign, li2_reduce, -li2_reduce); +} + +#if defined(XSIMD_HPP) +/** + * Approximation of the "dilogarithm" function for all inputs. + * + * Orders higher than 3 are generally not recommended for + * single-precision floating-point types, since they don't + * improve the accuracy very much. + */ +template = 5), typename T> +xsimd::batch li2 (const xsimd::batch& x) +{ + // x < -1: + // - log(-x) -> [1, inf] + // - log(1-x) -> [2, inf] + // x < 0: + // - NOP + // - log(1-x) -> [1, 2] + // x < 1/2: + // - NOP + // - NOP + // x < 1: + // - log(x) -> [1/2, 1] + // - log(1-x) -> [0, 1/2] + // x < 2: + // - log(x) -> [1, 2] + // - log(1-1/x) -> [0, 1/2] + // x >= 2: + // - log(x) -> [2, inf] + // - NOP + + const auto x_r = (T) 1 / x; + const auto x_r1 = (T) 1 / (x - (T) 1); + const auto log_arg1 = select (x < (T) -1, -x, select (x < (T) 0.5, xsimd::broadcast ((T) 1), x)); + const auto log_arg2 = select (x < (T) 1, (T) 1 - x, (T) 1 - x_r); + + const auto log1 = log (log_arg1); + const auto log2 = log (log_arg2); + + // clang-format off + const auto y = select (x < (T) -1, (T) -1 * x_r1, + select (x < (T) 0, x * x_r1, + select (x < (T) 0.5, x, + select (x < (T) 1, (T) 1 - x, + select (x < (T) 2, (T) 1 - x_r, + x_r))))); + const auto sign = x < (T) -1 || (x >= (T) 0 && x < (T) 0.5) || (x >= (T) 1 && x < (T) 2); + + static constexpr auto pisq_o_6 = (T) M_PI * (T) M_PI / (T) 6; + static constexpr auto pisq_o_3 = (T) M_PI * (T) M_PI / (T) 3; + const auto log1_log2 = log1 * log2; + const auto half_log1_sq = (T) 0.5 * log1 * log1; + const auto half_log2_sq = (T) 0.5 * log2 * log2; + const auto r = select (x < (T) -1, -pisq_o_6 + half_log2_sq - log1_log2, + select (x < (T) 0, -half_log2_sq, + select (x < (T) 0.5, xsimd::broadcast ((T) 0), + select (x < (T) 1, pisq_o_6 - log1_log2, + select (x < (T) 2, pisq_o_6 - log1_log2 - half_log1_sq, + pisq_o_3 - half_log1_sq))))); + //clang-format on + + const auto li2_reduce = li2_0_half (y); + return r + select (sign, li2_reduce, -li2_reduce); +} +#endif +} // namespace math_approx diff --git a/hi_tools/hi_neural/RTNeural/modules/math_approx/src/pow_approx.hpp b/hi_tools/hi_neural/RTNeural/modules/math_approx/src/pow_approx.hpp new file mode 100644 index 0000000000..51c5a3af7a --- /dev/null +++ b/hi_tools/hi_neural/RTNeural/modules/math_approx/src/pow_approx.hpp @@ -0,0 +1,231 @@ +#pragma once + +#include "basic_math.hpp" + +namespace math_approx +{ +namespace pow_detail +{ + // for polynomial derivations, see notebooks/exp_approx.nb + + /** approximation for 2^x, optimized on the range [0, 1] */ + template + constexpr T pow2_approx (T x) + { + static_assert (order >= 3 && order <= 7); + using S = scalar_of_t; + + const auto x_sq = x * x; + if constexpr (C1_continuous) + { + if constexpr (order == 3) + { + const auto x_2_3 = (S) 0.227411277760 + (S) 0.0794415416798 * x; + const auto x_0_1 = (S) 1 + (S) 0.693147180560 * x; + return x_0_1 + x_2_3 * x_sq; + } + else if constexpr (order == 4) + { + const auto x_3_4 = (S) 0.0521277476109 + (S) 0.0136568970345 * x; + const auto x_1_2 = (S) 0.693147180560 + (S) 0.241068174795 * x; + const auto x_1_2_3_4 = x_1_2 + x_3_4 * x_sq; + return (S) 1 + x_1_2_3_4 * x; + } + else if constexpr (order == 5) + { + const auto x_4_5 = (S) 0.00899838527231 + (S) 0.00188723482038 * x; + const auto x_2_3 = (S) 0.240184132673 + (S) 0.0557830666741 * x; + const auto x_2_3_4_5 = x_2_3 + x_4_5 * x_sq; + const auto x_0_1 = (S) 1 + (S) 0.693147180560 * x; + return x_0_1 + x_2_3_4_5 * x_sq; + } + else if constexpr (order == 6) + { + const auto x_5_6 = (S) 0.00124453797252 + (S) 0.000217714753229 * x; + const auto x_3_4 = (S) 0.0554875633068 + (S) 0.00967475272129 * x; + const auto x_1_2 = (S) 0.693147180560 + (S) 0.240228250686 * x; + const auto x_3_4_5_6 = x_3_4 + x_5_6 * x_sq; + const auto x_1_2_3_4_5_6 = x_1_2 + x_3_4_5_6 * x_sq; + return (S) 1 + x_1_2_3_4_5_6 * x; + } + else if constexpr (order == 7) + { + // doesn't seem to help at single-precision + const auto x_6_7 = (S) 0.000133154170702612 + (S) 0.0000245778949916153 * x; + const auto x_4_5 = (S) 0.00960612128901630 + (S) 0.00135551454943593 * x; + const auto x_2_3 = (S) 0.240226202240181 + (S) 0.0555072492957270 * x; + const auto x_0_1 = (S) 1 + (S) 0.693147180559945 * x; + const auto x_4_5_6_7 = x_4_5 + x_6_7 * x_sq; + const auto x_0_1_2_3 = x_0_1 + x_2_3 * x_sq; + return x_0_1_2_3 + x_4_5_6_7 * x_sq * x_sq; + } + else + { + return {}; + } + } + else + { + if constexpr (order == 3) + { + const auto x_2_3 = (S) 0.226307586882 + (S) 0.0782680256330 * x; + const auto x_0_1 = (S) 1 + (S) 0.695424387485 * x; + return x_0_1 + x_2_3 * x_sq; + } + else if constexpr (order == 4) + { + const auto x_3_4 = (S) 0.0520324008177 + (S) 0.0135557244044 * x; + const auto x_1_2 = (S) 0.693032120001 + (S) 0.241379754777 * x; + const auto x_1_2_3_4 = x_1_2 + x_3_4 * x_sq; + return (S) 1 + x_1_2_3_4 * x; + } + else if constexpr (order == 5) + { + const auto x_4_5 = (S) 0.00899009909264 + (S) 0.00187839071291 * x; + const auto x_2_3 = (S) 0.240156326598 + (S) 0.0558229130202 * x; + const auto x_2_3_4_5 = x_2_3 + x_4_5 * x_sq; + const auto x_0_1 = (S) 1 + (S) 0.693152270576 * x; + return x_0_1 + x_2_3_4_5 * x_sq; + } + else if constexpr (order == 6) + { + const auto x_5_6 = (S) 0.00124359387839 + (S) 0.000217187820427 * x; + const auto x_3_4 = (S) 0.0554833098983 + (S) 0.00967911763840 * x; + const auto x_1_2 = (S) 0.693147003658 + (S) 0.240229787107 * x; + const auto x_3_4_5_6 = x_3_4 + x_5_6 * x_sq; + const auto x_1_2_3_4_5_6 = x_1_2 + x_3_4_5_6 * x_sq; + return (S) 1 + x_1_2_3_4_5_6 * x; + } + else if constexpr (order == 7) + { + // doesn't seem to help at single-precision + const auto x_6_7 = (S) 0.000136898688977877 + (S) 0.0000234440812713967 * x; + const auto x_4_5 = (S) 0.00960825566419915 + (S) 0.00135107295099880 * x; + const auto x_2_3 = (S) 0.240226092549669 + (S) 0.0555070350342468 * x; + const auto x_0_1 = (S) 1 + (S) 0.693147201030637 * x; + const auto x_4_5_6_7 = x_4_5 + x_6_7 * x_sq; + const auto x_0_1_2_3 = x_0_1 + x_2_3 * x_sq; + return x_0_1_2_3 + x_4_5_6_7 * x_sq * x_sq; + } + else + { + return {}; + } + } + } + + template + struct BaseE + { + static constexpr auto log2_base = (T) 1.4426950408889634074; + }; + + template + struct Base2 + { + static constexpr auto log2_base = (T) 1; + }; + + template + struct Base10 + { + static constexpr auto log2_base = (T) 3.3219280948873623479; + }; +} + +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" // these methods require some type-punning +#pragma GCC diagnostic ignored "-Wuninitialized" +#endif + +/** approximation for pow(Base, x) (32-bit) */ +template +constexpr float pow (float x) +{ + x = std::max (-126.0f, Base::log2_base * x); + + const auto xi = (int32_t) x; + const auto l = x < 0.0f ? xi - 1 : xi; + const auto f = x - (float) l; + const auto vi = (l + 127) << 23; + + return bit_cast (vi) * pow_detail::pow2_approx (f); +} + +/** approximation for pow(Base, x) (64-bit) */ +template +constexpr double pow (double x) +{ + x = std::max (-1022.0, Base::log2_base * x); + + const auto xi = (int64_t) x; + const auto l = x < 0.0 ? xi - 1 : xi; + const auto d = x - (double) l; + const auto vi = (l + 1023) << 52; + + return bit_cast (vi) * pow_detail::pow2_approx (d); +} + +#if defined(XSIMD_HPP) +/** approximation for pow(Base, x) (32-bit SIMD) */ +template +xsimd::batch pow (xsimd::batch x) +{ + x = xsimd::max (xsimd::broadcast (-126.0f), Base::log2_base * x); + + const auto xi = xsimd::to_int (x); + const auto l = xsimd::select (xsimd::batch_bool_cast (x < 0.0f), xi - 1, xi); + const auto f = x - xsimd::to_float (l); + const auto vi = (l + 127) << 23; + + return xsimd::bit_cast> (vi) * pow_detail::pow2_approx, order, C1_continuous> (f); +} + +/** approximation for pow(Base, x) (64-bit SIMD) */ +template +xsimd::batch pow (xsimd::batch x) +{ + x = xsimd::max (-1022.0, Base::log2_base * x); + + const auto xi = xsimd::to_int (x); + const auto l = xsimd::select (xsimd::batch_bool_cast (x < 0.0), xi - 1, xi); + const auto d = x - xsimd::to_float (l); + const auto vi = (l + 1023) << 52; + + return xsimd::bit_cast> (vi) * pow_detail::pow2_approx, order, C1_continuous> (d); +} +#endif + +#if defined(__GNUC__) +#pragma GCC diagnostic pop // end ignore strict-aliasing warnings +#endif + +/** Approximation of exp(x), using exp(x) = 2^floor(x * log2(e)) * 2^frac(x * log2(e)) */ +template +constexpr T exp (T x) +{ + return pow>, order, C1_continuous> (x); +} + +/** Approximation of exp2(x), using exp(x) = 2^floor(x) * 2^frac(x) */ +template +constexpr T exp2 (T x) +{ + return pow>, order, C1_continuous> (x); +} + +/** Approximation of exp(x), using exp10(x) = 2^floor(x * log2(10)) * 2^frac(x * log2(10)) */ +template +constexpr T exp10 (T x) +{ + return pow>, order, C1_continuous> (x); +} + +/** Approximation of exp(1) - 1, using math_approx::exp(x) */ +template +constexpr T expm1 (T x) +{ + return pow>, order, C1_continuous> (x) - (T) 1; +} +} diff --git a/hi_tools/hi_neural/RTNeural/modules/math_approx/src/sigmoid_approx.hpp b/hi_tools/hi_neural/RTNeural/modules/math_approx/src/sigmoid_approx.hpp new file mode 100644 index 0000000000..68a9ffe499 --- /dev/null +++ b/hi_tools/hi_neural/RTNeural/modules/math_approx/src/sigmoid_approx.hpp @@ -0,0 +1,92 @@ +#pragma once + +#include "basic_math.hpp" + +namespace math_approx +{ +namespace sigmoid_detail +{ + // for polynomial derivations, see notebooks/sigmoid_approx.nb + + template + constexpr T sig_poly_9 (T x) + { + using S = scalar_of_t; + const auto x_sq = x * x; + const auto y_7_9 = (S) 1.50024356624e-6 + (S) 6.92468584642e-9 * x_sq; + const auto y_5_7_9 = (S) 0.000260923534301 + y_7_9 * x_sq; + const auto y_3_5_7_9 = (S) 0.0208320229264 + y_5_7_9 * x_sq; + const auto y_1_3_5_7_9 = (S) 0.5 + y_3_5_7_9 * x_sq; + return x * y_1_3_5_7_9; + } + + template + constexpr T sig_poly_7 (T x) + { + using S = scalar_of_t; + const auto x_sq = x * x; + const auto y_5_7 = (S) 0.000255174491559 + (S) 1.90805380557e-6 * x_sq; + const auto y_3_5_7 = (S) 0.0208503675870 + y_5_7 * x_sq; + const auto y_1_3_5_7 = (S) 0.5 + y_3_5_7 * x_sq; + return x * y_1_3_5_7; + } + + template + constexpr T sig_poly_5 (T x) + { + using S = scalar_of_t; + const auto x_sq = x * x; + const auto y_3_5 = (S) 0.0206108521251 + (S) 0.000307906311109 * x_sq; + const auto y_1_3_5 = (S) 0.5 + y_3_5 * x_sq; + return x * y_1_3_5; + } + + template + constexpr T sig_poly_3 (T x) + { + using S = scalar_of_t; + const auto x_sq = x * x; + const auto y_1_3 = (S) 0.5 + (S) 0.0233402955195 * x_sq; + return x * y_1_3; + } +} // namespace sigmoid_detail + +/** + * Approximation of sigmoid(x) := 1 / (1 + e^-x), + * using sigmoid(x) ≈ (1/2) p(x) / (p(x)^2 + 1) + (1/2), + * where p(x) is an odd polynomial fit to minimize the maxinimum relative error. + */ +template +T sigmoid (T x) +{ + static_assert (order % 2 == 1 && order <= 9 && order >= 3, "Order must e an odd number within [3, 9]"); + + T x_poly {}; + if constexpr (order == 9) + x_poly = sigmoid_detail::sig_poly_9 (x); + else if constexpr (order == 7) + x_poly = sigmoid_detail::sig_poly_7 (x); + else if constexpr (order == 5) + x_poly = sigmoid_detail::sig_poly_5 (x); + else if constexpr (order == 3) + x_poly = sigmoid_detail::sig_poly_3 (x); + + using S = scalar_of_t; + return (S) 0.5 * x_poly * rsqrt (x_poly * x_poly + (S) 1) + (S) 0.5; +} + + +/** + * Approximation of sigmoid(x) := 1 / (1 + e^-x), + * using math_approx::exp (x). + * + * So far this has tested slower than the above approximation + * for similar absolute error, but has better relative error + * characteristics. + */ +template +T sigmoid_exp (T x) +{ + return (T) 1 / ((T) 1 + math_approx::exp (-x)); +} +} // namespace math_approx diff --git a/hi_tools/hi_neural/RTNeural/modules/math_approx/src/trig_approx.hpp b/hi_tools/hi_neural/RTNeural/modules/math_approx/src/trig_approx.hpp new file mode 100644 index 0000000000..8dba67a016 --- /dev/null +++ b/hi_tools/hi_neural/RTNeural/modules/math_approx/src/trig_approx.hpp @@ -0,0 +1,256 @@ +#pragma once + +#include "basic_math.hpp" + +namespace math_approx +{ +namespace trig_detail +{ + template + constexpr T truncate (T x) + { + return static_cast (static_cast (x)); + } + +#if defined(XSIMD_HPP) + template + xsimd::batch truncate (xsimd::batch x) + { + return xsimd::to_float (xsimd::to_int (x)); + } +#endif + + /** Fast method to wrap a value into the range [-pi, pi] */ + template + constexpr T fast_mod_mpi_pi (T x) + { + using S = scalar_of_t; + constexpr auto pi = static_cast (M_PI); + constexpr auto two_pi = static_cast (2.0 * M_PI); + constexpr auto recip_two_pi = static_cast (1) / two_pi; + + x += pi; + const auto mod = x - two_pi * truncate (x * recip_two_pi); + return select (x >= (T) 0, mod, mod + two_pi) - pi; + } + + /** Fast method to wrap a value into the range [-pi/2, pi/2] */ + template + constexpr T fast_mod_mhalfpi_halfpi (T x) + { + using S = scalar_of_t; + constexpr auto half_pi = static_cast (M_PI) * (S) 0.5; + constexpr auto pi = static_cast (M_PI); + constexpr auto recip_pi = (S) 1 / pi; + + x += half_pi; + const auto mod = x - pi * truncate (x * recip_pi); + return select (x >= (T) 0, mod, mod + pi) - half_pi; + } + + // Polynomials were derived using the method presented in + // https://mooooo.ooo/chebyshev-sine-approximation/ + // and then adapted for various (odd) orders. + + template + constexpr T sin_poly_9 (T x, T x_sq) + { + using S = scalar_of_t; + const auto x_7_9 = (S) -2.49397084313e-6 + (S) 2.00382818811e-8 * x_sq; + const auto x_5_7_9 = (S) 0.000173405228576 + x_7_9 * x_sq; + const auto x_3_5_7_9 = (S) -0.00662075636230 + x_5_7_9 * x_sq; + const auto x_1_3_5_7_9 = (S) 0.101321159036 + x_3_5_7_9 * x_sq; + return x * x_1_3_5_7_9; + } + + template + constexpr T sin_poly_7 (T x, T x_sq) + { + using S = scalar_of_t; + const auto x_5_7 = (S) 0.000170965340046 + (S) -2.09843101304e-6 * x_sq; + const auto x_3_5_7 = (S) -0.00661594021539 + x_5_7 * x_sq; + const auto x_1_3_5_7 = (S) 0.101319673615 + x_3_5_7 * x_sq; + return x * x_1_3_5_7; + } + + template + constexpr T sin_poly_5 (T x, T x_sq) + { + using S = scalar_of_t; + const auto x_3_5 = (S) -0.00650096169550 + (S) 0.000139899314103 * x_sq; + const auto x_1_3_5 = (S) 0.101256629587 + x_3_5 * x_sq; + return x * x_1_3_5; + } +} // namespace sin_detail + +/** Polynomial approximation of sin(x) on the range [-pi, pi] */ +template +constexpr T sin_mpi_pi (T x) +{ + static_assert (order % 2 == 1 && order <= 9 && order >= 5, "Order must be an odd number within [5, 9]"); + + using S = scalar_of_t; + constexpr auto pi = static_cast (M_PI); + constexpr auto pi_sq = pi * pi; + const auto x_sq = x * x; + + T x_poly {}; + if constexpr (order == 9) + x_poly = trig_detail::sin_poly_9 (x, x_sq); + else if constexpr (order == 7) + x_poly = trig_detail::sin_poly_7 (x, x_sq); + else if constexpr (order == 5) + x_poly = trig_detail::sin_poly_5 (x, x_sq); + + return (pi_sq - x_sq) * x_poly; +} + +/** Full range approximation of sin(x) */ +template +constexpr T sin (T x) +{ + return sin_mpi_pi (trig_detail::fast_mod_mpi_pi (x)); +} + +/** + * Polynomial approximation of cos(x) on the range [-pi, pi], + * using a range-shifted approximation of sin(x). + */ +template +constexpr T cos_mpi_pi (T x) +{ + static_assert (order % 2 == 1 && order <= 9 && order >= 5, "Order must be an odd number within [5, 9]"); + + using S = scalar_of_t; + constexpr auto pi = static_cast (M_PI); + constexpr auto pi_sq = pi * pi; + constexpr auto pi_o_2 = pi * (S) 0.5; + + using std::abs; +#if defined(XSIMD_HPP) + using xsimd::abs; +#endif + x = abs (x); + + const auto hpmx = pi_o_2 - x; + const auto hpmx_sq = hpmx * hpmx; + + T x_poly {}; + if constexpr (order == 9) + x_poly = trig_detail::sin_poly_9 (hpmx, hpmx_sq); + else if constexpr (order == 7) + x_poly = trig_detail::sin_poly_7 (hpmx, hpmx_sq); + else if constexpr (order == 5) + x_poly = trig_detail::sin_poly_5 (hpmx, hpmx_sq); + + return (pi_sq - hpmx_sq) * x_poly; +} + +/** Full range approximation of cos(x) */ +template +constexpr T cos (T x) +{ + return cos_mpi_pi (trig_detail::fast_mod_mpi_pi (x)); +} + +/** Polynomial approximation of tan(x) on the range [-pi/4, pi/4] */ +template +constexpr T tan_mquarterpi_quarterpi (T x) +{ + static_assert (order % 2 == 1 && order >= 3 && order <= 15, "Order must be an odd number within [3, 15]"); + + // for polynomial derivation, see notebooks/tan_approx.nb + + using S = scalar_of_t; + const auto x_sq = x * x; + if constexpr (order == 3) + { + const auto x_1_3 = (S) 1 + (S) 0.442959265447 * x_sq; + return x * x_1_3; + } + else if constexpr (order == 5) + { + const auto x_3_5 = (S) 0.317574684334 + (S) 0.203265826702 * x_sq; + const auto x_1_3_5 = (S) 1 + x_3_5 * x_sq; + return x * x_1_3_5; + } + else if constexpr (order == 7) + { + const auto x_5_7 = (S) 0.116406244996 + (S) 0.0944480566104 * x_sq; + const auto x_1_3 = (S) 1 + (S) 0.335216153138 * x_sq; + const auto x_1_3_5_7 = x_1_3 + x_5_7 * x_sq * x_sq; + return x * x_1_3_5_7; + } + else if constexpr (order == 9) + { + const auto x_7_9 = (S) 0.0405232529373 + (S) 0.0439292071029 * x_sq; + const auto x_3_5 = (S) 0.333131667276 + (S) 0.136333765649 * x_sq; + const auto x_3_5_7_9 = x_3_5 + x_7_9 * x_sq * x_sq; + return x * ((S) 1 + x_3_5_7_9 * x_sq); + } + else if constexpr (order == 11) + { + const auto x_q = x_sq * x_sq; + const auto x_9_11 = (S) 0.0126603694551 + (S) 0.0203633469693 * x_sq; + const auto x_5_7 = (S) 0.132897195017 + (S) 0.0570525279731 * x_sq; + const auto x_1_3 = (S) 1 + (S) 0.333353019629 * x_sq; + const auto x_5_7_9_11 = x_5_7 + x_9_11 * x_q; + const auto x_1_3_5_7_9_11 = x_1_3 + x_5_7_9_11 * x_q; + return x * x_1_3_5_7_9_11; + } + else if constexpr (order == 13) + { + const auto x_q = x_sq * x_sq; + const auto x_6 = x_q * x_sq; + const auto x_11_13 = (S) 0.00343732283737 + (S) 0.00921082294855 * x_sq; + const auto x_7_9 = (S) 0.0534743904687 + (S) 0.0242183751709 * x_sq; + const auto x_3_5 = (S) 0.333331890901 + (S) 0.133379954680 * x_sq; + const auto x_7_9_11_13 = x_7_9 + x_11_13 * x_q; + const auto x_1_3_5 = (S) 1 + x_3_5 * x_sq; + return x * (x_1_3_5 + x_7_9_11_13 * x_6); + } + else if constexpr (order == 15) + { + // doesn't seem to help much at single-precision, but here it is: + const auto x_q = x_sq * x_sq; + const auto x_8 = x_q * x_q; + const auto x_13_15 = (S) 0.000292958045126 + (S) 0.00427933470414 * x_sq; + const auto x_9_11 = (S) 0.0213477960960 + (S) 0.0106702896251 * x_sq; + const auto x_5_7 = (S) 0.133327796402 + (S) 0.0540469276103* x_sq; + const auto x_1_3 = (S) 1 + (S) 0.333333463757 * x_sq; + const auto x_9_11_13_15 = x_9_11 + x_13_15 * x_q; + const auto x_1_3_5_7 = x_1_3 + x_5_7 * x_q; + const auto x_1_3_5_7_9_11_13_15 = x_1_3_5_7 + x_9_11_13_15 * x_8; + return x * x_1_3_5_7_9_11_13_15; + } + else + { + return {}; + } +} + +/** + * Approximation of tan(x) on the range [-pi/2, pi/2], + * using the tangent half-angle formula. + * + * Accuracy may suffer as x approaches ±pi/2. + */ +template +constexpr T tan_mhalfpi_halfpi (T x) +{ + using S = scalar_of_t; + const auto h_x = tan_mquarterpi_quarterpi ((S) 0.5 * x); + return (S) 2 * h_x / ((S) 1 - h_x * h_x); +} + +/** + * Full-range approximation of tan(x) + * + * Accuracy may suffer as x approaches values for which tan(x) approaches ±Inf. + */ +template +constexpr T tan (T x) +{ + return tan_mhalfpi_halfpi (trig_detail::fast_mod_mhalfpi_halfpi (x)); +} +} // namespace math_approx diff --git a/hi_tools/hi_neural/RTNeural/modules/math_approx/src/wright_omega_approx.hpp b/hi_tools/hi_neural/RTNeural/modules/math_approx/src/wright_omega_approx.hpp new file mode 100644 index 0000000000..eecb199c2b --- /dev/null +++ b/hi_tools/hi_neural/RTNeural/modules/math_approx/src/wright_omega_approx.hpp @@ -0,0 +1,94 @@ +#pragma once + +#include "basic_math.hpp" + +namespace math_approx +{ +/** + * Approximation of the Wright-Omega function, using + * w(x) ≈ 0 for x < -3 + * w(x) ≈ p(x) for -3 <= x < e + * w(x) ≈ x - log(x) + alpha * exp(-beta * x) for x >= e, + * where p(x) is a polynomial, and alpha and beta are coefficients, + * all fit to minimize the maximum absolute error. + * + * The above fit is optionally followed by some number of Newton-Raphson iterations. + */ +template +constexpr T wright_omega (T x) +{ + static_assert (poly_order == 3 || poly_order == 5); + + using S = scalar_of_t; + constexpr auto E = (S) 2.7182818284590452354; + + const auto x1 = [] (T _x) + { + const auto x_sq = _x * _x; + if constexpr (poly_order == 3) + { + const auto y_2_3 = (S) 0.0534379648805832 + (S) -0.00251076420630778 * _x; + const auto y_0_1 = (S) 0.616522951065868 + (S) 0.388418422853809 * _x; + return y_0_1 + y_2_3 * x_sq; + } + else if constexpr (poly_order == 5) + { + const auto y_4_5 = (S) -0.00156418794118294 + (S) -0.00151562297325209 * _x; + const auto y_2_3 = (S) 0.0719291313363515 + (S) 0.0216881206167543 * _x; + const auto y_0_1 = (S) 0.569291529016010 + (S) 0.290890537885083 * _x; + const auto y_2_3_4_5 = y_2_3 + y_4_5 * x_sq; + return y_0_1 + y_2_3_4_5 * x_sq; + } + else + { + return T {}; + } + }(x); + const auto x2 = x - log (x) + (S) 0.32352057096397160124 * exp ((S) -0.029614177658043381316 * x); + + auto y = select (x < (S) -3, T {}, select (x < (S) E, x1, x2)); + + const auto nr_update = [] (T _x, T _y) + { + return _y - (_y - exp (_x - _y)) / (_y + (S) 1); + }; + + for (int i = 0; i < num_nr_iters; ++i) + y = nr_update (x, y); + + return y; +} + +/** + * Wright-Omega function using Stephano D'Angelo's derivation (https://www.dafx.de/paper-archive/2019/DAFx2019_paper_5.pdf) + * With `num_nr_iters == 0`, this is the fastest implementation, but the least accurate. + * With `num_nr_iters == 1`, this is faster than the other implementation with 0 iterations, and little bit more accurate. + * For more accuracy, use the other implementation with at least 1 NR iteration. + */ +template +constexpr T wright_omega_dangelo (T x) +{ + using S = scalar_of_t; + + const auto x1 = [] (T _x) + { + const auto x_sq = _x * _x; + const auto y_2_3 = (S) 4.775931364975583e-2 + (S) -1.314293149877800e-3 * _x; + const auto y_0_1 = (S) 6.313183464296682e-1 + (S) 3.631952663804445e-1 * _x; + return y_0_1 + y_2_3 * x_sq; + }(x); + const auto x2 = x - log (x); + + auto y = select (x < (S) -3.341459552768620, T {}, select (x < (S) 8, x1, x2)); + + const auto nr_update = [] (T _x, T _y) + { + return _y - (_y - exp (_x - _y)) / (_y + (S) 1); + }; + + for (int i = 0; i < num_nr_iters; ++i) + y = nr_update (x, y); + + return y; +} +} // namespace math_approx diff --git a/hi_tools/hi_neural/hi_neural.cpp b/hi_tools/hi_neural/hi_neural.cpp index 4c36a60371..695728fc0d 100644 --- a/hi_tools/hi_neural/hi_neural.cpp +++ b/hi_tools/hi_neural/hi_neural.cpp @@ -1,9 +1,12 @@ #define RTNEURAL_DEFAULT_ALIGNMENT 16 #define RTNEURAL_USE_XSIMD 1 +#include "hi_neural.h" #include "hi_neural.h" #include "RTNeural/RTNeural/RTNeural.h" +#include "RTNeural/modules/math_approx/math_approx.hpp" +#include "RTNeural/RTNeural/wavenet/wavenet_model.hpp" namespace hise { @@ -488,6 +491,115 @@ Result NeuralNetwork::loadPytorchModel(const var& fullJson) return loadWeights(weightData); } + + +struct NAMModel: public NeuralNetwork::ModelBase +{ + using Dilations = wavenet::Dilations<1, 2, 4, 8, 16, 32, 64, 128, 256, 512>; + + struct NAMMathsProvider + { + #if RTNEURAL_USE_EIGEN + template + static auto tanh (const Matrix& x) + { + // See: math_approx::tanh<3> + const auto x_poly = x.array() * (1.0f + 0.183428244899f * x.array().square()); + return x_poly.array() * (x_poly.array().square() + 1.0f).array().rsqrt(); + } + #elif RTNEURAL_USE_XSIMD + template + static T tanh (const T& x) + { + return math_approx::tanh<3> (x); + } + #endif + }; + + using Dilations = wavenet::Dilations<1, 2, 4, 8, 16, 32, 64, 128, 256, 512>; + + NAMModel(const var& data_): + ModelBase(), + jsonData(data_) + { + auto s = JSON::toString(jsonData, true); + loadWeights(s); + }; + + void reset() final + { + obj.prewarm(); + }; + + void process(const float* input, float* output) final + { + *output = obj.forward(*input); + }; + + int getNumInputs() const final + { + return 1; + }; + + int getNumOutputs() const final + { + return 1; + }; + + ModelBase* clone() final + { + return new NAMModel(jsonData); + }; + + Result loadWeights(const String& jsonData) final + { + nlohmann::json model_json {}; + auto j = nlohmann::json::parse(jsonData.toStdString()); + + try + { + obj.load_weights(j); + } + catch(std::exception& e) + { + return Result::fail(e.what()); + } + }; + + wavenet::Wavenet_Model, + wavenet::Layer_Array> + + obj; + + var jsonData; +}; + +Result NeuralNetwork::loadNAMModel(const var& jsonData) +{ + OwnedArray nm; + + try + { + nm.add(new NAMModel(jsonData)); + + for(int i = 1; i < getNumNetworks(); i++) + nm.add(nm.getFirst()->clone()); + } + catch(Result& r) + { + return r; + } + + { + SimpleReadWriteLock::ScopedMultiWriteLock sl(lock); + currentModels.swapWith(nm); + } + + return Result::ok(); +} + void NeuralNetwork::clearModel() { OwnedArray nm; diff --git a/hi_tools/hi_neural/hi_neural.h b/hi_tools/hi_neural/hi_neural.h index e7d8e79543..879bed2649 100644 --- a/hi_tools/hi_neural/hi_neural.h +++ b/hi_tools/hi_neural/hi_neural.h @@ -173,6 +173,9 @@ struct NeuralNetwork: public ReferenceCountedObject, /** Loads a model with trained weights from Pytorch. */ Result loadPytorchModel(const var& jsonData); + /** Loads a model from a NAM file. */ + Result loadNAMModel(const var& jsonData); + /** Build a model from the JSON layout. */ Result build(const var& modelJSON); diff --git a/tools/onnx_lib/Source/Main.cpp b/tools/onnx_lib/Source/Main.cpp index 22cac4a184..94cecb3296 100644 --- a/tools/onnx_lib/Source/Main.cpp +++ b/tools/onnx_lib/Source/Main.cpp @@ -13,11 +13,7 @@ #define DLL_EXPORT extern "C" __attribute__((visibility("default"))) #endif -#if JUCE_DEBUG #include "include/onnxruntime_cxx_api.h" -#else -#include "include_rel/onnxruntime_cxx_api.h" -#endif using namespace juce; diff --git a/tools/onnx_lib/Source/include_rel/cpu_provider_factory.h b/tools/onnx_lib/Source/include_rel/cpu_provider_factory.h deleted file mode 100644 index 292678692b..0000000000 --- a/tools/onnx_lib/Source/include_rel/cpu_provider_factory.h +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "onnxruntime_c_api.h" - -#ifdef __cplusplus -extern "C" { -#endif - -/** - * \param use_arena zero: false. non-zero: true. - */ -ORT_EXPORT -ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CPU, _In_ OrtSessionOptions* options, int use_arena) -ORT_ALL_ARGS_NONNULL; - -#ifdef __cplusplus -} -#endif diff --git a/tools/onnx_lib/Source/include_rel/onnxruntime_c_api.h b/tools/onnx_lib/Source/include_rel/onnxruntime_c_api.h deleted file mode 100644 index 5aafdd149e..0000000000 --- a/tools/onnx_lib/Source/include_rel/onnxruntime_c_api.h +++ /dev/null @@ -1,4832 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// See docs\c_cxx\README.md on generating the Doxygen documentation from this file - -/** \mainpage ONNX Runtime - * - * ONNX Runtime is a high-performance inference and training graph execution engine for deep learning models. - * - * ONNX Runtime's C, C++ APIs offer an easy to use interface to onboard and execute onnx models. - * - \subpage c_cpp_api "Core C, C++ APIs" - * - \subpage training_c_cpp_api "Training C, C++ APIs for on-device training" - * - * \page c_cpp_api Core C, C++ APIs - *

C

- * - * ::OrtApi - Click here to go to the structure with all C API functions. - * - *

C++

- * - * ::Ort - Click here to go to the namespace holding all of the C++ wrapper classes - * - * It is a set of header only wrapper classes around the C API. The goal is to turn the C style return value error codes into C++ exceptions, and to - * automate memory management through standard C++ RAII principles. - * - * \addtogroup Global - * ONNX Runtime C API - * @{ - */ - -#pragma once -#include -#include -#include -#include - -/** \brief The API version defined in this header - * - * This value is used by some API functions to behave as this version of the header expects. - */ -#define ORT_API_VERSION 19 - -#ifdef __cplusplus -extern "C" { -#endif - -//! @} -// SAL2 Definitions -#ifndef _WIN32 -#define _In_ -#define _In_z_ -#define _In_opt_ -#define _In_opt_z_ -#define _Out_ -#define _Outptr_ -#define _Out_opt_ -#define _Inout_ -#define _Inout_opt_ -#define _Frees_ptr_opt_ -#define _Ret_maybenull_ -#define _Ret_notnull_ -#define _Check_return_ -#define _Outptr_result_maybenull_ -#define _In_reads_(X) -#define _Inout_updates_(X) -#define _Out_writes_(X) -#define _Inout_updates_all_(X) -#define _Out_writes_bytes_all_(X) -#define _Out_writes_all_(X) -#define _Success_(X) -#define _Outptr_result_buffer_maybenull_(X) -#define ORT_ALL_ARGS_NONNULL __attribute__((nonnull)) -#else -#include -#define ORT_ALL_ARGS_NONNULL -#endif - -#ifdef _WIN32 -// Define ORT_DLL_IMPORT if your program is dynamically linked to Ort. -// dllexport is not used, we use a .def file. -#ifdef ORT_DLL_IMPORT -#define ORT_EXPORT __declspec(dllimport) -#else -#define ORT_EXPORT -#endif -#define ORT_API_CALL _stdcall -#define ORT_MUST_USE_RESULT -#define ORTCHAR_T wchar_t -#else -// To make symbols visible on macOS/iOS -#ifdef __APPLE__ -#define ORT_EXPORT __attribute__((visibility("default"))) -#else -#define ORT_EXPORT -#endif -#define ORT_API_CALL -#define ORT_MUST_USE_RESULT __attribute__((warn_unused_result)) -#define ORTCHAR_T char -#endif - -/// ORTCHAR_T, ORT_TSTR are reserved specifically for path handling. -/// All other strings are UTF-8 encoded, use char and std::string -#ifndef ORT_TSTR -#ifdef _WIN32 -#define ORT_TSTR(X) L##X -// When X is a macro, L##X is not defined. In this case, we need to use ORT_TSTR_ON_MACRO. -#define ORT_TSTR_ON_MACRO(X) L"" X -#else -#define ORT_TSTR(X) X -#define ORT_TSTR_ON_MACRO(X) X -#endif -#endif - -// On Windows, ORT_FILE is a wchar_t version of the __FILE__ macro. -// Otherwise, ORT_FILE is equivalent to __FILE__. -#ifndef ORT_FILE -#define ORT_FILE_INTERNAL(x) ORT_TSTR(x) -#define ORT_FILE ORT_FILE_INTERNAL(__FILE__) -#endif - -// Any pointer marked with _In_ or _Out_, cannot be NULL. - -// Windows users should use unicode paths when possible to bypass the MAX_PATH limitation -// Every pointer marked with _In_ or _Out_, cannot be NULL. Caller should ensure that. -// for ReleaseXXX(...) functions, they can accept NULL pointer. - -#ifdef __cplusplus -// For any compiler with C++11 support, MSVC 2015 and greater, or Clang version supporting noexcept. -// Such complex condition is needed because compilers set __cplusplus value differently. -#ifndef __has_feature -#define __has_feature(x) 0 -#endif -#if ((__cplusplus >= 201103L) || (_MSC_VER >= 1900) || (defined(__has_feature) && __has_feature(cxx_noexcept))) -#define NO_EXCEPTION noexcept -#else -#define NO_EXCEPTION throw() -#endif -#else -#define NO_EXCEPTION -#endif - -// __VA_ARGS__ on Windows and Linux are different -#define ORT_API(RETURN_TYPE, NAME, ...) RETURN_TYPE ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION - -#define ORT_API_STATUS(NAME, ...) \ - _Success_(return == 0) _Check_return_ _Ret_maybenull_ OrtStatusPtr ORT_API_CALL NAME(__VA_ARGS__) \ - NO_EXCEPTION ORT_MUST_USE_RESULT - -// XXX: Unfortunately, SAL annotations are known to not work with function pointers -#define ORT_API2_STATUS(NAME, ...) \ - _Check_return_ _Ret_maybenull_ OrtStatusPtr(ORT_API_CALL* NAME)(__VA_ARGS__) NO_EXCEPTION ORT_MUST_USE_RESULT - -// Used in *.cc files. Almost as same as ORT_API_STATUS, except without ORT_MUST_USE_RESULT and ORT_EXPORT -#define ORT_API_STATUS_IMPL(NAME, ...) \ - _Success_(return == 0) _Check_return_ _Ret_maybenull_ OrtStatusPtr ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION - -#define ORT_CLASS_RELEASE(X) void(ORT_API_CALL * Release##X)(_Frees_ptr_opt_ Ort##X * input) - -#ifdef __DOXYGEN__ -#undef ORT_API_STATUS -#define ORT_API_STATUS(NAME, ...) OrtStatus* NAME(__VA_ARGS__) -#undef ORT_API2_STATUS -#define ORT_API2_STATUS(NAME, ...) OrtStatus* NAME(__VA_ARGS__) -#undef ORT_CLASS_RELEASE -#define ORT_CLASS_RELEASE(X) void Release##X(Ort##X* input) -#undef NO_EXCEPTION -#define NO_EXCEPTION -#endif -/** \addtogroup Global - * ONNX Runtime C API - * @{ - */ - -/** Copied from TensorProto::DataType - * Currently, Ort doesn't support complex64, complex128 - */ -typedef enum ONNXTensorElementDataType { - ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED, - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, // maps to c type float - ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, // maps to c type uint8_t - ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, // maps to c type int8_t - ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, // maps to c type uint16_t - ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, // maps to c type int16_t - ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, // maps to c type int32_t - ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, // maps to c type int64_t - ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, // maps to c++ type std::string - ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, - ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, // maps to c type double - ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, // maps to c type uint32_t - ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, // maps to c type uint64_t - ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64, // complex with float32 real and imaginary components - ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128, // complex with float64 real and imaginary components - ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, // Non-IEEE floating-point format based on IEEE754 single-precision - // float 8 types were introduced in onnx 1.14, see https://onnx.ai/onnx/technical/float8.html - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, // Non-IEEE floating-point format based on IEEE754 single-precision - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ, // Non-IEEE floating-point format based on IEEE754 single-precision - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2, // Non-IEEE floating-point format based on IEEE754 single-precision - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ, // Non-IEEE floating-point format based on IEEE754 single-precision - // Int4 types were introduced in ONNX 1.16. See https://onnx.ai/onnx/technical/int4.html - ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4, // maps to a pair of packed uint4 values (size == 1 byte) - ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4 // maps to a pair of packed int4 values (size == 1 byte) -} ONNXTensorElementDataType; - -// Synced with onnx TypeProto oneof -typedef enum ONNXType { - ONNX_TYPE_UNKNOWN, - ONNX_TYPE_TENSOR, - ONNX_TYPE_SEQUENCE, - ONNX_TYPE_MAP, - ONNX_TYPE_OPAQUE, - ONNX_TYPE_SPARSETENSOR, - ONNX_TYPE_OPTIONAL -} ONNXType; - -// These types are synced with internal -// SparseFormatFlags -typedef enum OrtSparseFormat { - ORT_SPARSE_UNDEFINED = 0, - ORT_SPARSE_COO = 0x1, - ORT_SPARSE_CSRC = 0x2, - ORT_SPARSE_BLOCK_SPARSE = 0x4 -} OrtSparseFormat; - -// Enum allows to query sparse tensor indices -enum OrtSparseIndicesFormat { - ORT_SPARSE_COO_INDICES, - ORT_SPARSE_CSR_INNER_INDICES, - ORT_SPARSE_CSR_OUTER_INDICES, - ORT_SPARSE_BLOCK_SPARSE_INDICES -}; - -/** \brief Logging severity levels - * - * In typical API usage, specifying a logging severity level specifies the minimum severity of log messages to show. - */ -typedef enum OrtLoggingLevel { - ORT_LOGGING_LEVEL_VERBOSE, ///< Verbose informational messages (least severe). - ORT_LOGGING_LEVEL_INFO, ///< Informational messages. - ORT_LOGGING_LEVEL_WARNING, ///< Warning messages. - ORT_LOGGING_LEVEL_ERROR, ///< Error messages. - ORT_LOGGING_LEVEL_FATAL, ///< Fatal error messages (most severe). -} OrtLoggingLevel; - -typedef enum OrtErrorCode { - ORT_OK, - ORT_FAIL, - ORT_INVALID_ARGUMENT, - ORT_NO_SUCHFILE, - ORT_NO_MODEL, - ORT_ENGINE_ERROR, - ORT_RUNTIME_EXCEPTION, - ORT_INVALID_PROTOBUF, - ORT_MODEL_LOADED, - ORT_NOT_IMPLEMENTED, - ORT_INVALID_GRAPH, - ORT_EP_FAIL, -} OrtErrorCode; - -typedef enum OrtOpAttrType { - ORT_OP_ATTR_UNDEFINED = 0, - ORT_OP_ATTR_INT, - ORT_OP_ATTR_INTS, - ORT_OP_ATTR_FLOAT, - ORT_OP_ATTR_FLOATS, - ORT_OP_ATTR_STRING, - ORT_OP_ATTR_STRINGS, -} OrtOpAttrType; - -//! @} -#define ORT_RUNTIME_CLASS(X) \ - struct Ort##X; \ - typedef struct Ort##X Ort##X - -/** \addtogroup Global - * ONNX Runtime C API - * @{ - */ -// The actual types defined have an Ort prefix -ORT_RUNTIME_CLASS(Env); -ORT_RUNTIME_CLASS(Status); // nullptr for Status* indicates success -ORT_RUNTIME_CLASS(MemoryInfo); -ORT_RUNTIME_CLASS(IoBinding); -ORT_RUNTIME_CLASS(Session); // Don't call ReleaseSession from Dllmain (because session owns a thread pool) -ORT_RUNTIME_CLASS(Value); -ORT_RUNTIME_CLASS(RunOptions); -ORT_RUNTIME_CLASS(TypeInfo); -ORT_RUNTIME_CLASS(TensorTypeAndShapeInfo); -ORT_RUNTIME_CLASS(MapTypeInfo); -ORT_RUNTIME_CLASS(SequenceTypeInfo); -ORT_RUNTIME_CLASS(OptionalTypeInfo); -ORT_RUNTIME_CLASS(SessionOptions); -ORT_RUNTIME_CLASS(CustomOpDomain); -ORT_RUNTIME_CLASS(ModelMetadata); -ORT_RUNTIME_CLASS(ThreadPoolParams); -ORT_RUNTIME_CLASS(ThreadingOptions); -ORT_RUNTIME_CLASS(ArenaCfg); -ORT_RUNTIME_CLASS(PrepackedWeightsContainer); -ORT_RUNTIME_CLASS(TensorRTProviderOptionsV2); -ORT_RUNTIME_CLASS(CUDAProviderOptionsV2); -ORT_RUNTIME_CLASS(CANNProviderOptions); -ORT_RUNTIME_CLASS(DnnlProviderOptions); -ORT_RUNTIME_CLASS(Op); -ORT_RUNTIME_CLASS(OpAttr); -ORT_RUNTIME_CLASS(Logger); -ORT_RUNTIME_CLASS(ShapeInferContext); - -#ifdef _WIN32 -typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; -#else -typedef OrtStatus* OrtStatusPtr; -#endif - -/** \brief Memory allocation interface - * - * Structure of function pointers that defines a memory allocator. This can be created and filled in by the user for custom allocators. - * - * When an allocator is passed to any function, be sure that the allocator object is not destroyed until the last allocated object using it is freed. - */ -typedef struct OrtAllocator { - uint32_t version; ///< Must be initialized to ORT_API_VERSION - void*(ORT_API_CALL* Alloc)(struct OrtAllocator* this_, size_t size); ///< Returns a pointer to an allocated block of `size` bytes - void(ORT_API_CALL* Free)(struct OrtAllocator* this_, void* p); ///< Free a block of memory previously allocated with OrtAllocator::Alloc - const struct OrtMemoryInfo*(ORT_API_CALL* Info)(const struct OrtAllocator* this_); ///< Return a pointer to an ::OrtMemoryInfo that describes this allocator - /** - * @brief Optional allocation function to use for memory allocations made during session initialization. - * Use this function if you want to separate allocations made by ORT during Run() calls from - * those made during session initialization. This allows for separate memory management strategies for these allocations. - */ - void*(ORT_API_CALL* Reserve)(struct OrtAllocator* this_, size_t size); ///< Returns a pointer to an allocated block of `size` bytes -} OrtAllocator; - -typedef void(ORT_API_CALL* OrtLoggingFunction)( - void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, - const char* message); - -/** \brief Graph optimization level - * - * Refer to https://www.onnxruntime.ai/docs/performance/graph-optimizations.html#graph-optimization-levels - * for an in-depth understanding of the Graph Optimization Levels. - */ -typedef enum GraphOptimizationLevel { - ORT_DISABLE_ALL = 0, - ORT_ENABLE_BASIC = 1, - ORT_ENABLE_EXTENDED = 2, - ORT_ENABLE_ALL = 99 -} GraphOptimizationLevel; - -typedef enum ExecutionMode { - ORT_SEQUENTIAL = 0, - ORT_PARALLEL = 1, -} ExecutionMode; - -/** \brief Language projection identifiers - * /see OrtApi::SetLanguageProjection - */ -typedef enum OrtLanguageProjection { - ORT_PROJECTION_C = 0, - ORT_PROJECTION_CPLUSPLUS = 1, - ORT_PROJECTION_CSHARP = 2, - ORT_PROJECTION_PYTHON = 3, - ORT_PROJECTION_JAVA = 4, - ORT_PROJECTION_WINML = 5, - ORT_PROJECTION_NODEJS = 6, -} OrtLanguageProjection; - -struct OrtKernelInfo; -typedef struct OrtKernelInfo OrtKernelInfo; -struct OrtKernelContext; -typedef struct OrtKernelContext OrtKernelContext; -struct OrtCustomOp; -typedef struct OrtCustomOp OrtCustomOp; - -typedef enum OrtAllocatorType { - OrtInvalidAllocator = -1, - OrtDeviceAllocator = 0, - OrtArenaAllocator = 1 -} OrtAllocatorType; - -/** \brief Memory types for allocated memory, execution provider specific types should be extended in each provider. - */ -// Whenever this struct is updated, please also update the MakeKey function in onnxruntime / core / framework / execution_provider.cc -typedef enum OrtMemType { - OrtMemTypeCPUInput = -2, ///< Any CPU memory used by non-CPU execution provider - OrtMemTypeCPUOutput = -1, ///< CPU accessible memory outputted by non-CPU execution provider, i.e. CUDA_PINNED - OrtMemTypeCPU = OrtMemTypeCPUOutput, ///< Temporary CPU accessible memory allocated by non-CPU execution provider, i.e. CUDA_PINNED - OrtMemTypeDefault = 0, ///< The default allocator for execution provider -} OrtMemType; - -/** \brief This mimics OrtDevice type constants so they can be returned in the API - */ -typedef enum OrtMemoryInfoDeviceType { - OrtMemoryInfoDeviceType_CPU = 0, - OrtMemoryInfoDeviceType_GPU = 1, - OrtMemoryInfoDeviceType_FPGA = 2 -} OrtMemoryInfoDeviceType; - -/** \brief Algorithm to use for cuDNN Convolution Op - */ -typedef enum OrtCudnnConvAlgoSearch { - OrtCudnnConvAlgoSearchExhaustive, // expensive exhaustive benchmarking using cudnnFindConvolutionForwardAlgorithmEx - OrtCudnnConvAlgoSearchHeuristic, // lightweight heuristic based search using cudnnGetConvolutionForwardAlgorithm_v7 - OrtCudnnConvAlgoSearchDefault, // default algorithm using CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM -} OrtCudnnConvAlgoSearch; - -/** \brief CUDA Provider Options - * - * \see OrtApi::SessionOptionsAppendExecutionProvider_CUDA - */ -typedef struct OrtCUDAProviderOptions { -#ifdef __cplusplus - OrtCUDAProviderOptions() - : device_id{}, - cudnn_conv_algo_search{OrtCudnnConvAlgoSearchExhaustive}, - gpu_mem_limit{SIZE_MAX}, - arena_extend_strategy{}, - do_copy_in_default_stream{1}, - has_user_compute_stream{}, - user_compute_stream{}, - default_memory_arena_cfg{}, - tunable_op_enable{false}, - tunable_op_tuning_enable{false}, - tunable_op_max_tuning_duration_ms{} {} -#endif - - /** \brief CUDA device Id - * Defaults to 0. - */ - int device_id; - - /** \brief CUDA Convolution algorithm search configuration. - * See enum OrtCudnnConvAlgoSearch for more details. - * Defaults to OrtCudnnConvAlgoSearchExhaustive. - */ - OrtCudnnConvAlgoSearch cudnn_conv_algo_search; - - /** \brief CUDA memory limit (To use all possible memory pass in maximum size_t) - * Defaults to SIZE_MAX. - * \note If a ::OrtArenaCfg has been applied, it will override this field - */ - size_t gpu_mem_limit; - - /** \brief Strategy used to grow the memory arena - * 0 = kNextPowerOfTwo
- * 1 = kSameAsRequested
- * Defaults to 0. - * \note If a ::OrtArenaCfg has been applied, it will override this field - */ - int arena_extend_strategy; - - /** \brief Flag indicating if copying needs to take place on the same stream as the compute stream in the CUDA EP - * 0 = Use separate streams for copying and compute. - * 1 = Use the same stream for copying and compute. - * Defaults to 1. - * WARNING: Setting this to 0 may result in data races for some models. - * Please see issue #4829 for more details. - */ - int do_copy_in_default_stream; - - /** \brief Flag indicating if there is a user provided compute stream - * Defaults to 0. - */ - int has_user_compute_stream; - - /** \brief User provided compute stream. - * If provided, please set `has_user_compute_stream` to 1. - */ - void* user_compute_stream; - - /** \brief CUDA memory arena configuration parameters - */ - OrtArenaCfg* default_memory_arena_cfg; - - /** \brief Enable TunableOp for using. - * Set it to 1/0 to enable/disable TunableOp. Otherwise, it is disabled by default. - * This option can be overridden by environment variable ORT_CUDA_TUNABLE_OP_ENABLE. - */ - int tunable_op_enable; - - /** \brief Enable TunableOp for tuning. - * Set it to 1/0 to enable/disable TunableOp tuning. Otherwise, it is disabled by default. - * This option can be overridden by environment variable ORT_CUDA_TUNABLE_OP_TUNING_ENABLE. - */ - int tunable_op_tuning_enable; - - /** \brief Max tuning duration time limit for each instance of TunableOp. - * Defaults to 0 to disable the limit. - */ - int tunable_op_max_tuning_duration_ms; - -} OrtCUDAProviderOptions; - -/** \brief ROCM Provider Options - * - * \see OrtApi::SessionOptionsAppendExecutionProvider_ROCM - */ -typedef struct OrtROCMProviderOptions { -#ifdef __cplusplus - OrtROCMProviderOptions() - : device_id{}, - miopen_conv_exhaustive_search{0}, - gpu_mem_limit{SIZE_MAX}, - arena_extend_strategy{}, - do_copy_in_default_stream{1}, - has_user_compute_stream{}, - user_compute_stream{}, - default_memory_arena_cfg{}, - enable_hip_graph{false}, - tunable_op_enable{false}, - tunable_op_tuning_enable{false}, - tunable_op_max_tuning_duration_ms{} {} -#endif - - /** \brief ROCM device Id - * Defaults to 0. - */ - int device_id; - - /** \brief ROCM MIOpen Convolution algorithm exaustive search option. - * Defaults to 0 (false). - */ - int miopen_conv_exhaustive_search; - - /** \brief ROCM memory limit (To use all possible memory pass in maximum size_t) - * Defaults to SIZE_MAX. - * \note If a ::OrtArenaCfg has been applied, it will override this field - */ - size_t gpu_mem_limit; - - /** \brief Strategy used to grow the memory arena - * 0 = kNextPowerOfTwo
- * 1 = kSameAsRequested
- * Defaults to 0. - * \note If a ::OrtArenaCfg has been applied, it will override this field - */ - int arena_extend_strategy; - - /** \brief Flag indicating if copying needs to take place on the same stream as the compute stream in the ROCM EP - * 0 = Use separate streams for copying and compute. - * 1 = Use the same stream for copying and compute. - * Defaults to 1. - * WARNING: Setting this to 0 may result in data races for some models. - * Please see issue #4829 for more details. - */ - int do_copy_in_default_stream; - - /** \brief Flag indicating if there is a user provided compute stream - * Defaults to 0. - */ - int has_user_compute_stream; - - /** \brief User provided compute stream. - * If provided, please set `has_user_compute_stream` to 1. - */ - void* user_compute_stream; - - /** \brief ROCM memory arena configuration parameters - */ - OrtArenaCfg* default_memory_arena_cfg; - - int enable_hip_graph; - - /** \brief Enable TunableOp for using. - * Set it to 1/0 to enable/disable TunableOp. Otherwise, it is disabled by default. - * This option can be overridden by environment variable ORT_ROCM_TUNABLE_OP_ENABLE. - */ - int tunable_op_enable; - - /** \brief Enable TunableOp for tuning. - * Set it to 1/0 to enable/disable TunableOp tuning. Otherwise, it is disabled by default. - * This option can be overridden by environment variable ORT_ROCM_TUNABLE_OP_TUNING_ENABLE. - */ - int tunable_op_tuning_enable; - - /** \brief Max tuning duration time limit for each instance of TunableOp. - * Defaults to 0 to disable the limit. - */ - int tunable_op_max_tuning_duration_ms; - -} OrtROCMProviderOptions; - -/** \brief TensorRT Provider Options - * - * \see OrtApi::SessionOptionsAppendExecutionProvider_TensorRT - */ -typedef struct OrtTensorRTProviderOptions { - int device_id; ///< CUDA device id (0 = default device) - int has_user_compute_stream; // indicator of user specified CUDA compute stream. - void* user_compute_stream; // user specified CUDA compute stream. - int trt_max_partition_iterations; // maximum iterations for TensorRT parser to get capability - int trt_min_subgraph_size; // minimum size of TensorRT subgraphs - size_t trt_max_workspace_size; // maximum workspace size for TensorRT. - int trt_fp16_enable; // enable TensorRT FP16 precision. Default 0 = false, nonzero = true - int trt_int8_enable; // enable TensorRT INT8 precision. Default 0 = false, nonzero = true - const char* trt_int8_calibration_table_name; // TensorRT INT8 calibration table name. - int trt_int8_use_native_calibration_table; // use native TensorRT generated calibration table. Default 0 = false, nonzero = true - int trt_dla_enable; // enable DLA. Default 0 = false, nonzero = true - int trt_dla_core; // DLA core number. Default 0 - int trt_dump_subgraphs; // dump TRT subgraph. Default 0 = false, nonzero = true - int trt_engine_cache_enable; // enable engine caching. Default 0 = false, nonzero = true - const char* trt_engine_cache_path; // specify engine cache path - int trt_engine_decryption_enable; // enable engine decryption. Default 0 = false, nonzero = true - const char* trt_engine_decryption_lib_path; // specify engine decryption library path - int trt_force_sequential_engine_build; // force building TensorRT engine sequentially. Default 0 = false, nonzero = true - // This is the legacy struct and don't add new fields here. - // For new field that can be represented by string, please add it in include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h - // For non-string field, need to create a new separate api to handle it. -} OrtTensorRTProviderOptions; - -/** \brief MIGraphX Provider Options - * - * \see OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX - */ -typedef struct OrtMIGraphXProviderOptions { - int device_id; // hip device id. - int migraphx_fp16_enable; // MIGraphX FP16 precision. Default 0 = false, nonzero = true - int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true - int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true - const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name - int migraphx_save_compiled_model; // migraphx save compiled model. Default 0 = false, noznero = true - const char* migraphx_save_model_path; // migraphx model path name - int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, noznero = true - const char* migraphx_load_model_path; // migraphx model path name -} OrtMIGraphXProviderOptions; - -/** \brief OpenVINO Provider Options - * - * \see OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO - */ -typedef struct OrtOpenVINOProviderOptions { -#ifdef __cplusplus - OrtOpenVINOProviderOptions() : device_type{}, - enable_npu_fast_compile{}, - device_id{}, - num_of_threads{}, - cache_dir{}, - context{}, - enable_opencl_throttling{}, - enable_dynamic_shapes{} {} -#endif - /** \brief Device type string - * - * Valid settings are one of: "CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU_FP16" - */ - const char* device_type; - unsigned char enable_npu_fast_compile; ///< 0 = disabled, nonzero = enabled - const char* device_id; - size_t num_of_threads; ///< 0 = Use default number of threads - const char* cache_dir; // path is set to empty by default - void* context; - unsigned char enable_opencl_throttling; ///< 0 = disabled, nonzero = enabled - unsigned char enable_dynamic_shapes; ///< 0 = disabled, nonzero = enabled -} OrtOpenVINOProviderOptions; - -struct OrtApi; -typedef struct OrtApi OrtApi; - -struct OrtTrainingApi; -typedef struct OrtTrainingApi OrtTrainingApi; - -/** \brief The helper interface to get the right version of OrtApi - * - * Get a pointer to this structure through ::OrtGetApiBase - */ -struct OrtApiBase { - /** \brief Get a pointer to the requested version of the ::OrtApi - * - * \param[in] version Must be ::ORT_API_VERSION - * \return The ::OrtApi for the version requested, nullptr will be returned if this version is unsupported, for example when using a runtime - * older than the version created with this header file. - * - * One can call GetVersionString() to get the version of the Onnxruntime library for logging - * and error reporting purposes. - */ - const OrtApi*(ORT_API_CALL* GetApi)(uint32_t version)NO_EXCEPTION; - - /** \brief Returns a null terminated string of the version of the Onnxruntime library (eg: "1.8.1") - * - * \return UTF-8 encoded version string. Do not deallocate the returned buffer. - */ - const char*(ORT_API_CALL* GetVersionString)(void)NO_EXCEPTION; -}; - -typedef struct OrtApiBase OrtApiBase; - -/** \brief The Onnxruntime library's entry point to access the C API - * - * Call this to get the a pointer to an ::OrtApiBase - */ -ORT_EXPORT const OrtApiBase* ORT_API_CALL OrtGetApiBase(void) NO_EXCEPTION; - -/** \brief Thread work loop function - * - * Onnxruntime will provide the working loop on custom thread creation - * Argument is an onnxruntime built-in type which will be provided when thread pool calls OrtCustomCreateThreadFn - */ -typedef void (*OrtThreadWorkerFn)(void* ort_worker_fn_param); - -typedef const struct OrtCustomHandleType { - char __place_holder; -}* OrtCustomThreadHandle; - -/** \brief Ort custom thread creation function - * - * The function should return a thread handle to be used in onnxruntime thread pools - * Onnxruntime will throw exception on return value of nullptr or 0, indicating that the function failed to create a thread - */ -typedef OrtCustomThreadHandle (*OrtCustomCreateThreadFn)(void* ort_custom_thread_creation_options, OrtThreadWorkerFn ort_thread_worker_fn, void* ort_worker_fn_param); - -/** \brief Custom thread join function - * - * Onnxruntime thread pool destructor will call the function to join a custom thread. - * Argument ort_custom_thread_handle is the value returned by OrtCustomCreateThreadFn - */ -typedef void (*OrtCustomJoinThreadFn)(OrtCustomThreadHandle ort_custom_thread_handle); - -typedef OrtStatus*(ORT_API_CALL* RegisterCustomOpsFn)(OrtSessionOptions* options, const OrtApiBase* api); - -/** \brief Callback function for RunAsync - * - * \param[in] user_data User specific data that passed back to the callback - * \param[out] outputs On succeed, outputs host inference results, on error, the value will be nullptr - * \param[out] num_outputs Number of outputs, on error, the value will be zero - * \param[out] status On error, status will provide details - */ -typedef void (*RunAsyncCallbackFn)(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status); - -/** \brief The C API - * - * All C API functions are defined inside this structure as pointers to functions. - * Call OrtApiBase::GetApi to get a pointer to it - * - * \nosubgrouping - */ -struct OrtApi { - /// \name OrtStatus - /// @{ - - /** - * \brief Create an OrtStatus from a null terminated string - * - * \param[in] code - * \param[in] msg A null-terminated string. Its contents will be copied. - * \return A new OrtStatus object, must be destroyed with OrtApi::ReleaseStatus - */ - OrtStatus*(ORT_API_CALL* CreateStatus)(OrtErrorCode code, _In_ const char* msg)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; - - /** \brief Get OrtErrorCode from OrtStatus - * - * \param[in] status - * \return OrtErrorCode that \p status was created with - */ - OrtErrorCode(ORT_API_CALL* GetErrorCode)(_In_ const OrtStatus* status) NO_EXCEPTION ORT_ALL_ARGS_NONNULL; - - /** \brief Get error string from OrtStatus - * - * \param[in] status - * \return The error message inside the `status`. Do not free the returned value. - */ - const char*(ORT_API_CALL* GetErrorMessage)(_In_ const OrtStatus* status)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; - - /// @} - /// \name OrtEnv - /// @{ - - /** \brief Create an OrtEnv - * - * \note Invoking this function will return the same instance of the environment as that returned by a previous call - * to another env creation function; all arguments to this function will be ignored. - * \param[in] log_severity_level The log severity level. - * \param[in] logid The log identifier. - * \param[out] out Returned newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateEnv, OrtLoggingLevel log_severity_level, _In_ const char* logid, _Outptr_ OrtEnv** out); - - /** \brief Create an OrtEnv - * - * \note Invoking this function will return the same instance of the environment as that returned by a previous call - * to another env creation function; all arguments to this function will be ignored. If you want to provide your - * own logging function, consider setting it using the SetUserLoggingFunction API instead. - * \param[in] logging_function A pointer to a logging function. - * \param[in] logger_param A pointer to arbitrary data passed as the ::OrtLoggingFunction `param` parameter to - * `logging_function`. This parameter is optional. - * \param[in] log_severity_level The log severity level. - * \param[in] logid The log identifier. - * \param[out] out Returned newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateEnvWithCustomLogger, _In_ OrtLoggingFunction logging_function, _In_opt_ void* logger_param, - _In_ OrtLoggingLevel log_severity_level, _In_ const char* logid, _Outptr_ OrtEnv** out); - - /** \brief Enable Telemetry - * - * \note Telemetry events are on by default since they are lightweight - * \param[in] env - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(EnableTelemetryEvents, _In_ const OrtEnv* env); - /** \brief Disable Telemetry - * - * \see OrtApi::EnableTelemetryEvents - * \param[in] env - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(DisableTelemetryEvents, _In_ const OrtEnv* env); - - /// @} - /// \name OrtSession - /// @{ - - /** \brief Create an OrtSession from a model file - * - * \param[in] env - * \param[in] model_path - * \param[in] options - * \param[out] out Returned newly created OrtSession. Must be freed with OrtApi::ReleaseSession - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - // TODO: document the path separator convention? '/' vs '\' - // TODO: should specify the access characteristics of model_path. Is this read only during the - // execution of CreateSession, or does the OrtSession retain a handle to the file/directory - // and continue to access throughout the OrtSession lifetime? - // What sort of access is needed to model_path : read or read/write? - ORT_API2_STATUS(CreateSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, - _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); - - /** \brief Create an OrtSession from memory - * - * \param[in] env - * \param[in] model_data - * \param[in] model_data_length - * \param[in] options - * \param[out] out Returned newly created OrtSession. Must be freed with OrtApi::ReleaseSession - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateSessionFromArray, _In_ const OrtEnv* env, _In_ const void* model_data, size_t model_data_length, - _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); - - /** \brief Run the model in an ::OrtSession - * - * Will not return until the model run has completed. Multiple threads might be used to run the model based on - * the options in the ::OrtSession and settings used when creating the ::OrtEnv - * - * \param[in] session - * \param[in] run_options If nullptr, will use a default ::OrtRunOptions - * \param[in] input_names Array of null terminated UTF8 encoded strings of the input names - * \param[in] inputs Array of ::OrtValue%s of the input values - * \param[in] input_len Number of elements in the input_names and inputs arrays - * \param[in] output_names Array of null terminated UTF8 encoded strings of the output names - * \param[in] output_names_len Number of elements in the output_names and outputs array - * \param[out] outputs Array of ::OrtValue%s that the outputs are stored in. This can also be - * an array of nullptr values, in this case ::OrtValue objects will be allocated and pointers - * to them will be set into the `outputs` array. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(Run, _Inout_ OrtSession* session, _In_opt_ const OrtRunOptions* run_options, - _In_reads_(input_len) const char* const* input_names, - _In_reads_(input_len) const OrtValue* const* inputs, size_t input_len, - _In_reads_(output_names_len) const char* const* output_names, size_t output_names_len, - _Inout_updates_all_(output_names_len) OrtValue** outputs); - - /// @} - /// \name OrtSessionOptions - /// @{ - - /** \brief Create an ::OrtSessionOptions object - * - * To use additional providers, you must build ORT with the extra providers enabled. Then call one of these - * functions to enable them in the session:
- * OrtSessionOptionsAppendExecutionProvider_CPU
- * OrtSessionOptionsAppendExecutionProvider_CUDA
- * OrtSessionOptionsAppendExecutionProvider_(remaining providers...)
- * The order they are called indicates the preference order as well. In other words call this method - * on your most preferred execution provider first followed by the less preferred ones. - * If none are called Ort will use its internal CPU execution provider. - * - * \param[out] options The newly created OrtSessionOptions. Must be freed with OrtApi::ReleaseSessionOptions - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateSessionOptions, _Outptr_ OrtSessionOptions** options); - - /** \brief Set filepath to save optimized model after graph level transformations - * - * \param[in] options - * \param[in] optimized_model_filepath - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetOptimizedModelFilePath, _Inout_ OrtSessionOptions* options, - _In_ const ORTCHAR_T* optimized_model_filepath); - - /** \brief Create a copy of an existing ::OrtSessionOptions - * - * \param[in] in_options OrtSessionOptions to copy - * \param[out] out_options Returned newly created ::OrtSessionOptions. Must be freed with OrtApi::ReleaseSessionOptions - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CloneSessionOptions, _In_ const OrtSessionOptions* in_options, - _Outptr_ OrtSessionOptions** out_options); - - /** \brief Set execution mode - * - * Controls whether you want to execute operators in your graph sequentially or in parallel. Usually when the model - * has many branches, setting this option to ExecutionMode.ORT_PARALLEL will give you better performance. - * See [docs/ONNX_Runtime_Perf_Tuning.md] for more details. - * - * \param[in] options - * \param[in] execution_mode - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetSessionExecutionMode, _Inout_ OrtSessionOptions* options, ExecutionMode execution_mode); - - /** \brief Enable profiling for a session - * - * \param[in] options - * \param[in] profile_file_prefix - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(EnableProfiling, _Inout_ OrtSessionOptions* options, _In_ const ORTCHAR_T* profile_file_prefix); - - /** \brief Disable profiling for a session - * - * \param[in] options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(DisableProfiling, _Inout_ OrtSessionOptions* options); - - /** \brief Enable the memory pattern optimization - * - * The idea is if the input shapes are the same, we could trace the internal memory allocation - * and generate a memory pattern for future request. So next time we could just do one allocation - * with a big chunk for all the internal memory allocation. - * \note Memory pattern optimization is only available when Sequential Execution mode is enabled (see OrtApi::SetSessionExecutionMode) - * - * \see OrtApi::DisableMemPattern - * - * \param[in] options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(EnableMemPattern, _Inout_ OrtSessionOptions* options); - - /** \brief Disable the memory pattern optimization - * - * \see OrtApi::EnableMemPattern - * - * \param[in] options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(DisableMemPattern, _Inout_ OrtSessionOptions* options); - - /** \brief Enable the memory arena on CPU - * - * Arena may pre-allocate memory for future usage. - * - * \param[in] options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(EnableCpuMemArena, _Inout_ OrtSessionOptions* options); - - /** \brief Disable the memory arena on CPU - * - * \param[in] options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(DisableCpuMemArena, _Inout_ OrtSessionOptions* options); - - /** \brief Set session log id - * - * \param[in] options - * \param[in] logid The log identifier. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetSessionLogId, _Inout_ OrtSessionOptions* options, const char* logid); - - /** \brief Set session log verbosity level - * - * Applies to session load, initialization, etc - * - * \param[in] options - * \param[in] session_log_verbosity_level \snippet{doc} snippets.dox Log Verbosity Level - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetSessionLogVerbosityLevel, _Inout_ OrtSessionOptions* options, int session_log_verbosity_level); - - /** \brief Set session log severity level - * - * \param[in] options - * \param[in] session_log_severity_level The log severity level (refer to ::OrtLoggingLevel for possible values). - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetSessionLogSeverityLevel, _Inout_ OrtSessionOptions* options, int session_log_severity_level); - - /** \brief Set the optimization level to apply when loading a graph - * - * Please see https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html for an in-depth explanation - * \param[in,out] options The session options object - * \param[in] graph_optimization_level The optimization level - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetSessionGraphOptimizationLevel, _Inout_ OrtSessionOptions* options, - GraphOptimizationLevel graph_optimization_level); - - /** \brief Sets the number of threads used to parallelize the execution within nodes - * - * When running a single node operation, ex. add, this sets the maximum number of threads to use. - * - * \note If built with OpenMP, this has no effect on the number of threads used. In this case - * use the OpenMP env variables to configure the number of intra op num threads. - * - * \param[in] options - * \param[in] intra_op_num_threads Number of threads to use
- * A value of 0 will use the default number of threads
- * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetIntraOpNumThreads, _Inout_ OrtSessionOptions* options, int intra_op_num_threads); - - /** \brief Sets the number of threads used to parallelize the execution of the graph - * - * If nodes can be run in parallel, this sets the maximum number of threads to use to run them in parallel. - * - * \note If sequential execution is enabled this value is ignored, it acts as if it was set to 1. - * - * \param[in] options - * \param[in] inter_op_num_threads Number of threads to use
- * A value of 0 will use the default number of threads
- * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetInterOpNumThreads, _Inout_ OrtSessionOptions* options, int inter_op_num_threads); - - /// @} - /// \name OrtCustomOpDomain - /// @{ - - /** \brief Create a custom op domain - * - * \param[in] domain - * \param[out] out Newly created domain. Must be freed with OrtApi::ReleaseCustomOpDomain - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateCustomOpDomain, _In_ const char* domain, _Outptr_ OrtCustomOpDomain** out); - - /** \brief Add a custom op to a custom op domain - * - * \note The OrtCustomOp* pointer must remain valid until the ::OrtCustomOpDomain using it is released - * - * \param[in] custom_op_domain - * \param[in] op - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CustomOpDomain_Add, _Inout_ OrtCustomOpDomain* custom_op_domain, _In_ const OrtCustomOp* op); - - /// @} - /// \name OrtSessionOptions - /// @{ - - /** \brief Add custom op domain to a session options - * - * \note The OrtCustomOpDomain* must not be deleted until all sessions using it are released - * - * \param[in] options - * \param[in] custom_op_domain - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(AddCustomOpDomain, _Inout_ OrtSessionOptions* options, _In_ OrtCustomOpDomain* custom_op_domain); - - /** \deprecated Use OrtApi::RegisterCustomOpsLibrary_V2. - * - * Registers custom ops from a shared library. - * - * Loads a shared library (dll on windows, so on linux, etc) named 'library_path' and looks for this entry point: - * OrtStatus* RegisterCustomOps(OrtSessionOptions * options, const OrtApiBase* api); - * It then passes in the provided session options to this function along with the api base. - * The handle to the loaded library is returned in library_handle. It can be freed by the caller after all sessions using the passed in - * session options are destroyed, or if an error occurs and it is non null. - * - * \param[in] options - * \param[in] library_path - * \param[out] library_handle OS specific handle to the loaded library (Use FreeLibrary on Windows, dlclose on Linux, etc.. to unload) - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(RegisterCustomOpsLibrary, _Inout_ OrtSessionOptions* options, _In_ const char* library_path, _Outptr_ void** library_handle); - - /// @} - /// \name OrtSession - /// @{ - - /** \brief Get input count for a session - * - * This number must also match the number of inputs passed to OrtApi::Run - * - * \see OrtApi::SessionGetInputTypeInfo, OrtApi::SessionGetInputName, OrtApi::Session - * - * \param[in] session - * \param[out] out Number of inputs - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionGetInputCount, _In_ const OrtSession* session, _Out_ size_t* out); - - /** \brief Get output count for a session - * - * This number must also match the number of outputs returned by OrtApi::Run - * - * \see OrtApi::SessionGetOutputTypeInfo, OrtApi::SessionGetOutputName, OrtApi::Session - * - * \param[in] session - * \param[out] out Number of outputs - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionGetOutputCount, _In_ const OrtSession* session, _Out_ size_t* out); - - /** \brief Get overridable initializer count - * - * \see OrtApi::SessionGetOverridableInitializerTypeInfo, OrtApi::SessionGetOverridableInitializerName - * - * \param[in] session - * \param[in] out - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionGetOverridableInitializerCount, _In_ const OrtSession* session, _Out_ size_t* out); - - /** \brief Get input type information - * - * \param[in] session - * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetInputCount returns (exclusive) - * \param[out] type_info Must be freed with OrtApi::ReleaseTypeInfo - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionGetInputTypeInfo, _In_ const OrtSession* session, size_t index, _Outptr_ OrtTypeInfo** type_info); - - /** \brief Get output type information - * - * \param[in] session - * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetOutputCount returns (exclusive) - * \param[out] type_info Must be freed with OrtApi::ReleaseTypeInfo - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionGetOutputTypeInfo, _In_ const OrtSession* session, size_t index, _Outptr_ OrtTypeInfo** type_info); - - /** \brief Get overridable initializer type information - * - * \param[in] session - * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetOverridableInitializerCount returns (exclusive) - * \param[out] type_info Must be freed with OrtApi::ReleaseTypeInfo - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionGetOverridableInitializerTypeInfo, _In_ const OrtSession* session, size_t index, _Outptr_ OrtTypeInfo** type_info); - - /** \brief Get input name - * - * \param[in] session - * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetInputCount returns (exclusive) - * \param[in] allocator - * \param[out] value Set to a null terminated UTF-8 encoded string allocated using `allocator`. Must be freed using `allocator`. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionGetInputName, _In_ const OrtSession* session, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** value); - - /** \brief Get output name - * - * \param[in] session - * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetOutputCount returns (exclusive) - * \param[in] allocator - * \param[out] value Set to a null terminated UTF-8 encoded string allocated using `allocator`. Must be freed using `allocator`. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionGetOutputName, _In_ const OrtSession* session, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** value); - - /** \brief Get overridable initializer name - * - * \param[in] session - * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetOverridableInitializerCount returns (exclusive) - * \param[in] allocator - * \param[out] value Set to a null terminated UTF-8 encoded string allocated using `allocator`. Must be freed using `allocator`. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionGetOverridableInitializerName, _In_ const OrtSession* session, size_t index, - _Inout_ OrtAllocator* allocator, _Outptr_ char** value); - - /// @} - /// \name OrtRunOptions - /// @{ - - /** \brief Create an OrtRunOptions - * - * \param[out] out Returned newly created ::OrtRunOptions. Must be freed with OrtApi::ReleaseRunOptions - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateRunOptions, _Outptr_ OrtRunOptions** out); - - /** \brief Set per-run log verbosity level - * - * \see OrtApi::RunOptionsGetRunLogVerbosityLevel - * - * \param[in] options - * \param[in] log_verbosity_level \snippet{doc} snippets.dox Log Verbosity Level - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(RunOptionsSetRunLogVerbosityLevel, _Inout_ OrtRunOptions* options, int log_verbosity_level); - - /** \brief Set per-run log severity level - * - * \see OrtApi::RunOptionsGetRunLogSeverityLevel - * - * \param[in] options - * \param[in] log_severity_level The log severity level (refer to ::OrtLoggingLevel for possible values). - */ - ORT_API2_STATUS(RunOptionsSetRunLogSeverityLevel, _Inout_ OrtRunOptions* options, int log_severity_level); - - /** \brief Set per-run tag - * - * This is used in a per-run log identifier. - * - * \see OrtApi::RunOptionsGetRunTag - * - * \param[in] options - * \param[in] run_tag The run tag. - */ - ORT_API2_STATUS(RunOptionsSetRunTag, _Inout_ OrtRunOptions* options, _In_ const char* run_tag); - - /** \brief Get per-run log verbosity level - * - * \see OrtApi::RunOptionsSetRunLogVerbosityLevel - * - * \param[in] options - * \param[out] log_verbosity_level \snippet{doc} snippets.dox Log Verbosity Level - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(RunOptionsGetRunLogVerbosityLevel, _In_ const OrtRunOptions* options, - _Out_ int* log_verbosity_level); - - /** \brief Get per-run log severity level - * - * \see OrtApi::RunOptionsSetRunLogSeverityLevel - * - * \param[in] options - * \param[out] log_severity_level The log severity level (refer to ::OrtLoggingLevel for possible values). - */ - ORT_API2_STATUS(RunOptionsGetRunLogSeverityLevel, _In_ const OrtRunOptions* options, _Out_ int* log_severity_level); - - /** \brief Get per-run tag - * - * This is used in a per-run log identifier. - * - * \see OrtApi::RunOptionsSetRunTag - * - * \param[in] options - * \param[out] run_tag The run tag. - * Do not free this value, it is owned by `options`. It will be invalidated if the run tag - * changes (i.e., with OrtApi::RunOptionsSetRunTag) or `options` is freed. - */ - ORT_API2_STATUS(RunOptionsGetRunTag, _In_ const OrtRunOptions* options, _Out_ const char** run_tag); - - /** \brief Set terminate flag - * - * If a currently executing session needs to be force terminated, this can be called from another thread to force it to fail with an error. - * - * \param[in] options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(RunOptionsSetTerminate, _Inout_ OrtRunOptions* options); - - /** \brief Clears the terminate flag - * - * Used so the OrtRunOptions instance can be used in a new OrtApi::Run call without it instantly terminating - * - * \param[in] options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(RunOptionsUnsetTerminate, _Inout_ OrtRunOptions* options); - - /// @} - /// \name OrtValue - /// @{ - - /** \brief Create a tensor - * - * Create a tensor using a supplied ::OrtAllocator - * - * \param[in] allocator - * \param[in] shape Pointer to the tensor shape dimensions. - * \param[in] shape_len The number of tensor shape dimensions. - * \param[in] type - * \param[out] out Returns newly created ::OrtValue. Must be freed with OrtApi::ReleaseValue - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateTensorAsOrtValue, _Inout_ OrtAllocator* allocator, _In_ const int64_t* shape, size_t shape_len, - ONNXTensorElementDataType type, _Outptr_ OrtValue** out); - - /** \brief Create a tensor backed by a user supplied buffer - * - * Create a tensor with user's buffer. You can fill the buffer either before calling this function or after. - * p_data is owned by caller. ReleaseValue won't release p_data. - * - * \param[in] info Memory description of where the p_data buffer resides (CPU vs GPU etc). - * \param[in] p_data Pointer to the data buffer. - * \param[in] p_data_len The number of bytes in the data buffer. - * \param[in] shape Pointer to the tensor shape dimensions. - * \param[in] shape_len The number of tensor shape dimensions. - * \param[in] type The data type. - * \param[out] out Returns newly created ::OrtValue. Must be freed with OrtApi::ReleaseValue - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateTensorWithDataAsOrtValue, _In_ const OrtMemoryInfo* info, _Inout_ void* p_data, - size_t p_data_len, _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, - _Outptr_ OrtValue** out); - - /** \brief Return if an ::OrtValue is a tensor type - * - * \param[in] value A tensor type (string tensors are not supported) - * \param[out] out Set to 1 iff ::OrtValue is a tensor, 0 otherwise - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(IsTensor, _In_ const OrtValue* value, _Out_ int* out); - - /** \brief Get a pointer to the raw data inside a tensor - * - * Used to read/write/modify the internal tensor data directly. - * \note The returned pointer is valid until the \p value is destroyed. - * - * \param[in] value A tensor type (string tensors are not supported) - * \param[out] out Filled in with a pointer to the internal storage - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetTensorMutableData, _In_ OrtValue* value, _Outptr_ void** out); - - /** \brief Set all strings at once in a string tensor - * - * \param[in,out] value A tensor of type ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING - * \param[in] s An array of strings. Each string in this array must be null terminated. - * \param[in] s_len Count of strings in s (Must match the size of \p value's tensor shape) - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(FillStringTensor, _Inout_ OrtValue* value, _In_ const char* const* s, size_t s_len); - - /** \brief Get total byte length for all strings in a string tensor - * - * Typically used with OrtApi::GetStringTensorContent - * - * \param[in] value A tensor of type ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING - * \param[out] len Total byte length of all strings (does not include trailing nulls) - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetStringTensorDataLength, _In_ const OrtValue* value, _Out_ size_t* len); - - /** \brief Get all strings from a string tensor - * - * An example of the results:
- * Given \p value is a string tensor with the strings { "This" "is" "a" "test" }
- * \p s must have a size of 11 bytes
- * \p offsets must have 4 elements
- * After the call, these values will be filled in:
- * \p s will contain "Thisisatest"
- * \p offsets will contain { 0, 4, 6, 7 }
- * The length of the last string is just s_len - offsets[last] - * - * \param[in] value A tensor of type ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING - * \param[in] s Buffer to sequentially write all tensor strings to. Each string is NOT null-terminated. - * \param[in] s_len Number of bytes of buffer pointed to by \p s (Get it from OrtApi::GetStringTensorDataLength) - * \param[out] offsets Array of start offsets into the strings written to \p s - * \param[in] offsets_len Number of elements in offsets - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetStringTensorContent, _In_ const OrtValue* value, _Out_writes_bytes_all_(s_len) void* s, - size_t s_len, _Out_writes_all_(offsets_len) size_t* offsets, size_t offsets_len); - - /// @} - /// \name OrtTypeInfo - /// @{ - - /** \brief Get ::OrtTensorTypeAndShapeInfo from an ::OrtTypeInfo - * - * \param[in] type_info - * \param[out] out Do not free this value, it will be valid until type_info is freed. - * If type_info does not represent tensor, this value will be set to nullptr. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CastTypeInfoToTensorInfo, _In_ const OrtTypeInfo* type_info, - _Outptr_result_maybenull_ const OrtTensorTypeAndShapeInfo** out); - - /** \brief Get ::ONNXType from ::OrtTypeInfo - * - * \param[in] type_info - * \param[out] out - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetOnnxTypeFromTypeInfo, _In_ const OrtTypeInfo* type_info, _Out_ enum ONNXType* out); - - /// @} - /// \name OrtTensorTypeAndShapeInfo - /// @{ - - /** \brief Create an ::OrtTensorTypeAndShapeInfo object - * - * \param[out] out Returns newly created ::OrtTensorTypeAndShapeInfo. Must be freed with OrtApi::ReleaseTensorTypeAndShapeInfo - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateTensorTypeAndShapeInfo, _Outptr_ OrtTensorTypeAndShapeInfo** out); - - /** \brief Set element type in ::OrtTensorTypeAndShapeInfo - * - * \param[in] info - * \param[in] type - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetTensorElementType, _Inout_ OrtTensorTypeAndShapeInfo* info, enum ONNXTensorElementDataType type); - - /** \brief Set shape information in ::OrtTensorTypeAndShapeInfo - * - * \param[in] info - * \param[in] dim_values Array with `dim_count` elements. Can contain negative values. - * \param[in] dim_count Number of elements in `dim_values` - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetDimensions, OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count); - - /** \brief Get element type in ::OrtTensorTypeAndShapeInfo - * - * \see OrtApi::SetTensorElementType - * - * \param[in] info - * \param[out] out - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetTensorElementType, _In_ const OrtTensorTypeAndShapeInfo* info, - _Out_ enum ONNXTensorElementDataType* out); - - /** \brief Get dimension count in ::OrtTensorTypeAndShapeInfo - * - * \see OrtApi::GetDimensions - * - * \param[in] info - * \param[out] out - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetDimensionsCount, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out); - - /** \brief Get dimensions in ::OrtTensorTypeAndShapeInfo - * - * \param[in] info - * \param[out] dim_values Array with `dim_values_length` elements. On return, filled with the dimensions stored in the ::OrtTensorTypeAndShapeInfo - * \param[in] dim_values_length Number of elements in `dim_values`. Use OrtApi::GetDimensionsCount to get this value - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, - size_t dim_values_length); - - /** \brief Get symbolic dimension names in ::OrtTensorTypeAndShapeInfo - * - * \param[in] info - * \param[in] dim_params Array with `dim_params_length` elements. On return filled with pointers to null terminated strings of the dimension names - * \param[in] dim_params_length Number of elements in `dim_params`. Use OrtApi::GetDimensionsCount to get this value - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetSymbolicDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, - _Out_writes_all_(dim_params_length) const char* dim_params[], size_t dim_params_length); - - /** \brief Get total number of elements in a tensor shape from an ::OrtTensorTypeAndShapeInfo - * - * Return the number of elements specified by the tensor shape (all dimensions multiplied by each other). - * For 0 dimensions, 1 is returned. If any dimension is less than 0, the result is always -1. - * - * Examples:
- * [] = 1
- * [1,3,4] = 12
- * [2,0,4] = 0
- * [-1,3,4] = -1
- * - * \param[in] info - * \param[out] out Number of elements - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetTensorShapeElementCount, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out); - - /// @} - /// \name OrtValue - /// @{ - - /** \brief Get type and shape information from a tensor ::OrtValue - * - * \param[in] value Must be a tensor (not a map/sequence/etc) or will return failure - * \param[out] out Newly created ::OrtTensorTypeAndShapeInfo. Must be freed with OrtApi::ReleaseTensorTypeAndShapeInfo - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetTensorTypeAndShape, _In_ const OrtValue* value, _Outptr_ OrtTensorTypeAndShapeInfo** out); - - /** \brief Get type information of an OrtValue - * - * \param[in] value - * \param[out] out Newly created ::OrtTypeInfo. Must be freed with OrtApi::ReleaseTypeInfo - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetTypeInfo, _In_ const OrtValue* value, _Outptr_result_maybenull_ OrtTypeInfo** out); - - /** \brief Get ONNXType of an ::OrtValue - * - * \param[in] value - * \param[out] out - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetValueType, _In_ const OrtValue* value, _Out_ enum ONNXType* out); - - /// @} - /// \name OrtMemoryInfo - /// @{ - - /** \brief Create an ::OrtMemoryInfo - * - * \param[in] name - * \param[in] type - * \param[in] id - * \param[in] mem_type - * \param[out] out Newly created ::OrtMemoryInfo. Must be freed with OrtAPi::ReleaseMemoryInfo - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateMemoryInfo, _In_ const char* name, enum OrtAllocatorType type, int id, - enum OrtMemType mem_type, _Outptr_ OrtMemoryInfo** out); - - /** \brief Create an ::OrtMemoryInfo for CPU memory - * - * Special case version of OrtApi::CreateMemoryInfo for CPU based memory. Same as using OrtApi::CreateMemoryInfo with name = "Cpu" and id = 0. - * - * \param[in] type - * \param[in] mem_type - * \param[out] out - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateCpuMemoryInfo, enum OrtAllocatorType type, enum OrtMemType mem_type, - _Outptr_ OrtMemoryInfo** out); - - /** \brief Compare ::OrtMemoryInfo objects for equality - * - * Compares all settings of each ::OrtMemoryInfo for equality - * - * \param[in] info1 - * \param[in] info2 - * \param[out] out Set to 0 if equal, -1 if not equal - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CompareMemoryInfo, _In_ const OrtMemoryInfo* info1, _In_ const OrtMemoryInfo* info2, _Out_ int* out); - - /** \brief Get name from ::OrtMemoryInfo - * - * \param[in] ptr - * \param[out] out Writes null terminated string to this pointer. Do NOT free the returned pointer. It is valid for the lifetime of the ::OrtMemoryInfo - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(MemoryInfoGetName, _In_ const OrtMemoryInfo* ptr, _Out_ const char** out); - - /** \brief Get the id from ::OrtMemoryInfo - */ - ORT_API2_STATUS(MemoryInfoGetId, _In_ const OrtMemoryInfo* ptr, _Out_ int* out); - - /** \brief Get the ::OrtMemType from ::OrtMemoryInfo - */ - ORT_API2_STATUS(MemoryInfoGetMemType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtMemType* out); - - /** \brief Get the ::OrtAllocatorType from ::OrtMemoryInfo - */ - ORT_API2_STATUS(MemoryInfoGetType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtAllocatorType* out); - - /// @} - /// \name OrtAllocator - /// @{ - - /// \brief Calls OrtAllocator::Alloc function - ORT_API2_STATUS(AllocatorAlloc, _Inout_ OrtAllocator* ort_allocator, size_t size, _Outptr_ void** out); - /// \brief Calls OrtAllocator::Free function - ORT_API2_STATUS(AllocatorFree, _Inout_ OrtAllocator* ort_allocator, void* p); - /// \brief Calls OrtAllocator::Info function - ORT_API2_STATUS(AllocatorGetInfo, _In_ const OrtAllocator* ort_allocator, _Outptr_ const struct OrtMemoryInfo** out); - - /** \brief Get the default allocator - * - * The default allocator is a CPU based, non-arena. Always returns the same pointer to the same default allocator. - * - * \param[out] out Returned value should NOT be freed - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetAllocatorWithDefaultOptions, _Outptr_ OrtAllocator** out); - - /// @} - /// \name OrtSessionOptions - /// @{ - - /** \brief Override session symbolic dimensions - * - * Override symbolic dimensions (by specific denotation strings) with actual values if known at session initialization time to enable - * optimizations that can take advantage of fixed values (such as memory planning, etc) - * - * \param[in] options - * \param[in] dim_denotation - * \param[in] dim_value - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(AddFreeDimensionOverride, _Inout_ OrtSessionOptions* options, _In_ const char* dim_denotation, - _In_ int64_t dim_value); - - /// @} - /// \name OrtValue - /// @{ - - /* Internal information (not seen in Doxygen) - * - * APIs to support non-tensor types - map and sequence. - * Currently only the following types are supported - * Note: the following types should be kept in sync with data_types.h - * Map types - * ========= - * std::map - * std::map - * std::map - * std::map - * std::map - * std::map - * std::map - * std::map - * - * Sequence types - * ============== - * std::vector - * std::vector - * std::vector - * std::vector - * std::vector> - * std::vector - */ - - /** \brief Get non tensor data from an ::OrtValue - * - * If `value` is of type ONNX_TYPE_MAP, you need to retrieve the keys and values - * separately. Use index=0 to retrieve keys and index=1 to retrieve values. - * If `value` is of type ONNX_TYPE_SEQUENCE, use index to retrieve the index'th element - * of the sequence. - * - * \param[in] value - * \param[in] index See above for usage based on `value` type - * \param[in] allocator Allocator used to allocate ::OrtValue - * \param[out] out Created ::OrtValue that holds the element requested. Must be freed with OrtApi::ReleaseValue - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetValue, _In_ const OrtValue* value, int index, _Inout_ OrtAllocator* allocator, - _Outptr_ OrtValue** out); - - /** \brief Get non tensor value count from an ::OrtValue - * - * If `value` is of type ONNX_TYPE_MAP 2 will always be returned. For ONNX_TYPE_SEQUENCE - * the number of elements in the sequence will be returned - * - * \param[in] value - * \param[out] out - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetValueCount, _In_ const OrtValue* value, _Out_ size_t* out); - - /** \brief Create a map or sequence ::OrtValue - * - * To construct a map (ONNX_TYPE_MAP), use num_values = 2 and `in` should be an array of 2 ::OrtValue%s - * representing keys and values.
- * - * To construct a sequence (ONNX_TYPE_SEQUENCE), use num_values = N where N is the number of the elements in the - * sequence. 'in' should be an array of N ::OrtValue%s. - * - * \param[in] in See above for details - * \param[in] num_values - * \param[in] value_type Must be either ONNX_TYPE_MAP or ONNX_TYPE_SEQUENCE - * \param[out] out Newly created ::OrtValue. Must be freed with OrtApi::ReleaseValue - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateValue, _In_reads_(num_values) const OrtValue* const* in, size_t num_values, - enum ONNXType value_type, _Outptr_ OrtValue** out); - - /** \brief Create an opaque (custom user defined type) ::OrtValue - * - * Constructs an ::OrtValue that contains a value of non-standard type created for - * experiments or while awaiting standardization. ::OrtValue in this case would contain - * an internal representation of the Opaque type. Opaque types are distinguished from - * each other by two strings 1) domain and 2) type name. The combination of the two - * must be unique, so the type representation is properly identified internally. The combination - * must be properly registered from within ORT at both compile/run time or by another API. - * - * To construct the ::OrtValue pass domain and type names, also a pointer to a data container - * the type of which must be known to both ORT and the client program. That data container may or may - * not match the internal representation of the Opaque type. The sizeof(data_container) is passed for - * verification purposes. - * - * \param[in] domain_name Null terminated string of the domain name - * \param[in] type_name Null terminated string of the type name - * \param[in] data_container User pointer Data to populate ::OrtValue - * \param[in] data_container_size Size in bytes of what `data_container` points to - * \param[out] out Newly created ::OrtValue. Must be freed with OrtApi::ReleaseValue - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateOpaqueValue, _In_z_ const char* domain_name, _In_z_ const char* type_name, - _In_ const void* data_container, size_t data_container_size, _Outptr_ OrtValue** out); - - /** \brief Get internal data from an opaque (custom user defined type) ::OrtValue - * - * Copies internal data from an opaque value into a user provided buffer - * - * \see OrtApi::CreateOpaqueValue - * - * \param[in] domain_name Null terminated string of the domain name - * \param[in] type_name Null terminated string of the type name - * \param[in] in The opaque ::OrtValue - * \param[out] data_container Buffer to copy data into - * \param[out] data_container_size Size in bytes of the buffer pointed to by data_container. Must match the size of the internal buffer. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetOpaqueValue, _In_ const char* domain_name, _In_ const char* type_name, _In_ const OrtValue* in, - _Out_ void* data_container, size_t data_container_size); - - /// @} - /// \name OrtKernelInfo - /// Custom operator APIs. - /// @{ - - /** \brief Get a float stored as an attribute in the graph node - * - * \param[in] info ::OrtKernelInfo instance - * \param[in] name Null terminated string of the name of the attribute - * \param[out] out Pointer to memory where the attribute will be stored - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(KernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, - _Out_ float* out); - - /** \brief Fetch a 64-bit int stored as an attribute in the graph node - * - * \param[in] info ::OrtKernelInfo instance - * \param[in] name Null terminated string of the name of the attribute - * \param[out] out Pointer to memory where the attribute will be stored - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(KernelInfoGetAttribute_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, - _Out_ int64_t* out); - - /** \brief Fetch a string stored as an attribute in the graph node - * - * If `out` is nullptr, the value of `size` is set to the true size of the string - * attribute, and a success status is returned. - * - * If the `size` parameter is greater than or equal to the actual string attribute's size, - * the value of `size` is set to the true size of the string attribute, the provided memory - * is filled with the attribute's contents, and a success status is returned. - * - * If the `size` parameter is less than the actual string attribute's size and `out` - * is not nullptr, the value of `size` is set to the true size of the string attribute - * and a failure status is returned.) - * - * \param[in] info ::OrtKernelInfo instance - * \param[in] name Null terminated string of the name of the attribute - * \param[out] out Pointer to memory where the attribute will be stored - * \param[in,out] size See above comments for details - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(KernelInfoGetAttribute_string, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ char* out, - _Inout_ size_t* size); - - /// @} - /// \name OrtKernelContext - /// Custom operator APIs. - /// @{ - - /** \brief Used for custom operators, get the input count of a kernel - * - * \see ::OrtCustomOp - */ - ORT_API2_STATUS(KernelContext_GetInputCount, _In_ const OrtKernelContext* context, _Out_ size_t* out); - - /** \brief Used for custom operators, get the output count of a kernel - * - * \see ::OrtCustomOp - */ - ORT_API2_STATUS(KernelContext_GetOutputCount, _In_ const OrtKernelContext* context, _Out_ size_t* out); - - /** \brief Used for custom operators, get an input of a kernel - * - * The function attempts fetches the input of the kernel. If the input is optional - * and not present, the function returns success and out is set to nullptr. - * - * \param[in] context ::OrtKernelContext instance - * \param[in] index See KernelContext_GetInputCount for boundaries check. - * \param[out] out OrtValue if the input is present otherwise is set nullptr - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(KernelContext_GetInput, _In_ const OrtKernelContext* context, _In_ size_t index, - _Out_ const OrtValue** out); - - /** \brief Used for custom operators, get an output of a kernel - * - * The function attempts fetches the output of the kernel. If the output is optional - * and not present, the function returns success and out is set to nullptr. - * - * \param[in] context ::OrtKernelContext instance - * \param[in] index See KernelContext_GetOutputCount for boundaries check. - * \param[in] dim_values output dimensions - * \param[in] dim_count number of dimensions - * \param[out] out a ptr to OrtValue to output otherwise set to nullptr - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(KernelContext_GetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index, - _In_ const int64_t* dim_values, size_t dim_count, _Outptr_ OrtValue** out); - - /// @} - /// \name OrtEnv - /// @{ - ORT_CLASS_RELEASE(Env); - /// @} - /// \name OrtStatus - /// @{ - ORT_CLASS_RELEASE(Status); - /// @} - /// \name OrtMemoryInfo - /// @{ - ORT_CLASS_RELEASE(MemoryInfo); - /// @} - /// \name OrtSession - /// @{ - ORT_CLASS_RELEASE(Session); // Don't call ReleaseSession from Dllmain (because session owns a thread pool) - /// @} - /// \name OrtValue - /// @{ - ORT_CLASS_RELEASE(Value); - /// @} - /// \name OrtRunOptions - /// @{ - ORT_CLASS_RELEASE(RunOptions); - /// @} - /// \name OrtTypeInfo - /// @{ - ORT_CLASS_RELEASE(TypeInfo); - /// @} - /// \name OrtTensorTypeAndShapeInfo - /// @{ - ORT_CLASS_RELEASE(TensorTypeAndShapeInfo); - /// @} - /// \name OrtSessionOptions - /// @{ - ORT_CLASS_RELEASE(SessionOptions); - /// @} - /// \name OrtCustomOpDomain - /// @{ - ORT_CLASS_RELEASE(CustomOpDomain); - - /// @} - /// \name OrtTypeInfo - /// @{ - - /** \brief Get denotation from type information - * - * Augments ::OrtTypeInfo to return denotations on the type. - * - * This is used by WinML to determine if an input/output is intended to be an Image or a Tensor. - * - * \param[in] type_info - * \param[out] denotation Pointer to the null terminated denotation string is written to this pointer. This pointer is valid until the object is destroyed or the name is changed, do not free. - * \param[out] len Length in bytes of the string returned in `denotation` - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetDenotationFromTypeInfo, _In_ const OrtTypeInfo* type_info, _Out_ const char** const denotation, - _Out_ size_t* len); - - /** \brief Get detailed map information from an ::OrtTypeInfo - * - * This augments ::OrtTypeInfo to return an ::OrtMapTypeInfo when the type is a map. - * The OrtMapTypeInfo has additional information about the map's key type and value type. - * - * This is used by WinML to support model reflection APIs. - * - * \param[out] type_info - * \param[out] out A pointer to the ::OrtMapTypeInfo. Do not free this value. If type_info - * does not contain a map, this value will be set to nullptr. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CastTypeInfoToMapTypeInfo, _In_ const OrtTypeInfo* type_info, - _Outptr_result_maybenull_ const OrtMapTypeInfo** out); - - /** \brief Cast ::OrtTypeInfo to an ::OrtSequenceTypeInfo - * - * This api augments ::OrtTypeInfo to return an ::OrtSequenceTypeInfo when the type is a sequence. - * The ::OrtSequenceTypeInfo has additional information about the sequence's element type. - * - * This is used by WinML to support model reflection APIs. - * - * \param[in] type_info - * \param[out] out A pointer to the OrtSequenceTypeInfo. Do not free this value. If type_info - * doesn not contain a sequence, this value will be set to nullptr. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CastTypeInfoToSequenceTypeInfo, _In_ const OrtTypeInfo* type_info, - _Outptr_result_maybenull_ const OrtSequenceTypeInfo** out); - - /// @} - /// \name OrtMapTypeInfo - /// @{ - - /** \brief Get key type from an ::OrtMapTypeInfo - * - * Key types are restricted to being scalar types. - * - * This is used by WinML to support model reflection APIs. - * - * \param[in] map_type_info - * \param[out] out - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetMapKeyType, _In_ const OrtMapTypeInfo* map_type_info, _Out_ enum ONNXTensorElementDataType* out); - - /** \brief Get the value type from an ::OrtMapTypeInfo - * - * \param[in] map_type_info - * \param[out] type_info - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetMapValueType, _In_ const OrtMapTypeInfo* map_type_info, _Outptr_ OrtTypeInfo** type_info); - - /// @} - /// \name OrtSequenceTypeInfo - /// @{ - - /** \brief Get element type from an ::OrtSequenceTypeInfo - * - * This is used by WinML to support model reflection APIs. - * - * \param[in] sequence_type_info - * \param[out] type_info - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetSequenceElementType, _In_ const OrtSequenceTypeInfo* sequence_type_info, - _Outptr_ OrtTypeInfo** type_info); - - /// @} - /// \name OrtMapTypeInfo - /// @{ - ORT_CLASS_RELEASE(MapTypeInfo); - /// @} - /// \name OrtSequenceTypeInfo - /// @{ - ORT_CLASS_RELEASE(SequenceTypeInfo); - - /// @} - /// \name OrtSession - /// @{ - - /** \brief End profiling and return filename of the profile data - * - * Profiling is turned on through OrtApi::EnableProfiling - * - * \param[in] session - * \param[in] allocator - * \param[out] out Null terminated string of the filename, allocated using `allocator`. Must be freed using `allocator` - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionEndProfiling, _In_ OrtSession* session, _Inout_ OrtAllocator* allocator, _Outptr_ char** out); - - /** \brief Get ::OrtModelMetadata from an ::OrtSession - * - * \param[in] session - * \param[out] out Newly created ::OrtModelMetadata. Must be freed using OrtApi::ReleaseModelMetadata - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionGetModelMetadata, _In_ const OrtSession* session, _Outptr_ OrtModelMetadata** out); - - /// @} - /// \name OrtModelMetadata - /// @{ - - /** \brief Get `producer name` from an ::OrtModelMetadata - * - * \param[in] model_metadata - * \param[in] allocator - * \param[out] value Set to a null terminated string allocated using `allocator`. Must be freed using `allocator` - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(ModelMetadataGetProducerName, _In_ const OrtModelMetadata* model_metadata, - _Inout_ OrtAllocator* allocator, _Outptr_ char** value); - - /** \brief Get `graph name` from an ::OrtModelMetadata - * - * \param[in] model_metadata - * \param[in] allocator - * \param[out] value Set to a null terminated string allocated using `allocator`. Must be freed using `allocator` - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(ModelMetadataGetGraphName, _In_ const OrtModelMetadata* model_metadata, - _Inout_ OrtAllocator* allocator, _Outptr_ char** value); - - /** \brief Get `domain` from an ::OrtModelMetadata - * - * \param[in] model_metadata - * \param[in] allocator - * \param[out] value Set to a null terminated string allocated using `allocator`. Must be freed using `allocator` - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(ModelMetadataGetDomain, _In_ const OrtModelMetadata* model_metadata, _Inout_ OrtAllocator* allocator, - _Outptr_ char** value); - - /** \brief Get `description` from an ::OrtModelMetadata - * - * \param[in] model_metadata - * \param[in] allocator - * \param[out] value Set to a null terminated string allocated using `allocator`. Must be freed using `allocator` - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(ModelMetadataGetDescription, _In_ const OrtModelMetadata* model_metadata, - _Inout_ OrtAllocator* allocator, _Outptr_ char** value); - - /** \brief Return data for a key in the custom metadata map in an ::OrtModelMetadata - * - * \param[in] model_metadata - * \param[in] allocator - * \param[in] key Null terminated string - * \param[out] value Set to a null terminated string allocated using `allocator`. Must be freed using `allocator` - * `value` will be set to nullptr if the given key is not found in the custom metadata map. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(ModelMetadataLookupCustomMetadataMap, _In_ const OrtModelMetadata* model_metadata, - _Inout_ OrtAllocator* allocator, _In_ const char* key, _Outptr_result_maybenull_ char** value); - - /** \brief Get version number from an ::OrtModelMetadata - * - * \param[in] model_metadata - * \param[out] value Set to the version number - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(ModelMetadataGetVersion, _In_ const OrtModelMetadata* model_metadata, _Out_ int64_t* value); - - ORT_CLASS_RELEASE(ModelMetadata); - - /// @} - /// \name OrtEnv - /// @{ - - /** \brief Create an OrtEnv - * - * Create an environment with global threadpools that will be shared across sessions. - * Use this in conjunction with OrtApi::DisablePerSessionThreads or else the session will use - * its own thread pools. - * - * \param[in] log_severity_level The log severity level. - * \param[in] logid The log identifier. - * \param[in] tp_options - * \param[out] out Returned newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateEnvWithGlobalThreadPools, OrtLoggingLevel log_severity_level, _In_ const char* logid, - _In_ const OrtThreadingOptions* tp_options, _Outptr_ OrtEnv** out); - - /// @} - /// \name OrtSessionOptions - /// @{ - - /** \brief Use global thread pool on a session - * - * Disable using per session thread pool and use the shared global threadpool. - * This should be used in conjunction with OrtApi::CreateEnvWithGlobalThreadPools. - * - * \param[in] options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(DisablePerSessionThreads, _Inout_ OrtSessionOptions* options); - - /// @} - /// \name OrtThreadingOptions - /// @{ - - /** \brief Create an ::OrtThreadingOptions - * - * \param[out] out Newly created ::OrtThreadingOptions. Must be freed with OrtApi::ReleaseThreadingOptions - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateThreadingOptions, _Outptr_ OrtThreadingOptions** out); - - ORT_CLASS_RELEASE(ThreadingOptions); - - /// @} - /// \name OrtModelMetadata - /// @{ - - /** - * - * \param[in] model_metadata - * \param[in] allocator - * \param[out] keys Array of null terminated strings (array count = num_keys) allocated using `allocator`. - * The strings and the pointer array must be freed using `allocator` - * `keys` will be set to nullptr if the custom metadata map is empty. - * \param[out] num_keys Set to the number of elements in the `keys` array - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(ModelMetadataGetCustomMetadataMapKeys, _In_ const OrtModelMetadata* model_metadata, - _Inout_ OrtAllocator* allocator, _Outptr_result_buffer_maybenull_(*num_keys) char*** keys, _Out_ int64_t* num_keys); - - /// @} - /// \name OrtSessionOptions - /// @{ - - /** - * - * Override symbolic dimensions (by specific name strings) with actual values - * if known at session initialization time to enable optimizations that can - * take advantage of fixed values (such as memory planning, etc) - * - */ - ORT_API2_STATUS(AddFreeDimensionOverrideByName, - _Inout_ OrtSessionOptions* options, _In_ const char* dim_name, - _In_ int64_t dim_value); - - /// @} - /// \name Misc - /// @{ - - /** \brief Get the names of all available providers - * - * \note The providers in the list are not guaranteed to be usable. They may fail to load due to missing system dependencies. - * For example, if the CUDA/cuDNN libraries are not installed, the CUDA provider will report an error when it is added to the session options. - * - * \param[out] out_ptr Set to a pointer to an array of null terminated strings of the available providers. The entries and the - * array itself must be freed using OrtApi::ReleaseAvailableProviders - * \param[out] provider_length Set to the number of entries in the `out_ptr` array - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetAvailableProviders, _Outptr_ char*** out_ptr, _Out_ int* provider_length); - - /** \brief Release data from OrtApi::GetAvailableProviders. This API will never fail - * so you can rely on it in a noexcept code. - * - * \param[in] ptr The `out_ptr` result from OrtApi::GetAvailableProviders. - * \param[in] providers_length The `provider_length` result from OrtApi::GetAvailableProviders - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(ReleaseAvailableProviders, _In_ char** ptr, - _In_ int providers_length); - - /// @} - /// \name OrtValue - /// @{ - - /** \brief Get the length of a single string in a string tensor - * - * \param[in] value A string tensor - * \param[in] index Index of the string in the tensor - * \param[out] out Set to number of bytes of the string element - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetStringTensorElementLength, _In_ const OrtValue* value, size_t index, _Out_ size_t* out); - - /** \brief Get a single string from a string tensor - * - * \param[in] value A string tensor - * \param[in] s_len Number of bytes in the `s` buffer. Must match the value returned by OrtApi::GetStringTensorElementLength. - * \param[in] index Index of the string in the tensor - * \param[out] s The string element contents in UTF-8 encoding. The string is NOT null-terminated. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetStringTensorElement, _In_ const OrtValue* value, size_t s_len, size_t index, _Out_writes_bytes_all_(s_len) void* s); - - /** \brief Set a single string in a string tensor - * - * \param[in] value A string tensor - * \param[in] s A null terminated UTF-8 encoded string - * \param[in] index Index of the string in the tensor to set - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(FillStringTensorElement, _Inout_ OrtValue* value, _In_ const char* s, size_t index); - - /// @} - /// \name OrtSessionOptions - /// @{ - - /** \brief Set a session configuration entry as a pair of strings - * - * If a configuration with same key exists, this will overwrite the configuration with the given config_value. - * - * The config_key and the format of config_value are defined in onnxruntime_session_options_config_keys.h - * - * \param[in] options - * \param[in] config_key A null terminated string representation of the config key - * \param[in] config_value A null terminated string representation of the config value - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(AddSessionConfigEntry, _Inout_ OrtSessionOptions* options, - _In_z_ const char* config_key, _In_z_ const char* config_value); - - /// @} - /// \name OrtAllocator - /// @{ - - /** \brief Create an allocator for an ::OrtSession following an ::OrtMemoryInfo - * - * \param[in] session - * \param[in] mem_info valid ::OrtMemoryInfo instance - * \param[out] out Newly created ::OrtAllocator. Must be freed with OrtApi::ReleaseAllocator - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateAllocator, _In_ const OrtSession* session, _In_ const OrtMemoryInfo* mem_info, - _Outptr_ OrtAllocator** out); - - /** \brief Release an ::OrtAllocator obtained from OrtApi::CreateAllocator - */ - ORT_CLASS_RELEASE(Allocator); - - /// @} - /// \name OrtSession - /// @{ - - /** \brief Run a model using Io Bindings for the inputs & outputs - * - * \see OrtApi::Run - * - * \param[in] session - * \param[in] run_options - * \param[in] binding_ptr - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(RunWithBinding, _Inout_ OrtSession* session, _In_ const OrtRunOptions* run_options, _In_ const OrtIoBinding* binding_ptr); - - /** \brief Create an ::OrtIoBinding instance - * - * An IoBinding object allows one to bind pre-allocated ::OrtValue%s to input names. - * Thus if you want to use a raw on device buffer as input or output you can avoid - * extra copy during runtime. - * - * \param[in] session - * \param[out] out Newly created ::OrtIoBinding. Must be freed with OrtApi::ReleaseIoBinding - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateIoBinding, _Inout_ OrtSession* session, _Outptr_ OrtIoBinding** out); - - /// @} - /// \name OrtIoBinding - /// @{ - - /** \brief Release an ::OrtIoBinding obtained from OrtApi::CreateIoBinding - */ - ORT_CLASS_RELEASE(IoBinding); - - /** \brief Bind an ::OrtValue to an ::OrtIoBinding input - * - * When using OrtApi::RunWithBinding this value is used for the named input - * - * \param[in] binding_ptr - * \param[in] name Name for the model input - * \param[in] val_ptr ::OrtValue of Tensor type. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(BindInput, _Inout_ OrtIoBinding* binding_ptr, _In_ const char* name, _In_ const OrtValue* val_ptr); - - /** \brief Bind an ::OrtValue to an ::OrtIoBinding output - * - * When using OrtApi::RunWithBinding this value is used for the named output - * - * \param[in] binding_ptr - * \param[in] name Null terminated string of the model output name - * \param[in] val_ptr ::OrtValue of Tensor type. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(BindOutput, _Inout_ OrtIoBinding* binding_ptr, _In_ const char* name, _In_ const OrtValue* val_ptr); - - /** \brief Bind an ::OrtIoBinding output to a device - * - * Binds the ::OrtValue to a device which is specified by ::OrtMemoryInfo. - * You can either create an instance of ::OrtMemoryInfo with a device id or obtain one from the allocator that you have created/are using - * This is useful when one or more outputs have dynamic shapes and, it is hard to pre-allocate and bind a chunk of - * memory within ::OrtValue ahead of time. - * - * \see OrtApi::RunWithBinding - * - * \param[in] binding_ptr - * \param[in] name Null terminated string of the device name - * \param[in] mem_info_ptr - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(BindOutputToDevice, _Inout_ OrtIoBinding* binding_ptr, _In_ const char* name, _In_ const OrtMemoryInfo* mem_info_ptr); - - /** \brief Get the names of an ::OrtIoBinding's outputs - * - * Returns the names of the outputs in the order they were bound. This is useful after running the model - * with bound outputs because the returned names are in order in which output ::OrtValue are returned. This is useful if - * the order of outputs and their names is not known. - * - * \param[in] binding_ptr - * \param[in] allocator Allocator used to allocate continuous buffers for output strings and lengths. - * \param[out] buffer Returns an array of non-null terminated UTF-8 strings. The number of strings stored is returned in the count parameter. - * This buffer is allocated using `allocator` and must be freed using it. - * \param[out] lengths Returns an array of `count` lengths of the strings returned in `buffer` - * This buffer is allocated using `allocator` and must be freed using it. - * \param[out] count Number of strings returned. If `binding_ptr` has no bound outputs, zero is returned, - * no memory allocation is performed and buffer and lengths are set to nullptr. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetBoundOutputNames, _In_ const OrtIoBinding* binding_ptr, _In_ OrtAllocator* allocator, - _Out_ char** buffer, _Out_writes_all_(count) size_t** lengths, _Out_ size_t* count); - - /** \brief Get the output ::OrtValue objects from an ::OrtIoBinding - * - * Returns an array of pointers to individually allocated ::OrtValue%s that contain results of a model execution with OrtApi::RunWithBinding - * The array contains the same number of ::OrtValue%s and they are in the same order as they were bound with OrtApi::BindOutput - * or OrtApi::BindOutputToDevice. - * - * The returned ::OrtValue%s must be released using OrtApi::ReleaseValue after they are no longer needed. - * The array is allocated using the specified instance of the allocator and must be freed using the same allocator after - * all the ::OrtValue%s contained therein are individually released. - * - * \param[in] binding_ptr - * \param[in] allocator Allocator used to allocate output array - * \param[out] output Set to the allocated array of allocated ::OrtValue outputs. Set to nullptr if there are 0 outputs. - * \param[out] output_count Set to number of ::OrtValue%s returned - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetBoundOutputValues, _In_ const OrtIoBinding* binding_ptr, _In_ OrtAllocator* allocator, - _Out_writes_all_(output_count) OrtValue*** output, _Out_ size_t* output_count); - - /** \brief Clears any previously set Inputs for an ::OrtIoBinding - */ - void(ORT_API_CALL* ClearBoundInputs)(_Inout_ OrtIoBinding* binding_ptr) NO_EXCEPTION ORT_ALL_ARGS_NONNULL; - - /** \brief Clears any previously set Outputs for an ::OrtIoBinding - */ - void(ORT_API_CALL* ClearBoundOutputs)(_Inout_ OrtIoBinding* binding_ptr) NO_EXCEPTION ORT_ALL_ARGS_NONNULL; - - /// @} - /// \name OrtValue - /// @{ - - /** \brief Direct memory access to a specified tensor element - * - * For example, given a tensor with shape of [3,224,224], a pointer to the element at location [2,150,128] can be retrieved - * - * This function only works for numeric type tensors (No strings, etc). - * This is a no-copy method whose returned pointer is valid until the passed in ::OrtValue is free'd. - * - * \param[in] value - * \param[in] location_values Pointer to an array of index values that specify an element's location relative to its shape - * \param[in] location_values_count Number of elements in location_values. Must match the number of elements in the tensor's shape. - * \param[out] out Set to a pointer to the element specified - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(TensorAt, _Inout_ OrtValue* value, const int64_t* location_values, size_t location_values_count, _Outptr_ void** out); - - /// @} - /// \name OrtEnv - /// @{ - - /** \brief Create an allocator and register it with the ::OrtEnv - * - * Enables sharing the allocator between multiple sessions that use the same env instance. - * Lifetime of the created allocator will be valid for the duration of the environment. - * Returns an error if an allocator with the same ::OrtMemoryInfo is already registered. - * - * See https://onnxruntime.ai/docs/get-started/with-c.html for details. - * - * \param[in] env ::OrtEnv instance - * \param[in] mem_info - * \param[in] arena_cfg Pass nullptr for defaults - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateAndRegisterAllocator, _Inout_ OrtEnv* env, _In_ const OrtMemoryInfo* mem_info, - _In_ const OrtArenaCfg* arena_cfg); - - /** \brief Set language projection - * - * Set the language projection for collecting telemetry data when Env is created. - * - * The default is ORT_PROJECTION_C, which means it will classify the language not in the list to C also. - * - * \param[in] ort_env - * \param[in] projection - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetLanguageProjection, _In_ const OrtEnv* ort_env, _In_ OrtLanguageProjection projection); - - /// @} - /// \name OrtSession - /// @{ - - /** \brief Return the time that profiling was started - * - * \note The timer precision varies per platform. On Windows and MacOS, the precision will be ~100ns - * - * \param[in] session - * \param[out] out nanoseconds of profiling's start time - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionGetProfilingStartTimeNs, _In_ const OrtSession* session, _Outptr_ uint64_t* out); - - /// @} - /// \name OrtThreadingOptions - /// @{ - - /** \brief Set global intra-op thread count - * - * This configures the global thread pool options to be used in the call to OrtApi::CreateEnvWithGlobalThreadPools - * - * \param[in] tp_options - * \param[in] intra_op_num_threads Number of threads, special values:
- * 0 = Use default thread count
- * 1 = The invoking thread will be used; no threads will be created in the thread pool. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetGlobalIntraOpNumThreads, _Inout_ OrtThreadingOptions* tp_options, int intra_op_num_threads); - - /** \brief Set global inter-op thread count - * - * This configures the global thread pool options to be used in the call to OrtApi::CreateEnvWithGlobalThreadPools - * - * \param[in] tp_options - * \param[in] inter_op_num_threads Number of threads, special values:
- * 0 = Use default thread count
- * 1 = The invoking thread will be used; no threads will be created in the thread pool. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetGlobalInterOpNumThreads, _Inout_ OrtThreadingOptions* tp_options, int inter_op_num_threads); - - /** \brief Set global spin control options - * - * This will configure the global thread pool options to be used in the call to OrtApi::CreateEnvWithGlobalThreadPools. - * Allow spinning of thread pools when their queues are empty. This will set the value for both - * inter_op and intra_op threadpools. - * - * \param[in] tp_options - * \param[in] allow_spinning Valid values are 0 or 1.
- * 0 = It won't spin (recommended if CPU usage is high)
- * 1 = Threadpool will spin to wait for queue to become non-empty - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetGlobalSpinControl, _Inout_ OrtThreadingOptions* tp_options, int allow_spinning); - - /// @} - /// \name OrtSessionOptions - /// @{ - - /** \brief Add a pre-allocated initializer to a session - * - * If a model contains an initializer with a name that is same as the name passed to this call, - * ORT will use this initializer instance instead of deserializing one from the model file. This - * is useful when you want to share the same initializer across sessions. - * - * \param[in] options - * \param[in] name Null terminated string of the initializer name - * \param[in] val ::OrtValue containing the initializer. Its lifetime and the underlying initializer buffer must be - * managed by the user (created using the OrtApi::CreateTensorWithDataAsOrtValue) and it must outlive the session object - * to which it is added. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(AddInitializer, _Inout_ OrtSessionOptions* options, _In_z_ const char* name, - _In_ const OrtValue* val); - - /// @} - /// \name OrtEnv - /// @{ - - /** - * Create a custom environment with global threadpools and logger that will be shared across sessions. - * Use this in conjunction with OrtApi::DisablePerSessionThreads or else the session will use - * its own thread pools. - * - * \param[in] logging_function A pointer to a logging function. - * \param[in] logger_param A pointer to arbitrary data passed as the ::OrtLoggingFunction `param` parameter to - * `logging_function`. - * \param[in] log_severity_level The log severity level. - * \param[in] logid The log identifier. - * \param[in] tp_options - * \param[out] out Newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateEnvWithCustomLoggerAndGlobalThreadPools, OrtLoggingFunction logging_function, _In_opt_ void* logger_param, OrtLoggingLevel log_severity_level, - _In_ const char* logid, _In_ const struct OrtThreadingOptions* tp_options, _Outptr_ OrtEnv** out); - - /// @} - /// \name OrtSessionOptions - /// @{ - - /** \brief Append CUDA provider to session options - * - * If CUDA is not available (due to a non CUDA enabled build, or if CUDA is not installed on the system), this function will return failure. - * - * \param[in] options - * \param[in] cuda_options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_CUDA, - _In_ OrtSessionOptions* options, _In_ const OrtCUDAProviderOptions* cuda_options); - - /** \brief Append ROCM execution provider to the session options - * - * If ROCM is not available (due to a non ROCM enabled build, or if ROCM is not installed on the system), this function will return failure. - * - * \param[in] options - * \param[in] rocm_options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_ROCM, - _In_ OrtSessionOptions* options, _In_ const OrtROCMProviderOptions* rocm_options); - - /** \brief Append OpenVINO execution provider to the session options - * - * If OpenVINO is not available (due to a non OpenVINO enabled build, or if OpenVINO is not installed on the system), this function will fail. - * - * \param[in] options - * \param[in] provider_options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_OpenVINO, - _In_ OrtSessionOptions* options, _In_ const OrtOpenVINOProviderOptions* provider_options); - - /// @} - /// \name OrtThreadingOptions - /// @{ - - /** \brief Set threading flush-to-zero and denormal-as-zero - * - * Sets global thread pool options to be used in the call to OrtApi::CreateEnvWithGlobalThreadPools. - * Flush-to-zero and denormal-as-zero are applied to threads in both intra and inter global thread pool. - * \note This option is not needed if the models used have no denormals. Having no denormals is recommended as this option may hurt model accuracy. - * - * \param[in] tp_options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetGlobalDenormalAsZero, _Inout_ OrtThreadingOptions* tp_options); - - /// @} - /// \name OrtArenaCfg - /// @{ - - /** \deprecated Use OrtApi::CreateArenaCfgV2 - * - * This will create the configuration of an arena that can eventually be used to define an arena based allocator's behavior - * - * \param[in] max_mem Use 0 to allow ORT to choose the default - * \param[in] arena_extend_strategy Use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested - * \param[in] initial_chunk_size_bytes Use -1 to allow ORT to choose the default - * \param[in] max_dead_bytes_per_chunk Use -1 to allow ORT to choose the default - * \param[in] out A pointer to an OrtArenaCfg instance - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateArenaCfg, _In_ size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, - int max_dead_bytes_per_chunk, _Outptr_ OrtArenaCfg** out); - - ORT_CLASS_RELEASE(ArenaCfg); - - /// @} - /// \name OrtModelMetadata - /// @{ - - /** - * Use this to obtain the description of the graph present in the model - * (doc_string field of the GraphProto message within the ModelProto message). - * If it doesn't exist, an empty string will be returned. - * - * \param[in] model_metadata An instance of ::OrtModelMetadata - * \param[in] allocator Allocator used to allocate the string that will be returned back - * \param[out] value Set to a null terminated string allocated using `allocator`. The caller is responsible for freeing it using `allocator` - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(ModelMetadataGetGraphDescription, _In_ const OrtModelMetadata* model_metadata, - _Inout_ OrtAllocator* allocator, _Outptr_ char** value); - - /// @} - /// \name OrtSessionOptions - /// @{ - - /** \brief Append TensorRT provider to session options - * - * If TensorRT is not available (due to a non TensorRT enabled build, or if TensorRT is not installed on the system), this function will return failure. - * - * \param[in] options - * \param[in] tensorrt_options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_TensorRT, - _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptions* tensorrt_options); - - /// @} - /// \name Misc - /// @{ - - /** \brief Set current GPU device ID - * - * Set the current device id of the GPU execution provider (CUDA/tensorrt/rocm). The device id should be less - * than the total number of devices available. This is only useful when multiple-GPUs are installed and it is - * required to restrict execution to a single GPU. - * - * \param[in] device_id - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetCurrentGpuDeviceId, _In_ int device_id); - - /** \brief Get current GPU device ID - * - * Get the current device id of the GPU execution provider (CUDA/tensorrt/rocm). - * - * \see OrtApi::SetCurrentGpuDeviceId - * - * \param[out] device_id - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetCurrentGpuDeviceId, _In_ int* device_id); - - /// @} - /// \name OrtKernelInfo - /// Custom operator APIs. - /// @{ - - /** \brief Fetch an array of int64_t values stored as an attribute in the graph node - * - * - * If `out` is nullptr, the value of `size` is set to the true size of the attribute - * array's size, and a success status is returned. - * - * If the `size` parameter is greater than or equal to the actual attribute array's size, - * the value of `size` is set to the true size of the attribute array's size, - * the provided memory is filled with the attribute's contents, - * and a success status is returned. - * - * If the `size` parameter is less than the actual attribute array's size and `out` - * is not nullptr, the value of `size` is set to the true size of the attribute array's size - * and a failure status is returned.) - * - * \param[in] info instance - * \param[in] name name of the attribute to be parsed - * \param[out] out pointer to memory where the attribute's contents are to be stored - * \param[in, out] size actual size of attribute array - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(KernelInfoGetAttributeArray_float, _In_ const OrtKernelInfo* info, _In_ const char* name, - _Out_ float* out, _Inout_ size_t* size); - - /** \brief Fetch an array of int64_t values stored as an attribute in the graph node - * - * If `out` is nullptr, the value of `size` is set to the true size of the attribute - * array's size, and a success status is returned. - * - * If the `size` parameter is greater than or equal to the actual attribute array's size, - * the value of `size` is set to the true size of the attribute array's size, - * the provided memory is filled with the attribute's contents, - * and a success status is returned. - * - * If the `size` parameter is less than the actual attribute array's size and `out` - * is not nullptr, the value of `size` is set to the true size of the attribute array's size - * and a failure status is returned.) - * - * \param[in] info instance - * \param[in] name name of the attribute to be parsed - * \param[out] out pointer to memory where the attribute's contents are to be stored - * \param[in, out] size actual size of attribute array - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(KernelInfoGetAttributeArray_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, - _Out_ int64_t* out, _Inout_ size_t* size); - - /// @} - /// \name OrtArenaCfg - /// @{ - - /** \brief Create an ::OrtArenaCfg - * - * Create the configuration of an arena that can eventually be used to define an arena based allocator's behavior. - * - * Supported keys are (See https://onnxruntime.ai/docs/get-started/with-c.html for details on what the - * following parameters mean and how to choose these values.): - * "max_mem": Maximum memory that can be allocated by the arena based allocator. - * Use 0 for ORT to pick the best value. Default is 0. - * "arena_extend_strategy": 0 = kNextPowerOfTwo, 1 = kSameAsRequested. - * Use -1 to allow ORT to choose the default. - * "initial_chunk_size_bytes": (Possible) Size of the first allocation in the arena. - * Only relevant if arena strategy is `kNextPowerOfTwo`. Use -1 to allow ORT to choose the default. - * Ultimately, the first allocation size is determined by the allocation memory request. - * "max_dead_bytes_per_chunk": Threshold of unused memory in an allocated chunk of arena memory after - * crossing which the current chunk is chunked into 2. - * "initial_growth_chunk_size_bytes": (Possible) Size of the second allocation in the arena. - * Only relevant if arena strategy is `kNextPowerOfTwo`. Use -1 to allow ORT to choose the default. - * "max_power_of_two_extend_bytes": The maximum enxtend size if arena strategy is `kNextPowerOfTwo`. - * It is not an allocation limit, it is only a limit for extension when requested byte is less than the limit. - * When requested bytes is more than the limit, allocator will still return as requested. - * Use -1 to allow ORT to choose the default 1GB for max_power_of_two_extend_bytes. - * Ultimately, the allocation size is determined by the allocation memory request. - * Further allocation sizes are governed by the arena extend strategy. - * - * \param[in] arena_config_keys Keys to configure the arena - * \param[in] arena_config_values Values to configure the arena - * \param[in] num_keys Number of keys in `arena_config_keys` and `arena_config_values` - * \param[out] out Newly created ::OrtArenaCfg. Must be freed with OrtApi::ReleaseArenaCfg - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateArenaCfgV2, _In_reads_(num_keys) const char* const* arena_config_keys, - _In_reads_(num_keys) const size_t* arena_config_values, _In_ size_t num_keys, - _Outptr_ OrtArenaCfg** out); - - /// @} - /// \name OrtRunOptions - /// @{ - - /** \brief Set a single run configuration entry as a pair of strings - * - * If a configuration with same key exists, this will overwrite the configuration with the given config_value - * - * The config_key and the format of config_value are defined in onnxruntime_run_options_config_keys.h - * - * \param[in] options - * \param[in] config_key A null terminated string representation of the config key - * \param[in] config_value A null terminated string representation of the config value - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(AddRunConfigEntry, _Inout_ OrtRunOptions* options, - _In_z_ const char* config_key, _In_z_ const char* config_value); - - /// @} - /// \name OrtPrepackedWeightsContainer - /// @{ - - /** \brief Create an ::OrtPrepackedWeightsContainer - * - * This container will hold pre-packed buffers of shared initializers for sharing between sessions - * (i.e.) if there are shared initializers that can be shared between sessions, the pre-packed buffers - * of these (if any) may possibly be shared to provide memory footprint savings. Pass this container - * to sessions that you would like to share pre-packed buffers of shared initializers at session - * creation time. - * - * \param[out] out Newly created ::OrtPrepackedWeightsContainer. Must be freed with OrtApi::ReleasePrepackedWeightsContainer - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreatePrepackedWeightsContainer, _Outptr_ OrtPrepackedWeightsContainer** out); - - /** \brief Release OrtPrepackedWeightsContainer instance - * - * \note instance must not be released until the sessions using it are released - */ - ORT_CLASS_RELEASE(PrepackedWeightsContainer); - - /// @} - /// \name OrtSession - /// @{ - - /** \brief Create session with prepacked weights container - * - * Same functionality offered by OrtApi::CreateSession except that a container that contains - * pre-packed weights' buffers is written into/read from by the created session. - * This is useful when used in conjunction with OrtApi::AddInitializer which injects - * shared initializer info into sessions. Wherever possible, the pre-packed versions of these - * shared initializers are cached in this container so that multiple sessions can just re-use - * these instead of duplicating these in memory. - * - * \param[in] env OrtEnv instance instance - * \param[in] model_path Null terminated string of the path (wchar on Windows, char otherwise) - * \param[in] options - * \param[in] prepacked_weights_container - * \param[out] out Newly created ::OrtSession. Must be freed with OrtApi::ReleaseSession - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateSessionWithPrepackedWeightsContainer, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, - _In_ const OrtSessionOptions* options, _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container, - _Outptr_ OrtSession** out); - - /** \brief Create session from memory with prepacked weights container - * - * Same functionality offered by OrtApi::CreateSessionFromArray except that a container that contains - * pre-packed weights' buffers is written into/read from by the created session. - * This is useful when used in conjunction with OrtApi::AddInitializer which injects - * shared initializer info into sessions. Wherever possible, the pre-packed versions of these - * shared initializers are cached in this container so that multiple sessions can just re-use - * these instead of duplicating these in memory. - * - * \param[in] env - * \param[in] model_data Array of bytes holding the model - * \param[in] model_data_length Number of bytes in `model_data_model` - * \param[in] options - * \param[in] prepacked_weights_container - * \param[out] out Newly created ::OrtSession. Must be freed with OrtApi::ReleaseSession - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateSessionFromArrayWithPrepackedWeightsContainer, _In_ const OrtEnv* env, - _In_ const void* model_data, size_t model_data_length, - _In_ const OrtSessionOptions* options, _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container, - _Outptr_ OrtSession** out); - - /// @} - /// \name OrtSessionOptions - /// @{ - - /** \brief Append TensorRT execution provider to the session options - * - * If TensorRT is not available (due to a non TensorRT enabled build), this function will return failure. - * - * This is slightly different from OrtApi::SessionOptionsAppendExecutionProvider_TensorRT, it takes an - * ::OrtTensorRTProviderOptions which is publicly defined. This takes an opaque ::OrtTensorRTProviderOptionsV2 - * which must be created with OrtApi::CreateTensorRTProviderOptions. - * - * For OrtApi::SessionOptionsAppendExecutionProvider_TensorRT, the user needs to instantiate ::OrtTensorRTProviderOptions - * as well as allocate/release buffers for some members of ::OrtTensorRTProviderOptions. - * Here, OrtApi::CreateTensorRTProviderOptions and Ortapi::ReleaseTensorRTProviderOptions will do the memory management for you. - * - * \param[in] options - * \param[in] tensorrt_options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_TensorRT_V2, - _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptionsV2* tensorrt_options); - - /// @} - /// \name OrtTensorRTProviderOptionsV2 - /// @{ - - /** \brief Create an OrtTensorRTProviderOptionsV2 - * - * \param[out] out Newly created ::OrtTensorRTProviderOptionsV2. Must be released with OrtApi::ReleaseTensorRTProviderOptions - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateTensorRTProviderOptions, _Outptr_ OrtTensorRTProviderOptionsV2** out); - - /** \brief Set options in a TensorRT Execution Provider. - * - * Please refer to https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#cc - * to know the available keys and values. Key should be in null terminated string format of the member of ::OrtTensorRTProviderOptionsV2 - * and value should be its related range. Recreates the options and only sets the supplied values. - * - * For example, key="trt_max_workspace_size" and value="2147483648" - * - * \param[in] tensorrt_options - * \param[in] provider_options_keys Array of UTF-8 null-terminated string for provider options keys - * \param[in] provider_options_values Array of UTF-8 null-terminated string for provider options values - * \param[in] num_keys Number of elements in the `provider_option_keys` and `provider_options_values` arrays - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(UpdateTensorRTProviderOptions, _Inout_ OrtTensorRTProviderOptionsV2* tensorrt_options, - _In_reads_(num_keys) const char* const* provider_options_keys, - _In_reads_(num_keys) const char* const* provider_options_values, - _In_ size_t num_keys); - - /** \brief Get serialized TensorRT provider options string. - * - * For example, "trt_max_workspace_size=2147483648;trt_max_partition_iterations=10;trt_int8_enable=1;......" - * - * \param tensorrt_options - OrtTensorRTProviderOptionsV2 instance - * \param allocator - a ptr to an instance of OrtAllocator obtained with OrtApi::CreateAllocator or OrtApi::GetAllocatorWithDefaultOptions - * the specified allocator will be used to allocate continuous buffers for output strings and lengths. - * \param ptr - is a UTF-8 null terminated string allocated using 'allocator'. The caller is responsible for using the same allocator to free it. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetTensorRTProviderOptionsAsString, _In_ const OrtTensorRTProviderOptionsV2* tensorrt_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); - - /** \brief Release an ::OrtTensorRTProviderOptionsV2 - * - * \note This is an exception in the naming convention of other Release* functions, as the name of the method does not have the V2 suffix, but the type does - */ - void(ORT_API_CALL* ReleaseTensorRTProviderOptions)(_Frees_ptr_opt_ OrtTensorRTProviderOptionsV2* input); - - /// @} - /// \name OrtSessionOptions - /// @{ - - /** \brief Enable custom operators - * - * See onnxruntime-extensions: https://github.com/microsoft/onnxruntime-extensions.git - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(EnableOrtCustomOps, _Inout_ OrtSessionOptions* options); - - /// @} - /// \name OrtAllocator - /// @{ - - /** \brief Register a custom allocator - * - * Enables sharing between multiple sessions that use the same env instance. - * Returns an error if an allocator with the same ::OrtMemoryInfo is already registered. - * - * The behavior of this is exactly the same as OrtApi::CreateAndRegisterAllocator except - * instead of ORT creating an allocator based on provided info, in this case - * ORT uses the user-provided custom allocator. - * See https://onnxruntime.ai/docs/get-started/with-c.html for details. - * - * \param[in] env - * \param[in] allocator User provided allocator - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(RegisterAllocator, _Inout_ OrtEnv* env, _In_ OrtAllocator* allocator); - - /** \brief Unregister a custom allocator - * - * It is an error if you provide an ::OrtMemoryInfo not corresponding to any - * registered allocators for sharing. - * - * \param[in] env - * \param[in] mem_info - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(UnregisterAllocator, _Inout_ OrtEnv* env, - _In_ const OrtMemoryInfo* mem_info); - - /// @} - /// \name OrtValue - /// @{ - - /** \brief Sets *out to 1 iff an ::OrtValue is a SparseTensor, and 0 otherwise - * - * \param[in] value existing ::OrtValue - * \param[out] out unless an error occurs, contains 1 iff the value contains an instance - * of sparse tensor or 0 otherwise. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(IsSparseTensor, _In_ const OrtValue* value, _Out_ int* out); - - /** \brief Create an ::OrtValue with a sparse tensor that is empty. - * - * Use FillSparseTensor() functions to populate sparse tensor with non-zero values and - * format specific indices data. - * Use ReleaseValue to destroy the sparse tensor, this will also release the buffer inside the output value - * if any was allocated. - * \param[in,out] allocator allocator to use when performing an allocation. Allocation will be performed - * by FillSparseTensor() APIs. The lifespan of the allocator instance must eclipse the lifespan - * this sparse tensor instance as the same allocator will be used to free memory. - * \param[in] dense_shape shape of the original dense tensor - * \param[in] dense_shape_len number of shape dimensions being passed - * \param[in] type must be one of TENSOR_ELEMENT_DATA_TYPE_xxxx - * \param[out] out Should be freed by calling ReleaseValue - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateSparseTensorAsOrtValue, _Inout_ OrtAllocator* allocator, _In_ const int64_t* dense_shape, - size_t dense_shape_len, ONNXTensorElementDataType type, _Outptr_ OrtValue** out); - - /** - * This fills populates an empty tensor that was created using OrtApi::CreateSparseTensorAsOrtValue. - * This will allocate required memory and copy the supplied NNZ values and COO indices into that memory allocation. - * Memory allocation is performed using the allocator that was specified with OrtApi::CreateSparseTensorAsOrtValue. - * - * \param[in,out] ort_value ::OrtValue to populate with data - * \param[in] data_mem_info serves to identify the location of the data to be copied. If the allocator specified - * at the creation time has memory info that is not the same as mem_info argument to this function a X-device copy will be performed. - * String data is assumed to be on CPU and will only be copied into a CPU allocated buffer. - * \param[in] values_shape pointer to values shape array - * \param[in] values_shape_len length of the values_shape - * \param[in] values pointer to an array of values. For strings, pass const char**. - * \param[in] indices_data pointer to a location of COO indices - * \param[in] indices_num number of COO indices - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(FillSparseTensorCoo, _Inout_ OrtValue* ort_value, _In_ const OrtMemoryInfo* data_mem_info, - _In_ const int64_t* values_shape, size_t values_shape_len, _In_ const void* values, - _In_ const int64_t* indices_data, size_t indices_num); - - /** - * This fills populates an empty tensor that was created using OrtApi::CreateSparseTensorAsOrtValue. - * This will allocate required memory and copy the supplied NNZ values and CSR indices into that memory allocation. - * Memory allocation is performed using the allocator that was specified with OrtApi::CreateSparseTensorAsOrtValue. - * - * \param[in,out] ort_value ::OrtValue to populate with data - * \param[in] data_mem_info serves to identify the location of the data to be copied. If the allocator specified - * at the creation time has memory info that is not the same as mem_info argument to this function a X-device copy will be performed. - * String data is assumed to be on CPU and will only be copied into a CPU allocated buffer. - * \param[in] values_shape pointer to values shape array - * \param[in] values_shape_len length of the values_shape - * \param[in] values - pointer to an array of values. For strings, pass const char**. - * \param[in] inner_indices_data pointer to a location of CSR inner indices - * \param[in] inner_indices_num number of CSR inner indices - * \param[in] outer_indices_data pointer to a location of CSR outer indices - * \param[in] outer_indices_num number of CSR outer indices - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(FillSparseTensorCsr, _Inout_ OrtValue* ort_value, _In_ const OrtMemoryInfo* data_mem_info, - _In_ const int64_t* values_shape, size_t values_shape_len, _In_ const void* values, - _In_ const int64_t* inner_indices_data, size_t inner_indices_num, - _In_ const int64_t* outer_indices_data, size_t outer_indices_num); - - /** - * This fills populates an empty tensor that was created using OrtApi::CreateSparseTensorAsOrtValue. - * This will allocate required memory and copy the supplied NNZ values and BlockSparse indices into that memory allocation. - * Memory allocation is performed using the allocator that was specified with OrtApi::CreateSparseTensorAsOrtValue. - * - * \param[in,out] ort_value ::OrtValue to populate with data - * \param[in] data_mem_info serves to identify the location of the data to be copied. If the allocator specified - * at the creation time has memory info that is not the same as mem_info argument to this function a X-device copy will be performed. - * String data is assumed to be on CPU and will only be copied into a CPU allocated buffer. - * \param[in] values_shape - * \param[in] values_shape_len - * \param[in] values structure with values information - * \param[in] indices_shape_data pointer to a location of indices shape - * \param[in] indices_shape_len length of the block sparse indices shape - * \param[in] indices_data pointer to a location of indices data. Shape will determine the length of the indices data. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(FillSparseTensorBlockSparse, _Inout_ OrtValue* ort_value, _In_ const OrtMemoryInfo* data_mem_info, - _In_ const int64_t* values_shape, size_t values_shape_len, _In_ const void* values, - _In_ const int64_t* indices_shape_data, size_t indices_shape_len, - _In_ const int32_t* indices_data); - - /** - * Create an ::OrtValue with a sparse tensor. This is the first step. - * Next, use UseIndices() functions to supply sparse tensor with - * format specific indices data and set its sparse format to a specific enum value. - * This will not perform memory allocations. It will - * use supplied user buffer which should outlive the created sparse tensor. - * Use OrtApi::ReleaseValue to destroy the sparse tensor. It would not release the supplied values buffer. - * This function can not be used to map strings from the user allocated memory. Strings must always be copied - * and have UTF-8 encoding. Therefore, use OrtApi::CreateSparseTensorAsOrtValue above and then fill it with data - * using appropriate Make*() function. - * - * \param[in] info memory info where sparse values reside. - * \param[in,out] p_data pointer to a user allocated buffer with values. To create a full sparse tensor with no non-zero - * values, pass nullptr - * \param[in] dense_shape shape of the original dense tensor - * \param[in] dense_shape_len number of shape dimensions being passed - * \param[in] values_shape shape of the values data. To create a fully sparse tensor with no non-zero values, - * pass {0} shape. - * \param[in] values_shape_len number of values shape dimensions - * \param[in] type must be one of TENSOR_ELEMENT_DATA_TYPE_xxxx - * \param[out] out Should be freed by calling ReleaseValue - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateSparseTensorWithValuesAsOrtValue, _In_ const OrtMemoryInfo* info, _Inout_ void* p_data, - _In_ const int64_t* dense_shape, size_t dense_shape_len, - _In_ const int64_t* values_shape, size_t values_shape_len, - ONNXTensorElementDataType type, _Outptr_ OrtValue** out); - - /** - * This assigns Coo format indices to the SparseTensor that was created by - * OrtApi::CreateSparseTensorWithValuesAsOrtValue above. It also sets OrtSparseFormat to - * ORT_SPARSE_COO. This will not allocate any additional memory for data. The life span of - * indices_data buffer should eclipse the life span of this ::OrtValue. - * - * \param[in,out] ort_value ::OrtValue instance constructed with OrtApi::CreateSparseTensorWithValuesAsOrtValue - * \param[in,out] indices_data pointer to a user pre-allocated buffer or nullptr for fully sparse tensors. - * \param[in] indices_num number of COO indices. Should either be 0 for fully sparse tensors, be equal - * to the number of nnz values specified to OrtApi::CreateSparseTensorWithValuesAsOrtValue for 1-D {nnz} indices or - * be twice as number of nnz values for a 2-D indices {nnz, 2} - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(UseCooIndices, _Inout_ OrtValue* ort_value, _Inout_ int64_t* indices_data, size_t indices_num); - - /** - * The assigns CSR format indices to the SparseTensor that was created by - * OrtApi::CreateSparseTensorWithValuesAsOrtValue above. It also sets OrtSparseFormat to - * ORT_SPARSE_CSRC. This will not allocate any additional memory for data. The life spans of - * inner_data and outer_data buffers should eclipse the life span of this ::OrtValue. - * - * \param[in,out] ort_value ::OrtValue instance constructed with OrtApi::CreateSparseTensorWithValuesAsOrtValue - * \param[in,out] inner_data pointer to a user pre-allocated buffer or nullptr for fully sparse tensors. - * \param[in] inner_num number of inner CSR indices. Should either be 0 for fully sparse tensors or be equal - * to the number of nnz values specified to OrtApi::CreateSparseTensorWithValuesAsOrtValue. - * \param[in,out] outer_data pointer to user pre-allocated buffer or nullptr for fully sparse tensors. - * \param[in] outer_num number of CSR outer indices. Should either be 0 for fully sparse tensors or - * equal to rows + 1 of the dense shape. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(UseCsrIndices, _Inout_ OrtValue* ort_value, _Inout_ int64_t* inner_data, size_t inner_num, - _Inout_ int64_t* outer_data, size_t outer_num); - - /** - * The assigns BlockSparse format indices to the SparseTensor that was created by - * OrtApi::CreateSparseTensorWithValuesAsOrtValue above. It also sets OrtSparseFormat to - * ORT_SPARSE_BLOCK_SPARSE. This will not allocate any additional memory for data. The life span of - * indices_data buffer must eclipse the lifespan of this ::OrtValue. - * - * \param[in,out] ort_value OrtValue instance constructed with OrtApi::CreateSparseTensorWithValuesAsOrtValue - * \param[in] indices_shape pointer to indices shape. Use {0} for fully sparse tensors - * \param[in] indices_shape_len length of the indices shape - * \param[in,out] indices_data pointer to user pre-allocated buffer or nullptr for fully sparse tensors. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(UseBlockSparseIndices, _Inout_ OrtValue* ort_value, const int64_t* indices_shape, size_t indices_shape_len, _Inout_ int32_t* indices_data); - - /** \brief Returns sparse tensor format enum iff a given ort value contains an instance of sparse tensor. - * - * \param[in] ort_value ::OrtValue that contains an instance of sparse tensor - * \param[out] out pointer to out parameter - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetSparseTensorFormat, _In_ const OrtValue* ort_value, _Out_ enum OrtSparseFormat* out); - - /** \brief Returns data type and shape of sparse tensor values (nnz) iff ::OrtValue contains a SparseTensor. - * - * \param[in] ort_value An ::OrtValue that contains a fully constructed sparse tensor - * \param[out] out Must be freed by OrtApi::ReleaseTensorTypeAndShapeInfo - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetSparseTensorValuesTypeAndShape, _In_ const OrtValue* ort_value, _Outptr_ OrtTensorTypeAndShapeInfo** out); - - /** \brief Returns numeric data for sparse tensor values (nnz). For string values use GetStringTensor*(). - * - * \param[in] ort_value an instance of ::OrtValue containing sparse tensor - * \param[out] out returns a pointer to values data. Do not attempt to free this ptr. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetSparseTensorValues, _In_ const OrtValue* ort_value, _Outptr_ const void** out); - - /** \brief Returns data type, shape for the type of indices specified by indices_format. - * - * \param[in] ort_value ::OrtValue containing sparse tensor. - * \param[in] indices_format One of the indices formats. It is an error to request a format that the sparse - * tensor does not contain. - * \param[out] out an instance of ::OrtTensorTypeAndShapeInfo. Must be freed by OrtApi::ReleaseTensorTypeAndShapeInfo - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetSparseTensorIndicesTypeShape, _In_ const OrtValue* ort_value, enum OrtSparseIndicesFormat indices_format, _Outptr_ OrtTensorTypeAndShapeInfo** out); - - /** \brief Returns indices data for the type of the indices specified by indices_format - * - * \param[in] ort_value ::OrtValue containing sparse tensor. - * \param[in] indices_format One of the indices formats. It is an error to request a format that the sparse tensor does not contain. - * \param[out] num_indices Pointer to where the number of indices entries is returned - * \param[out] indices Returned pointer to the indices data. Do not free the returned pointer as it refers to internal data owned by the ::OrtValue - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetSparseTensorIndices, _In_ const OrtValue* ort_value, enum OrtSparseIndicesFormat indices_format, _Out_ size_t* num_indices, _Outptr_ const void** indices); - /// @} - /// \name OrtSessionOptions - /// @{ - - /** - * \brief Sets out to 1 iff an optional type OrtValue has an element, 0 otherwise (OrtValue is None) - * Use this API to find if the optional type OrtValue is None or not. - * If the optional type OrtValue is not None, use the OrtValue just like any other OrtValue. - * For example, if you get an OrtValue that corresponds to Optional(tensor) and - * if HasValue() returns true, use it as tensor and so on. - - * \param[in] value Input OrtValue. - * \param[out] out indicating if the input OrtValue contains data (1) or if it is a None (0) - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(HasValue, _In_ const OrtValue* value, _Out_ int* out); - - /// @} - /// \name OrtKernelContext - /// Custom operator APIs. - /// @{ - - /** \brief Used for custom operators, gets the GPU compute stream to use to launch the custom a GPU kernel - * \see ::OrtCustomOp - * \param[in] context OrtKernelContext instance - * \param[out] out Returns pointer to a GPU compute stream that can be used to launch the custom GPU kernel. - * If retrieving the GPU compute stream is not relevant (GPU not enabled in the build, kernel partitioned to - * some other EP), then a nullptr is returned as the output param. - * Do not free or mutate the returned pointer as it refers to internal data owned by the underlying session. - * Only use it for custom kernel launching. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(KernelContext_GetGPUComputeStream, _In_ const OrtKernelContext* context, _Outptr_ void** out); - - /// @} - /// \name GetTensorMemoryInfo - /// @{ - /** \brief Returns a pointer to the ::OrtMemoryInfo of a Tensor - * \param[in] value ::OrtValue containing tensor. - * \param[out] mem_info ::OrtMemoryInfo of the tensor. Do NOT free the returned pointer. It is valid for the lifetime of the ::OrtValue - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetTensorMemoryInfo, _In_ const OrtValue* value, _Out_ const OrtMemoryInfo** mem_info); - - /// @} - /// \name GetExecutionProviderApi - /// @{ - /** \brief Get a pointer to the requested version of the Execution Provider specific - * API extensions to the OrtApi - * \param[in] provider_name The name of the execution provider name. Currently only the following - * values are supported: "DML". - * \param[in] version Must be ::ORT_API_VERSION. - * \param[out] provider_api A void pointer containing a reference to the execution provider versioned api structure. - * For example, the provider_api pointer can be cast to the OrtDmlApi* when the provider_name is "DML". - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetExecutionProviderApi, _In_ const char* provider_name, _In_ uint32_t version, _Outptr_ const void** provider_api); - - /// @} - - /// \name SessionOptions - /// @{ - /** \brief Set custom thread creation function - * - * \param[in] options Session options - * \param[in] ort_custom_create_thread_fn Custom thread creation function - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionOptionsSetCustomCreateThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn); - - /** \brief Set creation options for custom thread - * - * \param[in] options Session options - * \param[in] ort_custom_thread_creation_options Custom thread creation options (can be nullptr) - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionOptionsSetCustomThreadCreationOptions, _Inout_ OrtSessionOptions* options, _In_ void* ort_custom_thread_creation_options); - - /** \brief Set custom thread join function - * - * \param[in] options Session options - * \param[in] ort_custom_join_thread_fn Custom join thread function, must not be nullptr when ort_custom_create_thread_fn is set - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionOptionsSetCustomJoinThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn); - /// @} - - /// \name OrtThreadingOptions - /// @{ - /** \brief Set custom thread creation function for global thread pools - * - * \param[inout] tp_options - * \param[in] ort_custom_create_thread_fn Custom thread creation function - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetGlobalCustomCreateThreadFn, _Inout_ OrtThreadingOptions* tp_options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn); - - /** \brief Set custom thread creation options for global thread pools - * - * \param[inout] tp_options - * \param[in] ort_custom_thread_creation_options Custom thread creation options (can be nullptr) - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetGlobalCustomThreadCreationOptions, _Inout_ OrtThreadingOptions* tp_options, _In_ void* ort_custom_thread_creation_options); - - /** \brief Set custom thread join function for global thread pools - * - * \param[inout] tp_options - * \param[in] ort_custom_join_thread_fn Custom thread join function, must not be nullptr when global ort_custom_create_thread_fn is set - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetGlobalCustomJoinThreadFn, _Inout_ OrtThreadingOptions* tp_options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn); - /// @} - - /** \brief Synchronize bound inputs. The call may be necessary for some providers, such as cuda, - * in case the system that allocated bound memory operated on a different stream. However, the - * operation is provider specific and could be a no-op. - * - * \param[inout] binding_ptr - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SynchronizeBoundInputs, _Inout_ OrtIoBinding* binding_ptr); - - /** \brief Synchronize bound outputs. The call may be necessary for some providers, such as cuda, - * in case the system that allocated bound memory operated on a different stream. However, the - * operation is provider specific and could be a no-op. - * - * \param[inout] binding_ptr - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SynchronizeBoundOutputs, _Inout_ OrtIoBinding* binding_ptr); - - /// \name OrtSessionOptions - /// @{ - - /** \brief Append CUDA execution provider to the session options - * - * If CUDA is not available (due to a non CUDA enabled build), this function will return failure. - * - * This is slightly different from OrtApi::SessionOptionsAppendExecutionProvider_CUDA, it takes an - * ::OrtCUDAProviderOptions which is publicly defined. This takes an opaque ::OrtCUDAProviderOptionsV2 - * which must be created with OrtApi::CreateCUDAProviderOptions. - * - * For OrtApi::SessionOptionsAppendExecutionProvider_CUDA, the user needs to instantiate ::OrtCUDAProviderOptions - * as well as allocate/release buffers for some members of ::OrtCUDAProviderOptions. - * Here, OrtApi::CreateCUDAProviderOptions and Ortapi::ReleaseCUDAProviderOptions will do the memory management for you. - * - * \param[in] options - * \param[in] cuda_options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.11. - */ - ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_CUDA_V2, - _In_ OrtSessionOptions* options, _In_ const OrtCUDAProviderOptionsV2* cuda_options); - - /// @} - /// \name OrtCUDAProviderOptionsV2 - /// @{ - - /** \brief Create an OrtCUDAProviderOptionsV2 - * - * \param[out] out Newly created ::OrtCUDAProviderOptionsV2. Must be released with OrtApi::ReleaseCudaProviderOptions - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.11. - */ - ORT_API2_STATUS(CreateCUDAProviderOptions, _Outptr_ OrtCUDAProviderOptionsV2** out); - - /** \brief Set options in a CUDA Execution Provider. - * - * Please refer to https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#configuration-options - * to know the available keys and values. Key should be in null terminated string format of the member of ::OrtCUDAProviderOptionsV2 - * and value should be its related range. Recreates the options and only sets the supplied values. - * - * For example, key="device_id" and value="0" - * - * \param[in] cuda_options - * \param[in] provider_options_keys Array of UTF-8 null-terminated string for provider options keys - * \param[in] provider_options_values Array of UTF-8 null-terminated string for provider options values - * \param[in] num_keys Number of elements in the `provider_option_keys` and `provider_options_values` arrays - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.11. - */ - ORT_API2_STATUS(UpdateCUDAProviderOptions, _Inout_ OrtCUDAProviderOptionsV2* cuda_options, - _In_reads_(num_keys) const char* const* provider_options_keys, - _In_reads_(num_keys) const char* const* provider_options_values, - _In_ size_t num_keys); - - /** - * Get serialized CUDA provider options string. - * - * For example, "device_id=0;arena_extend_strategy=0;......" - * - * \param cuda_options - OrtCUDAProviderOptionsV2 instance - * \param allocator - a ptr to an instance of OrtAllocator obtained with CreateAllocator() or GetAllocatorWithDefaultOptions() - * the specified allocator will be used to allocate continuous buffers for output strings and lengths. - * \param ptr - is a UTF-8 null terminated string allocated using 'allocator'. The caller is responsible for using the same allocator to free it. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.11. - */ - ORT_API2_STATUS(GetCUDAProviderOptionsAsString, _In_ const OrtCUDAProviderOptionsV2* cuda_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); - - /** \brief Release an ::OrtCUDAProviderOptionsV2 - * - * \note This is an exception in the naming convention of other Release* functions, as the name of the method does not have the V2 suffix, but the type does - * - * \since Version 1.11. - */ - void(ORT_API_CALL* ReleaseCUDAProviderOptions)(_Frees_ptr_opt_ OrtCUDAProviderOptionsV2* input); - - /// @} - - /** \brief Append MIGraphX provider to session options - * - * If MIGraphX is not available (due to a non MIGraphX enabled build, or if MIGraphX is not installed on the system), this function will return failure. - * - * \param[in] options - * \param[in] migraphx_options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.11. - */ - ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_MIGraphX, - _In_ OrtSessionOptions* options, _In_ const OrtMIGraphXProviderOptions* migraphx_options); - - /** \brief Replace initialized Tensors with external data with the data provided in initializers. - * - * The function will find the initialized TensorProtos with external data in the graph with the provided names and - * replace them with the provided tensors. The API verifies that the TensorProto being replaced - * has an external data reference and has the same name, dimensions and data type as its replacement. The replacement - * will occur before any of the optimizations take place. The data will be copied into the graph - * since TensorProto can't refer to the user provided buffers. - * - * Once the model has been loaded, the OrtValue(s) added to SessionOptions instance will be removed - * from the internal SessionOptions copy to save memory, the user provided buffers can then be deallocated - * and the SessionOptions instance that refers to them can be destroyed. - * - * \param[in] options - * \param[in] initializer_names Array of null terminated UTF-8 encoded strings of the initializers names. - * \param[in] initializers Array of ::OrtValue type - * \param[in] num_initializers Number of elements in the initializer_names and initializers - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.12. - */ - ORT_API2_STATUS(AddExternalInitializers, _In_ OrtSessionOptions* options, - _In_reads_(num_initializers) const char* const* initializer_names, - _In_reads_(num_initializers) const OrtValue* const* initializers, size_t num_initializers); - - /** \brief: Create attribute of onnxruntime operator - * - * \param[in] name Name of the attribute - * \param[in] data Data content of the attribute - * \param[in] len Number of bytes stored in data - * \param[in] type Data type - * \param[out] op_attr Attribute that has been created, which must be released by OrtApi::ReleaseOpAttr - * - * \since Version 1.12. - */ - ORT_API2_STATUS(CreateOpAttr, - _In_ const char* name, - _In_ const void* data, - _In_ int len, - _In_ OrtOpAttrType type, - _Outptr_ OrtOpAttr** op_attr); - - /* \brief: Release op attribute - * - * \param[in] opAttr Attribute created by OrtApi::CreateOpAttr - * - * \since Version 1.12. - */ - ORT_CLASS_RELEASE(OpAttr); - - /** \brief: Create onnxruntime native operator - * - * \param[in] info Kernel info - * \param[in] op_name Operator name - * \param[in] domain Operator domain - * \param[in] version Operator opset version - * \param[in] type_constraint_names Name of the type contraints, such as "T" or "T1" - * \param[in] type_constraint_values Type of each contraints - * \param[in] type_constraint_count Number of contraints - * \param[in] attr_values Attributes used to initialize the operator - * \param[in] attr_count Number of the attributes - * \param[in] input_count Number of inputs - * \param[in] output_count Number of outputs - * \param[out] ort_op Operator that has been created - * - * \since Version 1.12. - */ - ORT_API2_STATUS(CreateOp, - _In_ const OrtKernelInfo* info, - _In_z_ const char* op_name, - _In_z_ const char* domain, - int version, - _In_reads_(type_constraint_count) const char** type_constraint_names, - _In_reads_(type_constraint_count) const ONNXTensorElementDataType* type_constraint_values, - int type_constraint_count, - _In_reads_(attr_count) const OrtOpAttr* const* attr_values, - int attr_count, - int input_count, - int output_count, - _Outptr_ OrtOp** ort_op); - - /** \brief: Invoke the operator created by OrtApi::CreateOp - * The inputs must follow the order as specified in onnx specification - * - * \param[in] context Kernel context - * \param[in] ort_op Operator that has been created - * \param[in] input_values Array of inputs - * \param[in] input_count Number of inputs - * \param[in] output_values Array of outputs - * \param[in] output_count Number of outputs - * - * \since Version 1.12. - */ - ORT_API2_STATUS(InvokeOp, - _In_ const OrtKernelContext* context, - _In_ const OrtOp* ort_op, - _In_ const OrtValue* const* input_values, - _In_ int input_count, - _Inout_ OrtValue* const* output_values, - _In_ int output_count); - - /* \brief: Release an onnxruntime operator - * - * \param[in] Op Operator created by OrtApi::CreateOp - * - * \since Version 1.12. - */ - ORT_CLASS_RELEASE(Op); - - /** \brief: Append execution provider to the session options. - * \param[in] options - * \param[in] provider_name - provider to add. - * \param[in] provider_options_keys - keys to configure the provider options - * \param[in] provider_options_values - values to configure the provider options - * \param[in] num_keys - number of keys passed in - * - * Currently supported providers: - * QNN - * SNPE - * XNNPACK - * - * Note: If an execution provider has a dedicated SessionOptionsAppendExecutionProvider_ function - * that should be used to add it. - * - * QNN supported keys: - * "backend_path": file path to QNN backend library. - * "profiling_level": QNN profiling level, options: "off", "basic", "detailed". Default to off. - * "profiling_file_path": QNN profiling file path if ETW not enabled. - * "rpc_control_latency": QNN RPC control latency. - * "vtcm_mb": QNN VTCM size in MB. default to 0(not set). - * "htp_performance_mode": QNN performance mode, options: "burst", "balanced", "default", "high_performance", - * "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver", "sustained_high_performance". Default to "default". - * "qnn_saver_path": File path to the QNN Saver backend library. If specified, QNN Saver will be enabled and will - * dump QNN API calls to disk for replay/debugging. QNN Saver produces incorrect model inference results and - * may alter model/EP partitioning. Use only for debugging. - * "qnn_context_priority": QNN context priority, options: "low", "normal", "normal_high", "high". Default to "normal". - * "htp_graph_finalization_optimization_mode": Set the optimization mode for graph finalization on the HTP backend. Available options: - * - "0": Default. - * - "1": Faster preparation time, less optimal graph. - * - "2": Longer preparation time, more optimal graph. - * - "3": Longest preparation time, most likely even more optimal graph. See QNN SDK documentation for specific details. - * "soc_model": The SoC model number. Refer to the QNN SDK documentation for valid values. Defaults to "0" (unknown). - * "htp_arch": The minimum HTP architecture the driver will use to select compatible QNN operators. Available options: - * - "0": Default (none). - * - "68" - * - "69" - * - "73" - * - "75" - * "device_id": The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device). - "enable_htp_fp16_precision": Only used for float32 model. - Enable the float32 model to be inferenced with fp16 precision. Otherwise, it will be fp32 precision. - - "0": Default. With fp32 precision. - - "1": With fp16 precision. - * - * SNPE supported keys: - * "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16", - * "DSP", "DSP_FIXED8_TF", "AIP_FIXED_TF", "AIP_FIXED8_TF". - * Mapping to SNPE Runtime_t definition: CPU, CPU_FLOAT32 => zdl::DlSystem::Runtime_t::CPU; - * GPU, GPU_FLOAT32_16_HYBRID => zdl::DlSystem::Runtime_t::GPU; - * GPU_FLOAT16 => zdl::DlSystem::Runtime_t::GPU_FLOAT16; - * DSP, DSP_FIXED8_TF => zdl::DlSystem::Runtime_t::DSP. - * AIP_FIXED_TF, AIP_FIXED8_TF => zdl::DlSystem::Runtime_t::AIP_FIXED_TF. - * "priority": execution priority, options: "low", "normal". - * "buffer_type": ITensor or user buffers, options: "ITENSOR", user buffer with different types - "TF8", "TF16", "UINT8", "FLOAT". - * "ITENSOR" -- default, ITensor which is float only. - * "TF8" -- quantized model required, "FLOAT" -- for both quantized or non-quantized model - * "enable_init_cache": enable SNPE init caching feature, set to 1 to enabled it. Disabled by default. - * If SNPE is not available (due to a non Snpe enabled build or its dependencies not being installed), this function will fail. - * - * XNNPACK supported keys: - * "intra_op_num_threads": number of thread-pool size to use for XNNPACK execution provider. - * default value is 0, which means to use the session thread-pool size. - * - * \since Version 1.12. - */ - ORT_API2_STATUS(SessionOptionsAppendExecutionProvider, _In_ OrtSessionOptions* options, - _In_ const char* provider_name, - _In_reads_(num_keys) const char* const* provider_options_keys, - _In_reads_(num_keys) const char* const* provider_options_values, - _In_ size_t num_keys); - - /* \brief: Get a copy of kernel info - * - * \param[in] info Kernel info - * \param[out] info_copy Copy of kernel info - * - * \since Version 1.12. - */ - ORT_API2_STATUS(CopyKernelInfo, - _In_ const OrtKernelInfo* info, - _Outptr_ OrtKernelInfo** info_copy); - - /* \brief: Release kernel info - * - * \param[in] KernelInfo A copy of kernel info returned by CopyKernelInfo - * - * \since Version 1.12. - */ - ORT_CLASS_RELEASE(KernelInfo); - - /// \name Ort Training - /// @{ - /** \brief Gets the Training C Api struct - * - * Call this function to access the ::OrtTrainingApi structure that holds pointers to functions that enable - * training with onnxruntime. - * \note A NULL pointer will be returned and no error message will be printed if the training api - * is not supported with this build. A NULL pointer will be returned and an error message will be - * printed if the provided version is unsupported, for example when using a runtime older than the - * version created with this header file. - * - * \param[in] version Must be ::ORT_API_VERSION - * \return The ::OrtTrainingApi struct for the version requested. - * - * \since Version 1.13 - */ - const OrtTrainingApi*(ORT_API_CALL* GetTrainingApi)(uint32_t version)NO_EXCEPTION; - - /// @} - - /** \brief Append CANN provider to session options - * - * If CANN is not available (due to a non CANN enabled build, or if CANN is not installed on the system), this function will return failure. - * - * \param[in] options - * \param[in] cann_options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.13. - */ - ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_CANN, - _In_ OrtSessionOptions* options, _In_ const OrtCANNProviderOptions* cann_options); - - /** \brief Create an OrtCANNProviderOptions - * - * \param[out] out created ::OrtCANNProviderOptions. Must be released with OrtApi::ReleaseCANNProviderOptions - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.13. - */ - ORT_API2_STATUS(CreateCANNProviderOptions, _Outptr_ OrtCANNProviderOptions** out); - - /** \brief Set options in a CANN Execution Provider. - * - * \param[in] cann_options - * \param[in] provider_options_keys Array of UTF-8 null-terminated string for provider options keys - * \param[in] provider_options_values Array of UTF-8 null-terminated string for provider options values - * \param[in] num_keys Number of elements in the `provider_option_keys` and `provider_options_values` arrays - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.13. - */ - ORT_API2_STATUS(UpdateCANNProviderOptions, _Inout_ OrtCANNProviderOptions* cann_options, - _In_reads_(num_keys) const char* const* provider_options_keys, - _In_reads_(num_keys) const char* const* provider_options_values, - _In_ size_t num_keys); - - /** \brief Get serialized CANN provider options string. - * - * \param[in] cann_options OrtCANNProviderOptions instance - * \param[in] allocator a ptr to an instance of OrtAllocator obtained with CreateAllocator() - * or GetAllocatorWithDefaultOptions(), the specified allocator will be used to allocate - * continuous buffers for output strings and lengths. - * \param[out] ptr is a UTF-8 null terminated string allocated using 'allocator'. - * The caller is responsible for using the same allocator to free it. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.13. - */ - ORT_API2_STATUS(GetCANNProviderOptionsAsString, _In_ const OrtCANNProviderOptions* cann_options, - _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); - - /** \brief Release an OrtCANNProviderOptions - * - * \param[in] the pointer of OrtCANNProviderOptions which will been deleted - * - * \since Version 1.13. - */ - void(ORT_API_CALL* ReleaseCANNProviderOptions)(_Frees_ptr_opt_ OrtCANNProviderOptions* input); - - /* \brief Get OrtDevice type from MemoryInfo - * - * \since Version 1.14 - */ - void(ORT_API_CALL* MemoryInfoGetDeviceType)(_In_ const OrtMemoryInfo* ptr, _Out_ OrtMemoryInfoDeviceType* out); - - /* \brief Update the OrtEnv instance with custom log severity level - * - * \param[in] ort_env The OrtEnv instance being used - * \param[in] log_severity_level The log severity level. - * - * \since Version 1.14. - */ - ORT_API2_STATUS(UpdateEnvWithCustomLogLevel, _In_ OrtEnv* ort_env, OrtLoggingLevel log_severity_level); - - /* \brief Set affinities for intra op threads - * - * Affinity string follows format: - * logical_processor_id,logical_processor_id;logical_processor_id,logical_processor_id - * Semicolon isolates configurations among threads, while comma split processors where ith thread expected to attach to. - * e.g. 1,2,3;4,5 - * specifies affinities for two threads, with the 1st thread attach to the 1st, 2nd, and 3rd processor, and 2nd thread to the 4th and 5th. - * To ease the configuration, an "interval" is also allowed: - * e.g. 1-8;8-16;17-24 - * orders that the 1st thread runs on first eight processors, 2nd thread runs on next eight processors, and so forth. - * Note: - * 1. Once set, the number of thread affinities must equal to intra_op_num_threads - 1, - * ort does not set affinity on the main thread which is started and managed by the calling app; - * 2. For windows, ort will infer the group id from a logical processor id, for example, assuming there are two groups with each has 64 logical processors, - * an id of 64 will be inferred as the last processor of the 1st group, while 65 will be interpreted as the 1st processor of the second group. - * Hence 64-65 is an invalid configuration, because a windows thread cannot be attached to processors across group boundary. - * - * \since Version 1.14 - */ - ORT_API2_STATUS(SetGlobalIntraOpThreadAffinity, _Inout_ OrtThreadingOptions* tp_options, const char* affinity_string); - - /** \brief Register custom ops from a shared library. - * - * Loads a shared library (.dll on windows, .so on linux, etc) named 'library_name' and looks for this entry point: - * OrtStatus* RegisterCustomOps(OrtSessionOptions * options, const OrtApiBase* api); - * It then passes in the provided session options to this function along with the api base. - * - * The handle to the loaded library is automatically released by ORT when the last OrtSession that references the - * library handle is released. If no OrtSession is created, then the library handle is released when the provided - * OrtSessionOptions is released. - * - * \param[in] options The session options. - * \param[in] library_name The name of the shared library to load and register. Refer to OS-specific dynamic library - * loading utilities (e.g., LoadLibraryEx on Windows or dlopen on Linux/MacOS) for information - * on the format of library names and search paths. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.14 - */ - ORT_API2_STATUS(RegisterCustomOpsLibrary_V2, _Inout_ OrtSessionOptions* options, _In_ const ORTCHAR_T* library_name); - - /** \brief Register custom ops by calling a RegisterCustomOpsFn function. - * - * Searches for registration_func_name and if found calls it. - * - * The library containing the function must either be linked against or previously loaded by the executable. - * - * If you want ONNX Runtime to load the library and manage its lifetime, use RegisterCustomOpsLibrary_V2. - * - * RegisterCustomOpsUsingFunction can be used in scenarios where it may not be possible for ONNX Runtime to load - * the library from a path. e.g. mobile platforms where the library must be linked into the app. - * - * The registration function must have the signature of RegisterCustomOpsFn: - * OrtStatus* (*fn)(OrtSessionOptions* options, const OrtApiBase* api); - * - * See https://onnxruntime.ai/docs/reference/operators/add-custom-op.html for details on how the registration - * function should be implemented. - * - * \param[in] options OrtSessionOptions that is passed through as the first argument in the call to the - * registration function. - * \param[in] registration_func_name Name of registration function to use. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.14 - */ - ORT_API2_STATUS(RegisterCustomOpsUsingFunction, _Inout_ OrtSessionOptions* options, - _In_ const char* registration_func_name); - - /// \name OrtKernelInfo - /// Custom operator APIs. - /// @{ - - /** \brief Get the number of inputs from ::OrtKernelInfo. - * - * Used in the CreateKernel callback of an OrtCustomOp to query the number of inputs - * during kernel/session creation. - * - * \param[in] info Instance of ::OrtKernelInfo. - * \param[out] out Pointer to variable assigned with the result on success. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.14 - */ - ORT_API2_STATUS(KernelInfo_GetInputCount, _In_ const OrtKernelInfo* info, _Out_ size_t* out); - - /** \brief Get the number of outputs from ::OrtKernelInfo. - * - * Used in the CreateKernel callback of an OrtCustomOp to query the number of outputs - * during kernel/session creation. - * - * \param[in] info Instance of ::OrtKernelInfo. - * \param[out] out Pointer to variable assigned with the result on success. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.14 - */ - ORT_API2_STATUS(KernelInfo_GetOutputCount, _In_ const OrtKernelInfo* info, _Out_ size_t* out); - - /** \brief Get the name of a ::OrtKernelInfo's input. - * - * Used in the CreateKernel callback of an OrtCustomOp to query an input's name - * during kernel/session creation. - * - * If `out` is nullptr, the value of `size` is set to the size of the name - * string (including null-terminator), and a success status is returned. - * - * If the `size` parameter is greater than or equal to the name string's size, - * the value of `size` is set to the true size of the string (including null-terminator), - * the provided memory is filled with the string's contents, and a success status is returned. - * - * If the `size` parameter is less than the actual string's size and `out` - * is not nullptr, the value of `size` is set to the true size of the string - * and a failure status is returned. - * - * \param[in] info An instance of ::OrtKernelInfo. - * \param[in] index The index of the input name to get. Returns a failure status if out-of-bounds. - * \param[out] out Memory location into which to write the UTF-8 null-terminated string representing the input's name. - * \param[in,out] size Pointer to the size of the `out` buffer. See above comments for details. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.14 - */ - ORT_API2_STATUS(KernelInfo_GetInputName, _In_ const OrtKernelInfo* info, size_t index, _Out_ char* out, - _Inout_ size_t* size); - - /** \brief Get the name of a ::OrtKernelInfo's output. - * - * Used in the CreateKernel callback of an OrtCustomOp to query an output's name - * during kernel/session creation. - * - * If `out` is nullptr, the value of `size` is set to the size of the name - * string (including null-terminator), and a success status is returned. - * - * If the `size` parameter is greater than or equal to the name string's size, - * the value of `size` is set to the true size of the string (including null-terminator), - * the provided memory is filled with the string's contents, and a success status is returned. - * - * If the `size` parameter is less than the actual string's size and `out` - * is not nullptr, the value of `size` is set to the true size of the string - * and a failure status is returned. - * - * \param[in] info An instance of ::OrtKernelInfo. - * \param[in] index The index of the output name to get. Returns a failure status if out-of-bounds. - * \param[out] out Memory location into which to write the UTF-8 null-terminated string representing the output's - * name. - * \param[in,out] size Pointer to the size of the `out` buffer. See above comments for details. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.14 - */ - ORT_API2_STATUS(KernelInfo_GetOutputName, _In_ const OrtKernelInfo* info, size_t index, _Out_ char* out, - _Inout_ size_t* size); - - /** \brief Get the type information for a ::OrtKernelInfo's input. - * - * Used in the CreateKernel callback of an OrtCustomOp to query the shape and type information - * of an input during kernel/session creation. - * - * \param[in] info An instance of ::OrtKernelInfo. - * \param[in] index Which input to get the type information for - * \param[out] type_info Pointer set to the resulting ::OrtTypeInfo. Must be freed with OrtApi::ReleaseTypeInfo. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.14 - */ - ORT_API2_STATUS(KernelInfo_GetInputTypeInfo, _In_ const OrtKernelInfo* info, size_t index, - _Outptr_ OrtTypeInfo** type_info); - - /** \brief Get the type information for a ::OrtKernelInfo's output. - * - * Used in the CreateKernel callback of an OrtCustomOp to query the shape and type information - * of an output during kernel/session creation. - * - * \param[in] info An instance of ::OrtKernelInfo. - * \param[in] index Which input to get the type information for - * \param[out] type_info Pointer set to the resulting ::OrtTypeInfo. Must be freed with OrtApi::ReleaseTypeInfo. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.14 - */ - ORT_API2_STATUS(KernelInfo_GetOutputTypeInfo, _In_ const OrtKernelInfo* info, size_t index, - _Outptr_ OrtTypeInfo** type_info); - - /** \brief Get a ::OrtValue tensor stored as an attribute in the graph node. - * - * Used in the CreateKernel callback of an OrtCustomOp to get a tensor attribute. - * - * \param[in] info ::OrtKernelInfo instance. - * \param[in] name UTF-8 null-terminated string representing the attribute's name. - * \param[in] allocator Allocator used to allocate the internal tensor state. - * \param[out] out Returns newly created ::OrtValue. Must be freed with OrtApi::ReleaseValue, - * which will also free internal tensor state allocated with the provided allocator. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(KernelInfoGetAttribute_tensor, _In_ const OrtKernelInfo* info, _In_z_ const char* name, - _Inout_ OrtAllocator* allocator, _Outptr_ OrtValue** out); - - /// @} - /// \name OrtSessionOptions - /// Custom operator APIs - /// @{ - - /** \brief Checks if the given session configuration entry exists. - * - * The config_key formats are defined in onnxruntime_session_options_config_keys.h - * - * Can be used in a custom operator library to check for session configuration entries - * that target one or more custom operators in the library. Example: The config entry - * custom_op.myop.some_key targets a custom op named "myop". - * - * \param[in] options The ::OrtSessionOptions instance. - * \param[in] config_key A null-terminated UTF-8 string representation of the configuration key. - * \param[out] out Pointer set to 1 if the entry exists and 0 otherwise. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.14 - */ - ORT_API2_STATUS(HasSessionConfigEntry, _In_ const OrtSessionOptions* options, - _In_z_ const char* config_key, _Out_ int* out); - - /** \brief Get a session configuration value. - * - * Returns a failure status if the configuration key does not exist. - * The config_key and the format of config_value are defined in onnxruntime_session_options_config_keys.h - * - * If `config_value` is nullptr, the value of `size` is set to the true size of the string - * value (including null-terminator), and a success status is returned. - * - * If the `size` parameter is greater than or equal to the actual string value's size, - * the value of `size` is set to the true size of the string value, the provided memory - * is filled with the value's contents, and a success status is returned. - * - * If the `size` parameter is less than the actual string value's size and `config_value` - * is not nullptr, the value of `size` is set to the true size of the string value - * and a failure status is returned. - * - * Can be used in a custom operator library to get session configuration entries - * that target one or more custom operators in the library. Example: The config entry - * custom_op.myop.some_key targets a custom op named "myop". - * - * \param[in] options The session options. - * \param[in] config_key A null-terminated UTF-8 string representation of the config key. - * \param[in] config_value Pointer to memory where the null-terminated UTF-8 string value will be stored. - * \param[in,out] size Pointer to the size of the `config_value` buffer. See above comments for details. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.14 - */ - ORT_API2_STATUS(GetSessionConfigEntry, _In_ const OrtSessionOptions* options, - _In_z_ const char* config_key, _Out_ char* config_value, _Inout_ size_t* size); - - /// @} - - /** \brief Append dnnl provider to session options - * - * If oneDNN is not available, this function will return failure. - * - * \param[in] options - * \param[in] dnnl_options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.15. - */ - ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_Dnnl, - _In_ OrtSessionOptions* options, _In_ const OrtDnnlProviderOptions* dnnl_options); - - /** \brief Create an OrtDnnlProviderOptions - * - * \param[out] out Newly created ::OrtDnnlProviderOptions. Must be released with OrtApi::ReleaseDnnlProviderOptions - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.15. - */ - ORT_API2_STATUS(CreateDnnlProviderOptions, _Outptr_ OrtDnnlProviderOptions** out); - - /** \brief Set options in a oneDNN Execution Provider. - * - * Key should be in null terminated string format of the member of ::OrtDnnlProviderOptions - * and value should be its related range. - * - * For example, key="use_arena" and value="1" - * - * \param[in] dnnl_options - * \param[in] provider_options_keys Array of UTF-8 null-terminated string for provider options keys - * \param[in] provider_options_values Array of UTF-8 null-terminated string for provider options values - * \param[in] num_keys Number of elements in the `provider_option_keys` and `provider_options_values` arrays - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.15. - */ - ORT_API2_STATUS(UpdateDnnlProviderOptions, _Inout_ OrtDnnlProviderOptions* dnnl_options, - _In_reads_(num_keys) const char* const* provider_options_keys, - _In_reads_(num_keys) const char* const* provider_options_values, - _In_ size_t num_keys); - - /** - * Get serialized oneDNN provider options string. - * - * For example, "use_arena=1;......" - * - * \param dnnl_options - OrtDnnlProviderOptions instance - * \param allocator - a ptr to an instance of OrtAllocator obtained with CreateAllocator() or GetAllocatorWithDefaultOptions() - * the specified allocator will be used to allocate continuous buffers for output strings and lengths. - * \param ptr - is a UTF-8 null terminated string allocated using 'allocator'. The caller is responsible for using the same allocator to free it. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.15. - */ - ORT_API2_STATUS(GetDnnlProviderOptionsAsString, _In_ const OrtDnnlProviderOptions* dnnl_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); - - /** \brief Release an ::OrtDnnlProviderOptions - * - * \since Version 1.15. - */ - void(ORT_API_CALL* ReleaseDnnlProviderOptions)(_Frees_ptr_opt_ OrtDnnlProviderOptions* input); - - /// \name OrtKernelInfo - /// Custom operator APIs. - /// @{ - - /** \brief Get the graph node name from ::OrtKernelInfo. - * - * If `out` is nullptr, the value of `size` is set to the size of the name - * string (including null-terminator), and a success status is returned. - * - * If the `size` parameter is greater than or equal to the name string's size, - * the value of `size` is set to the true size of the string (including null-terminator), - * the provided memory is filled with the string's contents, and a success status is returned. - * - * If the `size` parameter is less than the actual string's size and `out` - * is not nullptr, the value of `size` is set to the true size of the string - * and a failure status is returned. - * - * Can be used in a custom operator's CreateKernel callback to get the name of the operator's node name in the graph. - * - * \param[in] info An instance of ::OrtKernelInfo. - * \param[out] out Memory location into which to write the UTF-8 null-terminated string representing the name. - * \param[in,out] size Pointer to the size of the `out` buffer. See above comments for details. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.15 - */ - ORT_API2_STATUS(KernelInfo_GetNodeName, _In_ const OrtKernelInfo* info, _Out_ char* out, _Inout_ size_t* size); - - /** \brief Get the session logger from ::OrtKernelInfo. - * - * Used in the CreateKernel callback of an OrtCustomOp to get a logger that can be used to log - * messages. - * - * \param[in] info An instance of ::OrtKernelInfo. - * \param[out] logger Pointer set to the session's ::OrtLogger. Owned by ONNX Runtime, so do not free. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.15 - */ - ORT_API2_STATUS(KernelInfo_GetLogger, _In_ const OrtKernelInfo* info, _Outptr_ const OrtLogger** logger); - - /// @} - /// \name OrtKernelContext - /// Custom operator APIs. - /// @{ - - /** \brief Get the runtime logger from ::OrtKernelContext. - * - * Used in the KernelCompute callback of an OrtCustomOp to get a logger that can be used to log - * messages during inference. - * - * \param[in] context An instance of ::OrtKernelContext. - * \param[out] logger Pointer set to the kernel context's ::OrtLogger. Owned by ONNX Runtime, so do not free. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.15 - */ - ORT_API2_STATUS(KernelContext_GetLogger, _In_ const OrtKernelContext* context, _Outptr_ const OrtLogger** logger); - - /// @} - /// \name OrtLogger - /// Custom operator APIs. - /// @{ - - /** \brief Logs a message at the given severity level using the provided ::OrtLogger. - * - * Only messages with a severity level equal or greater than the ::OrtLogger's logging severity level - * are logged. Use OrtApi::Logger_GetLoggingSeverityLevel to get the ::OrtLogger's logging severity - * level. - * - * Can be used in custom operators to log messages with the logger retrieved via OrtApi::KernelInfo_GetLogger. - * - * \param[in] logger The ::OrtLogger instance. - * \param[in] log_severity_level The message's severity level. - * \param[in] message The message to log. - * \param[in] file_path The filepath of the file in which the message is logged. Usually the value of ORT_FILE. - * \param[in] line_number The file line number in which the message is logged. Usually the value of __LINE__. - * \param[in] func_name The name of the function in which the message is logged. Usually the value of __FUNCTION__. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.15 - */ - ORT_API2_STATUS(Logger_LogMessage, _In_ const OrtLogger* logger, OrtLoggingLevel log_severity_level, - _In_z_ const char* message, _In_z_ const ORTCHAR_T* file_path, int line_number, - _In_z_ const char* func_name); - - /** \brief Get the logging severity level of the ::OrtLogger. - * - * Can be used in a custom operator to get the logging serverity level of the ::OrtLogger associated with - * the ::OrtKernelInfo. - * - * \param[in] logger The ::OrtLogger instance. - * \param[out] out Pointer to variable assigned with the logging severity level on success. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.15 - */ - ORT_API2_STATUS(Logger_GetLoggingSeverityLevel, _In_ const OrtLogger* logger, _Out_ OrtLoggingLevel* out); - - /// @} - - /** \brief Get a ::OrtValue tensor stored as a constant initializer in the graph node. - * - * Used in the CreateKernel callback of an OrtCustomOp to get a tensor value. - * - * \param[in] info ::OrtKernelInfo instance. - * \param[in] index The node index. - * \param[out] is_constant Is it a constant node input or not. - * \param[out] out The OrtValue tensor value. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.15. - */ - ORT_API2_STATUS(KernelInfoGetConstantInput_tensor, _In_ const OrtKernelInfo* info, size_t index, _Out_ int* is_constant, _Outptr_ const OrtValue** out); - - /** \brief Get Optional Type information from an ::OrtTypeInfo - * - * This augments ::OrtTypeInfo to return an ::OrtOptionalTypeInfo when the type is optional. - * The OrtOptionalTypeInfo also has a nested ::OrtTypeInfo that describes the type of the optional value. - * ::OrtOptionalTypeInfo type can only appear within model metadata to describe inputs/outputs. - * The actual OrtValues that are supplied in place of optional type inputs should contain - * specific type that is described by ::OrtOptionalTypeInfo. - * - * So the picture: ::OrtTypeInfo -> ::OrtOptionalTypeInfo -> ::OrtTypeInfo (describes the type that can be supplied - * in place of the optional type when creating the actual ::OrtValue). - * - * \param[in] type_info - * \param[out] out A pointer to the ::OrtOptionalTypeInfo. Do not free this value, - * it is owned by OrtTypeInfo instance. When the type_info does not represent - * optional type, nullptr is returned in out. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.15. - */ - ORT_API2_STATUS(CastTypeInfoToOptionalTypeInfo, _In_ const OrtTypeInfo* type_info, - _Outptr_result_maybenull_ const OrtOptionalTypeInfo** out); - - /** \brief Get OrtTypeInfo for the allowed contained type from an ::OrtOptionalTypeInfo. - * - * This augments ::OrtOptionalTypeInfo to return an ::OrtTypeInfo for the contained type. - * The OrtOptionalTypeInfo has a nested ::OrtTypeInfo that describes the type of the optional value. - * ::OrtOptionalTypeInfo type can only appear within model metadata to describe inputs/outputs. - * The actual OrtValues that are supplied in place of optional type inputs should contain - * specific type that is described by the returned ::OrtTypeInfo. - * - * \param[in] optional_type_info - * \param[out] out A pointer to the ::OrtTypeInfo for what the optional value could be. - * it is owned by OrtOptionalTypeInfo instance. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.15. - */ - ORT_API2_STATUS(GetOptionalContainedTypeInfo, _In_ const OrtOptionalTypeInfo* optional_type_info, - _Outptr_ OrtTypeInfo** out); - - /** \brief Set a single string in a string tensor - * Do not zero terminate the string data. - * - * \param[in] value A string tensor - * \param[in] index - flat index of the element - * \param[in] length_in_bytes length of the buffer in utf-8 bytes (without the null terminator) - * \param[inout] buffer - address of return value - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetResizedStringTensorElementBuffer, _Inout_ OrtValue* value, _In_ size_t index, _In_ size_t length_in_bytes, _Inout_ char** buffer); - - /** \brief Get Allocator from KernelContext for a specific memoryInfo. Please use C API ReleaseAllocator to release out object - * - * \param[in] context OrtKernelContext instance - * \param[in] mem_info OrtMemoryInfo instance - * \param[out] out A pointer to OrtAllocator. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.15. - */ - ORT_API2_STATUS(KernelContext_GetAllocator, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _Outptr_ OrtAllocator** out); - - /** \brief Returns a null terminated string of the build info including git info and cxx flags - * - * \return UTF-8 encoded version string. Do not deallocate the returned buffer. - * - * \since Version 1.15. - */ - const char*(ORT_API_CALL* GetBuildInfoString)(void); - - /// \name OrtROCMProviderOptions - /// @{ - - /** \brief Create an OrtROCMProviderOptions - * - * \param[out] out Newly created ::OrtROCMProviderOptions. Must be released with OrtApi::ReleaseROCMProviderOptions - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.16. - */ - ORT_API2_STATUS(CreateROCMProviderOptions, _Outptr_ OrtROCMProviderOptions** out); - - /** \brief Set options in a ROCm Execution Provider. - * - * Please refer to https://onnxruntime.ai/docs/execution-providers/ROCm-ExecutionProvider.html - * to know the available keys and values. Key should be in null terminated string format of the member of - * ::OrtROCMProviderOptions and value should be its related range. - * - * For example, key="device_id" and value="0" - * - * \param[in] rocm_options - * \param[in] provider_options_keys Array of UTF-8 null-terminated string for provider options keys - * \param[in] provider_options_values Array of UTF-8 null-terminated string for provider options values - * \param[in] num_keys Number of elements in the `provider_option_keys` and `provider_options_values` arrays - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.16. - */ - ORT_API2_STATUS(UpdateROCMProviderOptions, _Inout_ OrtROCMProviderOptions* rocm_options, - _In_reads_(num_keys) const char* const* provider_options_keys, - _In_reads_(num_keys) const char* const* provider_options_values, - _In_ size_t num_keys); - - /** - * Get serialized ROCm provider options string. - * - * For example, "device_id=0;arena_extend_strategy=0;......" - * - * \param rocm_options - OrtROCMProviderOptions instance - * \param allocator - a ptr to an instance of OrtAllocator obtained with CreateAllocator() or GetAllocatorWithDefaultOptions() - * the specified allocator will be used to allocate continuous buffers for output strings and lengths. - * \param ptr - is a UTF-8 null terminated string allocated using 'allocator'. The caller is responsible for using the same allocator to free it. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.16. - */ - ORT_API2_STATUS(GetROCMProviderOptionsAsString, _In_ const OrtROCMProviderOptions* rocm_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); - - /** \brief Release an ::OrtROCMProviderOptions - * - * \note This is an exception in the naming convention of other Release* functions, as the name of the method does not have the V2 suffix, but the type does - * - * \since Version 1.16. - */ - void(ORT_API_CALL* ReleaseROCMProviderOptions)(_Frees_ptr_opt_ OrtROCMProviderOptions* input); - - /** \brief Create an allocator with specific type and register it with the ::OrtEnv - * This API enhance CreateAndRegisterAllocator that it can create an allocator with specific type, not just CPU allocator - * Enables sharing the allocator between multiple sessions that use the same env instance. - * Lifetime of the created allocator will be valid for the duration of the environment. - * Returns an error if an allocator with the same ::OrtMemoryInfo is already registered. - * \param[in] env OrtEnv instance - * \param[in] provider_type ExecutionProvider type - * \param[in] mem_info OrtMemoryInfo instance - * \param[in] arena_cfg Arena configuration - * \param[in] provider_options_keys key of the provider options map - * \param[in] provider_options_values value of the provider options map - * \param[in] num_keys Length of the provider options map - */ - ORT_API2_STATUS(CreateAndRegisterAllocatorV2, _Inout_ OrtEnv* env, _In_ const char* provider_type, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg, - _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); - - /** \brief Run the model asynchronously in a thread owned by intra op thread pool - * - * \param[in] session - * \param[in] run_options If nullptr, will use a default ::OrtRunOptions - * \param[in] input_names Array of null terminated UTF8 encoded strings of the input names - * \param[in] input Array of ::OrtValue%s of the input values - * \param[in] input_len Number of elements in the input_names and inputs arrays - * \param[in] output_names Array of null terminated UTF8 encoded strings of the output names - * \param[in] output_names_len Number of elements in the output_names and outputs array - * \param[out] output OrtValue* array of size output_names_len. - * On calling RunAsync, output[i] could either be a null or a pointer to a preallocated OrtValue. - * Later, the output array will be passed to run_async_callback with all null(s) filled with valid - * OrtValue pointer(s) allocated by onnxruntime. - * NOTE: it is customer's duty to finally release the output array and each of its member, - * regardless of whether the member (OrtValue*) is allocated by onnxruntime or preallocated by the customer. - * \param[in] run_async_callback Callback function on model run completion - * \param[in] user_data User data that pass back to run_async_callback - */ - ORT_API2_STATUS(RunAsync, _Inout_ OrtSession* session, _In_opt_ const OrtRunOptions* run_options, - _In_reads_(input_len) const char* const* input_names, - _In_reads_(input_len) const OrtValue* const* input, size_t input_len, - _In_reads_(output_names_len) const char* const* output_names, size_t output_names_len, - _Inout_updates_all_(output_names_len) OrtValue** output, - _In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data); - - /** - * Update TensorRT EP provider option where its data type is pointer, for example 'user_compute_stream'. - * If the data type of the provider option can be represented by string please use UpdateTensorRTProviderOptions. - * - * Note: It's caller's responsibility to properly manage the lifetime of the instance pointed by this pointer. - * - * \param tensorrt_options - OrtTensorRTProviderOptionsV2 instance - * \param key - Name of the provider option - * \param value - A pointer to the instance that will be assigned to this provider option - * - * \since Version 1.16. - */ - ORT_API2_STATUS(UpdateTensorRTProviderOptionsWithValue, _Inout_ OrtTensorRTProviderOptionsV2* tensorrt_options, _In_ const char* key, _In_ void* value); - - /** - * Get TensorRT EP provider option where its data type is pointer. - * If the data type of the provider option can be represented by string please use GetTensorRTProviderOptionsAsString. - * - * \param tensorrt_options - OrtTensorRTProviderOptionsV2 instance - * \param key - Name of the provider option - * \param ptr - A pointer to the instance that is kept by the provider option - * - * \since Version 1.16. - */ - ORT_API2_STATUS(GetTensorRTProviderOptionsByName, _In_ const OrtTensorRTProviderOptionsV2* tensorrt_options, _In_ const char* key, _Outptr_ void** ptr); - - /** - * Update CUDA EP provider option where its data type is pointer, for example 'user_compute_stream'. - * If the data type of the provider option can be represented by string please use UpdateCUDAProviderOptions. - * - * Note: It's caller's responsibility to properly manage the lifetime of the instance pointed by this pointer. - * - * \param cuda_options - OrtCUDAProviderOptionsV2 instance - * \param key - Name of the provider option - * \param value - A pointer to the instance that will be assigned to this provider option - * - * \since Version 1.16. - */ - ORT_API2_STATUS(UpdateCUDAProviderOptionsWithValue, _Inout_ OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _In_ void* value); - - /** - * Get CUDA EP provider option where its data type is pointer. - * If the data type of the provider option can be represented by string please use GetCUDAProviderOptionsAsString. - * - * \param cuda_options - OrtCUDAProviderOptionsV2 instance - * \param key - Name of the provider option - * \param ptr - A pointer to the instance that is kept by the provider option - * - * \since Version 1.16. - */ - ORT_API2_STATUS(GetCUDAProviderOptionsByName, _In_ const OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _Outptr_ void** ptr); - - /** - * Get a EP resource. - * E.g. a cuda stream or a cublas handle - * - * \param context - Kernel context - * \param resource_version - Version of the resource - * \param resource_id - Type of resource - * \param resource - A pointer to returned resource - * - * \since Version 1.16. - */ - ORT_API2_STATUS(KernelContext_GetResource, _In_ const OrtKernelContext* context, _In_ int resource_version, - _In_ int resource_id, _Outptr_ void** resource); - - /** \brief Set user logging function - * - * By default the logger created by the CreateEnv* functions is used to create the session logger as well. - * This function allows a user to override this default session logger with a logger of their own choosing. This way - * the user doesn't have to create a separate environment with a custom logger. This addresses the problem when - * the user already created an env but now wants to use a different logger for a specific session (for debugging or - * other reasons). - * - * \param[in] options - * \param[in] user_logging_function A pointer to a logging function. - * \param[in] user_logging_param A pointer to arbitrary data passed as the ::OrtLoggingFunction `param` parameter to - * `user_logging_function`. This parameter is optional. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.17. - */ - ORT_API2_STATUS(SetUserLoggingFunction, _Inout_ OrtSessionOptions* options, - _In_ OrtLoggingFunction user_logging_function, _In_opt_ void* user_logging_param); - - /** - * Get number of input from OrtShapeInferContext - * - * \param[in] context - * \param[out] out The number of inputs - * - * \since Version 1.17. - */ - ORT_API2_STATUS(ShapeInferContext_GetInputCount, _In_ const OrtShapeInferContext* context, _Out_ size_t* out); - - /** - * Get type and shape info of an input - * - * \param[in] context - * \param[in] index The index of the input - * \param[out] info Type shape info of the input - * - * \since Version 1.17. - */ - ORT_API2_STATUS(ShapeInferContext_GetInputTypeShape, _In_ const OrtShapeInferContext* context, _In_ size_t index, _Outptr_ OrtTensorTypeAndShapeInfo** info); - - /** - * Get attribute from OrtShapeInferContext. Note that OrtShapeInferContext is a per-node context, one could only read attribute from current node. - * - * \param[in] context - * \param[in] attr_name Name of the attribute - * \param[out] attr Handle of the attribute fetched - * - * \since Version 1.17. - */ - ORT_API2_STATUS(ShapeInferContext_GetAttribute, _In_ const OrtShapeInferContext* context, _In_ const char* attr_name, _Outptr_ const OrtOpAttr** attr); - - /** - * Set type and shape info of an output - * - * \param[in] context - * \param[in] index The index of the output - * \param[out] info Type shape info of the output - * - * \since Version 1.17. - */ - ORT_API2_STATUS(ShapeInferContext_SetOutputTypeShape, _In_ const OrtShapeInferContext* context, _In_ size_t index, _In_ const OrtTensorTypeAndShapeInfo* info); - - /** - * Set symbolic shape to type shape info - * - * \param[in] info Type shape info - * \param[in] dim_params Symbolic strings - * \param[in] dim_params_length Number of strings - * - * \since Version 1.17. - */ - ORT_API2_STATUS(SetSymbolicDimensions, _In_ OrtTensorTypeAndShapeInfo* info, _In_ const char* dim_params[], _In_ size_t dim_params_length); - - /** - * Read contents of an attribute to data - * - * \param[in] op_attr - * \param[in] type Attribute type - * \param[out] data Memory address to save raw content of the attribute - * \param[in] len Number of bytes allowed to store in data - * \param[out] out Number of bytes required to save the data when the call failed, or the real number of bytes saved to data on success - * - * \since Version 1.17. - */ - ORT_API2_STATUS(ReadOpAttr, _In_ const OrtOpAttr* op_attr, _In_ OrtOpAttrType type, _Inout_ void* data, _In_ size_t len, _Out_ size_t* out); - - /** \brief Set whether to use deterministic compute. - * - * Default is false. If set to true, this will enable deterministic compute for GPU kernels where possible. - * Note that this most likely will have a performance cost. - * - * \param[in] options - * \param[in] value - * - * \since Version 1.17. - */ - ORT_API2_STATUS(SetDeterministicCompute, _Inout_ OrtSessionOptions* options, bool value); - - /** - * Run fn in parallel - * - * \param[in] context - * \param[in] fn Function accepting usr_data and an integer as iterator - * \param[in] total The number of times fn is to be invoked - * \param[in] num_batch Number of batches by which the "total" is to be divided in maximum. When zero, there is no limit - * \param[in] usr_data User data to be passed back to fn - * - * \since Version 1.17. - */ - ORT_API2_STATUS(KernelContext_ParallelFor, _In_ const OrtKernelContext* context, _In_ void (*fn)(void*, size_t), _In_ size_t total, _In_ size_t num_batch, _In_ void* usr_data); - - /** \brief Append OpenVINO execution provider to the session options - * - * If OpenVINO is not available (due to a non OpenVINO enabled build, or if OpenVINO is not installed on the system), this function will fail. - * - * \param[in] options - * \param[in] provider_options_keys - * \param[in] provider_options_values - * \param[in] num_keys - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_OpenVINO_V2, - _In_ OrtSessionOptions* options, - _In_reads_(num_keys) const char* const* provider_options_keys, - _In_reads_(num_keys) const char* const* provider_options_values, - _In_ size_t num_keys); - - /** \brief Append VitisAI provider to session options - * - * If VitisAI is not available (due to a non VitisAI enabled build, or if VitisAI is not installed on the system), this function will return failure. - * - * \param[in] options - * \param[in] provider_options_keys - * \param[in] provider_options_values - * \param[in] num_keys - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_VitisAI, - _In_ OrtSessionOptions* options, - _In_reads_(num_keys) const char* const* provider_options_keys, - _In_reads_(num_keys) const char* const* provider_options_values, - _In_ size_t num_keys); - - /** \brief Get scratch buffer from the corresponding allocator under the sepcific OrtMemoryInfo object. - * NOTE: callers are responsible to release this scratch buffer from the corresponding allocator - * \param[in] context OrtKernelContext instance - * \param[in] mem_info OrtMemoryInfo instance - * \param[in] count_or_bytes How many bytes is this scratch buffer - * \param[out] out A pointer to the scrach buffer - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out); - - /** \brief Get allocator from KernelInfo for a specific memory type. Please use C API ReleaseAllocator to release out object - * - * \param[in] info OrtKernelInfo instance - * \param[in] mem_type OrtMemType object - * \param[out] out A pointer to OrtAllocator - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out); - - /** \brief Replace initialized Tensors with external data with the provided files in memory - * - * The function will find the initialized TensorProtos with external data in the graph with the provided - * external file names and the file content in memory. The API gets the external file name, offset, data length - * from TensorProto, and locate the tensor data from the file in memory buffer. - * It creates a Tensor to replace the existing Tensor in graph. The replacement - * will occur before any of the optimizations take place. The data will be copied into the graph - * since TensorProto can't refer to the user provided buffers. - * - * \param[in] options - * \param[in] external_initializer_file_names Array of null terminated UTF-8 encoded strings of the file names - * which holds the external initializers. - * \param[in] external_initializer_file_buffer_array Array of pointers to the buffer of the file content. - * The buffer can be freed after session creation. - * \param[in] external_initializer_file_lengths Array of size_t to indicate the length of file content - * \param[in] num_external_initializer_files Number of external files - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(AddExternalInitializersFromFilesInMemory, _In_ OrtSessionOptions* options, - _In_reads_(num_external_initializer_files) const ORTCHAR_T* const* external_initializer_file_names, - _In_reads_(num_external_initializer_files) char* const* external_initializer_file_buffer_array, - _In_reads_(num_external_initializer_files) const size_t* external_initializer_file_lengths, - size_t num_external_initializer_files); -}; - -/* - * Steps to use a custom op: - * 1 Create an OrtCustomOpDomain with the domain name used by the custom ops - * 2 Create an OrtCustomOp structure for each op and add them to the domain - * 3 Call OrtAddCustomOpDomain to add the custom domain of ops to the session options - */ - -// Specifies some characteristics of inputs/outputs of custom ops: -// Specify if the inputs/outputs are one of: -// 1) Non-optional (input/output must be present in the node) -// 2) Optional (input/output may be absent in the node) -// 3) Variadic: A variadic input or output specifies N (i.e., the minimum arity) or more operands. -// Only the last input or output of a custom op may be marked as variadic. -// The homogeneity of the variadic input or output determines whether all operands must be of the same -// tensor element type. -typedef enum OrtCustomOpInputOutputCharacteristic { - INPUT_OUTPUT_REQUIRED = 0, - INPUT_OUTPUT_OPTIONAL, - INPUT_OUTPUT_VARIADIC, -} OrtCustomOpInputOutputCharacteristic; - -/* - * The OrtCustomOp structure defines a custom op's schema and its kernel callbacks. The callbacks are filled in by - * the implementor of the custom op. - */ -struct OrtCustomOp { - uint32_t version; // Must be initialized to ORT_API_VERSION - - // This callback creates the kernel, which is a user defined - // parameter that is passed to the Kernel* callbacks below. It is - // recommended to use CreateKernelV2 which allows for a safe error - // propagation by returning an OrtStatusPtr. - void*(ORT_API_CALL* CreateKernel)(_In_ const struct OrtCustomOp* op, _In_ const OrtApi* api, - _In_ const OrtKernelInfo* info); - - // Returns the name of the op - const char*(ORT_API_CALL* GetName)(_In_ const struct OrtCustomOp* op); - - // Returns the type of the execution provider, return nullptr to use CPU execution provider - const char*(ORT_API_CALL* GetExecutionProviderType)(_In_ const struct OrtCustomOp* op); - - // Returns the count and types of the input & output tensors - ONNXTensorElementDataType(ORT_API_CALL* GetInputType)(_In_ const struct OrtCustomOp* op, _In_ size_t index); - size_t(ORT_API_CALL* GetInputTypeCount)(_In_ const struct OrtCustomOp* op); - ONNXTensorElementDataType(ORT_API_CALL* GetOutputType)(_In_ const struct OrtCustomOp* op, _In_ size_t index); - size_t(ORT_API_CALL* GetOutputTypeCount)(_In_ const struct OrtCustomOp* op); - - // Perform a computation step. It is recommended to use - // KernelComputeV2 which allows for a safe error propagation by - // returning an OrtStatusPtr. - void(ORT_API_CALL* KernelCompute)(_In_ void* op_kernel, _In_ OrtKernelContext* context); - void(ORT_API_CALL* KernelDestroy)(_In_ void* op_kernel); - - // Returns the characteristics of the input & output tensors - OrtCustomOpInputOutputCharacteristic(ORT_API_CALL* GetInputCharacteristic)(_In_ const struct OrtCustomOp* op, _In_ size_t index); - OrtCustomOpInputOutputCharacteristic(ORT_API_CALL* GetOutputCharacteristic)(_In_ const struct OrtCustomOp* op, _In_ size_t index); - - // Returns the memory type of the input tensors. This API allows the custom op - // to place the inputs on specific devices. By default, it returns - // OrtMemTypeDefault, which means the input is placed on the default device for - // the execution provider. If the inputs need to be with different memory tyeps, - // this function can be overridden to return the specific memory types. - OrtMemType(ORT_API_CALL* GetInputMemoryType)(_In_ const struct OrtCustomOp* op, _In_ size_t index); - - // Returns the minimum number of input arguments expected for the variadic input. - // Applicable only for custom ops that have a variadic input. - int(ORT_API_CALL* GetVariadicInputMinArity)(_In_ const struct OrtCustomOp* op); - - // Returns true (non-zero) if all arguments of a variadic input have to be of the same type (homogeneous), - // and false (zero) otherwise. - // Applicable only for custom ops that have a variadic input. - int(ORT_API_CALL* GetVariadicInputHomogeneity)(_In_ const struct OrtCustomOp* op); - - // Returns the minimum number of output values expected for the variadic output. - // Applicable only for custom ops that have a variadic output. - int(ORT_API_CALL* GetVariadicOutputMinArity)(_In_ const struct OrtCustomOp* op); - - // Returns true (non-zero) if all outputs values of a variadic output have to be of the same type (homogeneous), - // and false (zero) otherwise. - // Applicable only for custom ops that have a variadic output. - int(ORT_API_CALL* GetVariadicOutputHomogeneity)(_In_ const struct OrtCustomOp* op); - - // Create the kernel state which is passed to each compute call. - OrtStatusPtr(ORT_API_CALL* CreateKernelV2)(_In_ const struct OrtCustomOp* op, _In_ const OrtApi* api, - _In_ const OrtKernelInfo* info, - _Out_ void** kernel); - - // Perform the computation step. - OrtStatusPtr(ORT_API_CALL* KernelComputeV2)(_In_ void* op_kernel, _In_ OrtKernelContext* context); - - OrtStatusPtr(ORT_API_CALL* InferOutputShapeFn)(_In_ const struct OrtCustomOp* op, _In_ OrtShapeInferContext*); - - // Get start range - int(ORT_API_CALL* GetStartVersion)(_In_ const struct OrtCustomOp* op); - int(ORT_API_CALL* GetEndVersion)(_In_ const struct OrtCustomOp* op); - - // Get the inplace_map that defines which output can reuse which input - // Callers will provide 2 raw int* and pass in their address, this function will fill these 2 arrays - // when return, output (*output_index)[i] may reuse the input (*input_index[i]). - // The return value is the size of these 2 arrays. - // Callers are responsible to delete these 2 arrays after use by calling OrtCustomOp::ReleaseMayInplace(). - size_t(ORT_API_CALL* GetMayInplace)(_Out_ int** input_index, _Out_ int** output_index); - - // Release the pointer input_index and output_index allocated from GetMayInplace() function. - // If GetMayInplace() is defined, this function MUST be defined as well. - void(ORT_API_CALL* ReleaseMayInplace)(_Frees_ptr_opt_ int* input_index, _Frees_ptr_opt_ int* output_index); - - // Same as GetMayInplace() and ReleaseMayInplace() - size_t(ORT_API_CALL* GetAliasMap)(_Out_ int** input_index, _Out_ int** output_index); - void(ORT_API_CALL* ReleaseAliasMap)(_Frees_ptr_opt_ int* input_index, _Frees_ptr_opt_ int* output_index); -}; - -/* - * This is the old way to add the CUDA provider to the session, please use SessionOptionsAppendExecutionProvider_CUDA above to access the latest functionality - * This function always exists, but will only succeed if Onnxruntime was built with CUDA support and the CUDA provider shared library exists - * - * \param device_id CUDA device id, starts from zero. - */ -ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessionOptions* options, int device_id); - -/* - * This is the old way to add the ROCm provider to the session, please use - * SessionOptionsAppendExecutionProvider_ROCM above to access the latest functionality - * This function always exists, but will only succeed if Onnxruntime was built with - * HIP support and the ROCm provider shared library exists - * - * \param device_id HIP device id, starts from zero. - */ -ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_ROCM, _In_ OrtSessionOptions* options, int device_id); - -/* - * This is the old way to add the MIGraphX provider to the session, please use - * SessionOptionsAppendExecutionProvider_MIGraphX above to access the latest functionality - * This function always exists, but will only succeed if Onnxruntime was built with - * HIP support and the MIGraphX provider shared library exists - * - * \param device_id HIP device id, starts from zero. - */ -ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_MIGraphX, _In_ OrtSessionOptions* options, int device_id); - -/* - * This is the old way to add the oneDNN provider to the session, please use - * SessionOptionsAppendExecutionProvider_oneDNN above to access the latest functionality - * This function always exists, but will only succeed if Onnxruntime was built with - * oneDNN support and the oneDNN provider shared library exists - * - * \param use_arena zero: false. non-zero: true. - */ -ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessionOptions* options, int use_arena); - -/* - * This is the old way to add the TensorRT provider to the session, please use SessionOptionsAppendExecutionProvider_TensorRT_V2 above to access the latest functionality - * This function always exists, but will only succeed if Onnxruntime was built with TensorRT support and the TensorRT provider shared library exists - * - * \param device_id CUDA device id, starts from zero. - */ -ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id); - -#ifdef __cplusplus -} -#endif -/// @} diff --git a/tools/onnx_lib/Source/include_rel/onnxruntime_cxx_api.h b/tools/onnx_lib/Source/include_rel/onnxruntime_cxx_api.h deleted file mode 100644 index 29a229f427..0000000000 --- a/tools/onnx_lib/Source/include_rel/onnxruntime_cxx_api.h +++ /dev/null @@ -1,2387 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// Summary: The Ort C++ API is a header only wrapper around the Ort C API. -// -// The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors -// and automatically releasing resources in the destructors. The primary purpose of C++ API is exception safety so -// all the resources follow RAII and do not leak memory. -// -// Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers. -// To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};). However, you can't use them -// until you assign an instance that actually holds an underlying object. -// -// For Ort objects only move assignment between objects is allowed, there are no copy constructors. -// Some objects have explicit 'Clone' methods for this purpose. -// -// ConstXXXX types are copyable since they do not own the underlying C object, so you can pass them to functions as arguments -// by value or by reference. ConstXXXX types are restricted to const only interfaces. -// -// UnownedXXXX are similar to ConstXXXX but also allow non-const interfaces. -// -// The lifetime of the corresponding owning object must eclipse the lifetimes of the ConstXXXX/UnownedXXXX types. They exists so you do not -// have to fallback to C types and the API with the usual pitfalls. In general, do not use C API from your C++ code. - -#pragma once -#include "onnxruntime_c_api.h" -#include "onnxruntime_float16.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#ifdef ORT_NO_EXCEPTIONS -#include -#endif - -/** \brief All C++ Onnxruntime APIs are defined inside this namespace - * - */ -namespace Ort { - -/** \brief All C++ methods that can fail will throw an exception of this type - * - * If ORT_NO_EXCEPTIONS is defined, then any error will result in a call to abort() - */ -struct Exception : std::exception { - Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {} - - OrtErrorCode GetOrtErrorCode() const { return code_; } - const char* what() const noexcept override { return message_.c_str(); } - - private: - std::string message_; - OrtErrorCode code_; -}; - -#ifdef ORT_NO_EXCEPTIONS -// The #ifndef is for the very special case where the user of this library wants to define their own way of handling errors. -// NOTE: This header expects control flow to not continue after calling ORT_CXX_API_THROW -#ifndef ORT_CXX_API_THROW -#define ORT_CXX_API_THROW(string, code) \ - do { \ - std::cerr << Ort::Exception(string, code) \ - .what() \ - << std::endl; \ - abort(); \ - } while (false) -#endif -#else -#define ORT_CXX_API_THROW(string, code) \ - throw Ort::Exception(string, code) -#endif - -// This is used internally by the C++ API. This class holds the global variable that points to the OrtApi, -// it's in a template so that we can define a global variable in a header and make -// it transparent to the users of the API. -template -struct Global { - static const OrtApi* api_; -}; - -// If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it. -template -#ifdef ORT_API_MANUAL_INIT -const OrtApi* Global::api_{}; -inline void InitApi() noexcept { Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); } - -// Used by custom operator libraries that are not linked to onnxruntime. Sets the global API object, which is -// required by C++ APIs. -// -// Example mycustomop.cc: -// -// #define ORT_API_MANUAL_INIT -// #include -// #undef ORT_API_MANUAL_INIT -// -// OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api_base) { -// Ort::InitApi(api_base->GetApi(ORT_API_VERSION)); -// // ... -// } -// -inline void InitApi(const OrtApi* api) noexcept { Global::api_ = api; } -#else -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(push) -// "Global initializer calls a non-constexpr function." Therefore you can't use ORT APIs in the other global initializers. -// Please define ORT_API_MANUAL_INIT if it conerns you. -#pragma warning(disable : 26426) -#endif -const OrtApi* Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(pop) -#endif -#endif - -/// This returns a reference to the OrtApi interface in use -inline const OrtApi& GetApi() noexcept { return *Global::api_; } - -/// -/// This function returns the onnxruntime version string -/// -/// version string major.minor.rev -std::string GetVersionString(); - -/// -/// This function returns the onnxruntime build information: including git branch, -/// git commit id, build type(Debug/Release/RelWithDebInfo) and cmake cpp flags. -/// -/// string -std::string GetBuildInfoString(); - -/// -/// This is a C++ wrapper for OrtApi::GetAvailableProviders() and -/// returns a vector of strings representing the available execution providers. -/// -/// vector of strings -std::vector GetAvailableProviders(); - -/** \brief IEEE 754 half-precision floating point data type - * - * \details This struct is used for converting float to float16 and back - * so the user could feed inputs and fetch outputs using these type. - * - * The size of the structure should align with uint16_t and one can freely cast - * uint16_t buffers to/from Ort::Float16_t to feed and retrieve data. - * - * \code{.unparsed} - * // This example demonstrates converion from float to float16 - * constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f}; - * std::vector fp16_values; - * fp16_values.reserve(std::size(values)); - * std::transform(std::begin(values), std::end(values), std::back_inserter(fp16_values), - * [](float value) { return Ort::Float16_t(value); }); - * - * \endcode - */ -struct Float16_t : onnxruntime_float16::Float16Impl { - private: - /// - /// Constructor from a 16-bit representation of a float16 value - /// No conversion is done here. - /// - /// 16-bit representation - constexpr explicit Float16_t(uint16_t v) noexcept { val = v; } - - public: - using Base = onnxruntime_float16::Float16Impl; - - /// - /// Default constructor - /// - Float16_t() = default; - - /// - /// Explicit conversion to uint16_t representation of float16. - /// - /// uint16_t bit representation of float16 - /// new instance of Float16_t - constexpr static Float16_t FromBits(uint16_t v) noexcept { return Float16_t(v); } - - /// - /// __ctor from float. Float is converted into float16 16-bit representation. - /// - /// float value - explicit Float16_t(float v) noexcept { val = Base::ToUint16Impl(v); } - - /// - /// Converts float16 to float - /// - /// float representation of float16 value - float ToFloat() const noexcept { return Base::ToFloatImpl(); } - - /// - /// Checks if the value is negative - /// - /// true if negative - using Base::IsNegative; - - /// - /// Tests if the value is NaN - /// - /// true if NaN - using Base::IsNaN; - - /// - /// Tests if the value is finite - /// - /// true if finite - using Base::IsFinite; - - /// - /// Tests if the value represents positive infinity. - /// - /// true if positive infinity - using Base::IsPositiveInfinity; - - /// - /// Tests if the value represents negative infinity - /// - /// true if negative infinity - using Base::IsNegativeInfinity; - - /// - /// Tests if the value is either positive or negative infinity. - /// - /// True if absolute value is infinity - using Base::IsInfinity; - - /// - /// Tests if the value is NaN or zero. Useful for comparisons. - /// - /// True if NaN or zero. - using Base::IsNaNOrZero; - - /// - /// Tests if the value is normal (not zero, subnormal, infinite, or NaN). - /// - /// True if so - using Base::IsNormal; - - /// - /// Tests if the value is subnormal (denormal). - /// - /// True if so - using Base::IsSubnormal; - - /// - /// Creates an instance that represents absolute value. - /// - /// Absolute value - using Base::Abs; - - /// - /// Creates a new instance with the sign flipped. - /// - /// Flipped sign instance - using Base::Negate; - - /// - /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check - /// for two values by or'ing the private bits together and stripping the sign. They are both zero, - /// and therefore equivalent, if the resulting value is still zero. - /// - /// first value - /// second value - /// True if both arguments represent zero - using Base::AreZero; - - /// - /// User defined conversion operator. Converts Float16_t to float. - /// - explicit operator float() const noexcept { return ToFloat(); } - - using Base::operator==; - using Base::operator!=; - using Base::operator<; -}; - -static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match"); - -/** \brief bfloat16 (Brain Floating Point) data type - * - * \details This struct is used for converting float to bfloat16 and back - * so the user could feed inputs and fetch outputs using these type. - * - * The size of the structure should align with uint16_t and one can freely cast - * uint16_t buffers to/from Ort::BFloat16_t to feed and retrieve data. - * - * \code{.unparsed} - * // This example demonstrates converion from float to float16 - * constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f}; - * std::vector bfp16_values; - * bfp16_values.reserve(std::size(values)); - * std::transform(std::begin(values), std::end(values), std::back_inserter(bfp16_values), - * [](float value) { return Ort::BFloat16_t(value); }); - * - * \endcode - */ -struct BFloat16_t : onnxruntime_float16::BFloat16Impl { - private: - /// - /// Constructor from a uint16_t representation of bfloat16 - /// used in FromBits() to escape overload resolution issue with - /// constructor from float. - /// No conversion is done. - /// - /// 16-bit bfloat16 value - constexpr explicit BFloat16_t(uint16_t v) noexcept { val = v; } - - public: - using Base = onnxruntime_float16::BFloat16Impl; - - BFloat16_t() = default; - - /// - /// Explicit conversion to uint16_t representation of bfloat16. - /// - /// uint16_t bit representation of bfloat16 - /// new instance of BFloat16_t - static constexpr BFloat16_t FromBits(uint16_t v) noexcept { return BFloat16_t(v); } - - /// - /// __ctor from float. Float is converted into bfloat16 16-bit representation. - /// - /// float value - explicit BFloat16_t(float v) noexcept { val = Base::ToUint16Impl(v); } - - /// - /// Converts bfloat16 to float - /// - /// float representation of bfloat16 value - float ToFloat() const noexcept { return Base::ToFloatImpl(); } - - /// - /// Checks if the value is negative - /// - /// true if negative - using Base::IsNegative; - - /// - /// Tests if the value is NaN - /// - /// true if NaN - using Base::IsNaN; - - /// - /// Tests if the value is finite - /// - /// true if finite - using Base::IsFinite; - - /// - /// Tests if the value represents positive infinity. - /// - /// true if positive infinity - using Base::IsPositiveInfinity; - - /// - /// Tests if the value represents negative infinity - /// - /// true if negative infinity - using Base::IsNegativeInfinity; - - /// - /// Tests if the value is either positive or negative infinity. - /// - /// True if absolute value is infinity - using Base::IsInfinity; - - /// - /// Tests if the value is NaN or zero. Useful for comparisons. - /// - /// True if NaN or zero. - using Base::IsNaNOrZero; - - /// - /// Tests if the value is normal (not zero, subnormal, infinite, or NaN). - /// - /// True if so - using Base::IsNormal; - - /// - /// Tests if the value is subnormal (denormal). - /// - /// True if so - using Base::IsSubnormal; - - /// - /// Creates an instance that represents absolute value. - /// - /// Absolute value - using Base::Abs; - - /// - /// Creates a new instance with the sign flipped. - /// - /// Flipped sign instance - using Base::Negate; - - /// - /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check - /// for two values by or'ing the private bits together and stripping the sign. They are both zero, - /// and therefore equivalent, if the resulting value is still zero. - /// - /// first value - /// second value - /// True if both arguments represent zero - using Base::AreZero; - - /// - /// User defined conversion operator. Converts BFloat16_t to float. - /// - explicit operator float() const noexcept { return ToFloat(); } - - // We do not have an inherited impl for the below operators - // as the internal class implements them a little differently - bool operator==(const BFloat16_t& rhs) const noexcept; - bool operator!=(const BFloat16_t& rhs) const noexcept { return !(*this == rhs); } - bool operator<(const BFloat16_t& rhs) const noexcept; -}; - -static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match"); - -/** \brief float8e4m3fn (Float8 Floating Point) data type - * \details It is necessary for type dispatching to make use of C++ API - * The type is implicitly convertible to/from uint8_t. - * See https://onnx.ai/onnx/technical/float8.html for further details. - */ -struct Float8E4M3FN_t { - uint8_t value; - constexpr Float8E4M3FN_t() noexcept : value(0) {} - constexpr Float8E4M3FN_t(uint8_t v) noexcept : value(v) {} - constexpr operator uint8_t() const noexcept { return value; } - // nan values are treated like any other value for operator ==, != - constexpr bool operator==(const Float8E4M3FN_t& rhs) const noexcept { return value == rhs.value; }; - constexpr bool operator!=(const Float8E4M3FN_t& rhs) const noexcept { return value != rhs.value; }; -}; - -static_assert(sizeof(Float8E4M3FN_t) == sizeof(uint8_t), "Sizes must match"); - -/** \brief float8e4m3fnuz (Float8 Floating Point) data type - * \details It is necessary for type dispatching to make use of C++ API - * The type is implicitly convertible to/from uint8_t. - * See https://onnx.ai/onnx/technical/float8.html for further details. - */ -struct Float8E4M3FNUZ_t { - uint8_t value; - constexpr Float8E4M3FNUZ_t() noexcept : value(0) {} - constexpr Float8E4M3FNUZ_t(uint8_t v) noexcept : value(v) {} - constexpr operator uint8_t() const noexcept { return value; } - // nan values are treated like any other value for operator ==, != - constexpr bool operator==(const Float8E4M3FNUZ_t& rhs) const noexcept { return value == rhs.value; }; - constexpr bool operator!=(const Float8E4M3FNUZ_t& rhs) const noexcept { return value != rhs.value; }; -}; - -static_assert(sizeof(Float8E4M3FNUZ_t) == sizeof(uint8_t), "Sizes must match"); - -/** \brief float8e5m2 (Float8 Floating Point) data type - * \details It is necessary for type dispatching to make use of C++ API - * The type is implicitly convertible to/from uint8_t. - * See https://onnx.ai/onnx/technical/float8.html for further details. - */ -struct Float8E5M2_t { - uint8_t value; - constexpr Float8E5M2_t() noexcept : value(0) {} - constexpr Float8E5M2_t(uint8_t v) noexcept : value(v) {} - constexpr operator uint8_t() const noexcept { return value; } - // nan values are treated like any other value for operator ==, != - constexpr bool operator==(const Float8E5M2_t& rhs) const noexcept { return value == rhs.value; }; - constexpr bool operator!=(const Float8E5M2_t& rhs) const noexcept { return value != rhs.value; }; -}; - -static_assert(sizeof(Float8E5M2_t) == sizeof(uint8_t), "Sizes must match"); - -/** \brief float8e5m2fnuz (Float8 Floating Point) data type - * \details It is necessary for type dispatching to make use of C++ API - * The type is implicitly convertible to/from uint8_t. - * See https://onnx.ai/onnx/technical/float8.html for further details. - */ -struct Float8E5M2FNUZ_t { - uint8_t value; - constexpr Float8E5M2FNUZ_t() noexcept : value(0) {} - constexpr Float8E5M2FNUZ_t(uint8_t v) noexcept : value(v) {} - constexpr operator uint8_t() const noexcept { return value; } - // nan values are treated like any other value for operator ==, != - constexpr bool operator==(const Float8E5M2FNUZ_t& rhs) const noexcept { return value == rhs.value; }; - constexpr bool operator!=(const Float8E5M2FNUZ_t& rhs) const noexcept { return value != rhs.value; }; -}; - -static_assert(sizeof(Float8E5M2FNUZ_t) == sizeof(uint8_t), "Sizes must match"); - -namespace detail { -// This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type -// This can't be done in the C API since C doesn't have function overloading. -#define ORT_DEFINE_RELEASE(NAME) \ - inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); } - -ORT_DEFINE_RELEASE(Allocator); -ORT_DEFINE_RELEASE(MemoryInfo); -ORT_DEFINE_RELEASE(CustomOpDomain); -ORT_DEFINE_RELEASE(ThreadingOptions); -ORT_DEFINE_RELEASE(Env); -ORT_DEFINE_RELEASE(RunOptions); -ORT_DEFINE_RELEASE(Session); -ORT_DEFINE_RELEASE(SessionOptions); -ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo); -ORT_DEFINE_RELEASE(SequenceTypeInfo); -ORT_DEFINE_RELEASE(MapTypeInfo); -ORT_DEFINE_RELEASE(TypeInfo); -ORT_DEFINE_RELEASE(Value); -ORT_DEFINE_RELEASE(ModelMetadata); -ORT_DEFINE_RELEASE(IoBinding); -ORT_DEFINE_RELEASE(ArenaCfg); -ORT_DEFINE_RELEASE(Status); -ORT_DEFINE_RELEASE(OpAttr); -ORT_DEFINE_RELEASE(Op); -ORT_DEFINE_RELEASE(KernelInfo); - -#undef ORT_DEFINE_RELEASE - -/** \brief This is a tagging template type. Use it with Base to indicate that the C++ interface object - * has no ownership of the underlying C object. - */ -template -struct Unowned { - using Type = T; -}; - -/** \brief Used internally by the C++ API. C++ wrapper types inherit from this. - * This is a zero cost abstraction to wrap the C API objects and delete them on destruction. - * - * All of the C++ classes - * a) serve as containers for pointers to objects that are created by the underlying C API. - * Their size is just a pointer size, no need to dynamically allocate them. Use them by value. - * b) Each of struct XXXX, XXX instances function as smart pointers to the underlying C API objects. - * they would release objects owned automatically when going out of scope, they are move-only. - * c) ConstXXXX and UnownedXXX structs function as non-owning, copyable containers for the above pointers. - * ConstXXXX allow calling const interfaces only. They give access to objects that are owned by somebody else - * such as Onnxruntime or instances of XXXX classes. - * d) serve convenient interfaces that return C++ objects and further enhance exception and type safety so they can be used - * in C++ code. - * - */ - -/// -/// This is a non-const pointer holder that is move-only. Disposes of the pointer on destruction. -/// -template -struct Base { - using contained_type = T; - - constexpr Base() = default; - constexpr explicit Base(contained_type* p) noexcept : p_{p} {} - ~Base() { OrtRelease(p_); } - - Base(const Base&) = delete; - Base& operator=(const Base&) = delete; - - Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; } - Base& operator=(Base&& v) noexcept { - OrtRelease(p_); - p_ = v.release(); - return *this; - } - - constexpr operator contained_type*() const noexcept { return p_; } - - /// \brief Relinquishes ownership of the contained C object pointer - /// The underlying object is not destroyed - contained_type* release() { - T* p = p_; - p_ = nullptr; - return p; - } - - protected: - contained_type* p_{}; -}; - -// Undefined. For const types use Base> -template -struct Base; - -/// -/// Covers unowned pointers owned by either the ORT -/// or some other instance of CPP wrappers. -/// Used for ConstXXX and UnownedXXXX types that are copyable. -/// Also convenient to wrap raw OrtXX pointers . -/// -/// -template -struct Base> { - using contained_type = typename Unowned::Type; - - constexpr Base() = default; - constexpr explicit Base(contained_type* p) noexcept : p_{p} {} - - ~Base() = default; - - Base(const Base&) = default; - Base& operator=(const Base&) = default; - - Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; } - Base& operator=(Base&& v) noexcept { - p_ = nullptr; - std::swap(p_, v.p_); - return *this; - } - - constexpr operator contained_type*() const noexcept { return p_; } - - protected: - contained_type* p_{}; -}; - -// Light functor to release memory with OrtAllocator -struct AllocatedFree { - OrtAllocator* allocator_; - explicit AllocatedFree(OrtAllocator* allocator) - : allocator_(allocator) {} - void operator()(void* ptr) const { - if (ptr) allocator_->Free(allocator_, ptr); - } -}; - -} // namespace detail - -struct AllocatorWithDefaultOptions; -struct Env; -struct TypeInfo; -struct Value; -struct ModelMetadata; - -/** \brief unique_ptr typedef used to own strings allocated by OrtAllocators - * and release them at the end of the scope. The lifespan of the given allocator - * must eclipse the lifespan of AllocatedStringPtr instance - */ -using AllocatedStringPtr = std::unique_ptr; - -/** \brief The Status that holds ownership of OrtStatus received from C API - * Use it to safely destroy OrtStatus* returned from the C API. Use appropriate - * constructors to construct an instance of a Status object from exceptions. - */ -struct Status : detail::Base { - explicit Status(std::nullptr_t) noexcept {} ///< Create an empty object, must be assigned a valid one to be used - explicit Status(OrtStatus* status) noexcept; ///< Takes ownership of OrtStatus instance returned from the C API. - explicit Status(const Exception&) noexcept; ///< Creates status instance out of exception - explicit Status(const std::exception&) noexcept; ///< Creates status instance out of exception - Status(const char* message, OrtErrorCode code) noexcept; ///< Creates status instance out of null-terminated string message. - std::string GetErrorMessage() const; - OrtErrorCode GetErrorCode() const; - bool IsOK() const noexcept; ///< Returns true if instance represents an OK (non-error) status. -}; - -/** \brief The ThreadingOptions - * - * The ThreadingOptions used for set global threadpools' options of The Env. - */ -struct ThreadingOptions : detail::Base { - /// \brief Wraps OrtApi::CreateThreadingOptions - ThreadingOptions(); - - /// \brief Wraps OrtApi::SetGlobalIntraOpNumThreads - ThreadingOptions& SetGlobalIntraOpNumThreads(int intra_op_num_threads); - - /// \brief Wraps OrtApi::SetGlobalInterOpNumThreads - ThreadingOptions& SetGlobalInterOpNumThreads(int inter_op_num_threads); - - /// \brief Wraps OrtApi::SetGlobalSpinControl - ThreadingOptions& SetGlobalSpinControl(int allow_spinning); - - /// \brief Wraps OrtApi::SetGlobalDenormalAsZero - ThreadingOptions& SetGlobalDenormalAsZero(); - - /// \brief Wraps OrtApi::SetGlobalCustomCreateThreadFn - ThreadingOptions& SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); - - /// \brief Wraps OrtApi::SetGlobalCustomThreadCreationOptions - ThreadingOptions& SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options); - - /// \brief Wraps OrtApi::SetGlobalCustomJoinThreadFn - ThreadingOptions& SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); -}; - -/** \brief The Env (Environment) - * - * The Env holds the logging state used by all other objects. - * Note: One Env must be created before using any other Onnxruntime functionality - */ -struct Env : detail::Base { - explicit Env(std::nullptr_t) {} ///< Create an empty Env object, must be assigned a valid one to be used - - /// \brief Wraps OrtApi::CreateEnv - Env(OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); - - /// \brief Wraps OrtApi::CreateEnvWithCustomLogger - Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param); - - /// \brief Wraps OrtApi::CreateEnvWithGlobalThreadPools - Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); - - /// \brief Wraps OrtApi::CreateEnvWithCustomLoggerAndGlobalThreadPools - Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param, - OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); - - /// \brief C Interop Helper - explicit Env(OrtEnv* p) : Base{p} {} - - Env& EnableTelemetryEvents(); ///< Wraps OrtApi::EnableTelemetryEvents - Env& DisableTelemetryEvents(); ///< Wraps OrtApi::DisableTelemetryEvents - - Env& UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level); ///< Wraps OrtApi::UpdateEnvWithCustomLogLevel - - Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocator - - Env& CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info, const std::unordered_map& options, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocatorV2 -}; - -/** \brief Custom Op Domain - * - */ -struct CustomOpDomain : detail::Base { - explicit CustomOpDomain(std::nullptr_t) {} ///< Create an empty CustomOpDomain object, must be assigned a valid one to be used - - /// \brief Wraps OrtApi::CreateCustomOpDomain - explicit CustomOpDomain(const char* domain); - - // This does not take ownership of the op, simply registers it. - void Add(const OrtCustomOp* op); ///< Wraps CustomOpDomain_Add -}; - -/** \brief RunOptions - * - */ -struct RunOptions : detail::Base { - explicit RunOptions(std::nullptr_t) {} ///< Create an empty RunOptions object, must be assigned a valid one to be used - RunOptions(); ///< Wraps OrtApi::CreateRunOptions - - RunOptions& SetRunLogVerbosityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel - int GetRunLogVerbosityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel - - RunOptions& SetRunLogSeverityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogSeverityLevel - int GetRunLogSeverityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogSeverityLevel - - RunOptions& SetRunTag(const char* run_tag); ///< wraps OrtApi::RunOptionsSetRunTag - const char* GetRunTag() const; ///< Wraps OrtApi::RunOptionsGetRunTag - - RunOptions& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddRunConfigEntry - - /** \brief Terminates all currently executing Session::Run calls that were made using this RunOptions instance - * - * If a currently executing session needs to be force terminated, this can be called from another thread to force it to fail with an error - * Wraps OrtApi::RunOptionsSetTerminate - */ - RunOptions& SetTerminate(); - - /** \brief Clears the terminate flag so this RunOptions instance can be used in a new Session::Run call without it instantly terminating - * - * Wraps OrtApi::RunOptionsUnsetTerminate - */ - RunOptions& UnsetTerminate(); -}; - -namespace detail { -// Utility function that returns a SessionOption config entry key for a specific custom operator. -// Ex: custom_op.[custom_op_name].[config] -std::string MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config); -} // namespace detail - -/// -/// Class that represents session configuration entries for one or more custom operators. -/// -/// Example: -/// Ort::CustomOpConfigs op_configs; -/// op_configs.AddConfig("my_custom_op", "device_type", "CPU"); -/// -/// Passed to Ort::SessionOptions::RegisterCustomOpsLibrary. -/// -struct CustomOpConfigs { - CustomOpConfigs() = default; - ~CustomOpConfigs() = default; - CustomOpConfigs(const CustomOpConfigs&) = default; - CustomOpConfigs& operator=(const CustomOpConfigs&) = default; - CustomOpConfigs(CustomOpConfigs&& o) = default; - CustomOpConfigs& operator=(CustomOpConfigs&& o) = default; - - /** \brief Adds a session configuration entry/value for a specific custom operator. - * - * \param custom_op_name The name of the custom operator for which to add a configuration entry. - * Must match the name returned by the CustomOp's GetName() method. - * \param config_key The name of the configuration entry. - * \param config_value The value of the configuration entry. - * \return A reference to this object to enable call chaining. - */ - CustomOpConfigs& AddConfig(const char* custom_op_name, const char* config_key, const char* config_value); - - /** \brief Returns a flattened map of custom operator configuration entries and their values. - * - * The keys has been flattened to include both the custom operator name and the configuration entry key name. - * For example, a prior call to AddConfig("my_op", "key", "value") corresponds to the flattened key/value pair - * {"my_op.key", "value"}. - * - * \return An unordered map of flattened configurations. - */ - const std::unordered_map& GetFlattenedConfigs() const; - - private: - std::unordered_map flat_configs_; -}; - -/** \brief Options object used when creating a new Session object - * - * Wraps ::OrtSessionOptions object and methods - */ - -struct SessionOptions; - -namespace detail { -// we separate const-only methods because passing const ptr to non-const methods -// is only discovered when inline methods are compiled which is counter-intuitive -template -struct ConstSessionOptionsImpl : Base { - using B = Base; - using B::B; - - SessionOptions Clone() const; ///< Creates and returns a copy of this SessionOptions object. Wraps OrtApi::CloneSessionOptions - - std::string GetConfigEntry(const char* config_key) const; ///< Wraps OrtApi::GetSessionConfigEntry - bool HasConfigEntry(const char* config_key) const; ///< Wraps OrtApi::HasSessionConfigEntry - std::string GetConfigEntryOrDefault(const char* config_key, const std::string& def); -}; - -template -struct SessionOptionsImpl : ConstSessionOptionsImpl { - using B = ConstSessionOptionsImpl; - using B::B; - - SessionOptionsImpl& SetIntraOpNumThreads(int intra_op_num_threads); ///< Wraps OrtApi::SetIntraOpNumThreads - SessionOptionsImpl& SetInterOpNumThreads(int inter_op_num_threads); ///< Wraps OrtApi::SetInterOpNumThreads - SessionOptionsImpl& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::SetSessionGraphOptimizationLevel - SessionOptionsImpl& SetDeterministicCompute(bool value); ///< Wraps OrtApi::SetDeterministicCompute - - SessionOptionsImpl& EnableCpuMemArena(); ///< Wraps OrtApi::EnableCpuMemArena - SessionOptionsImpl& DisableCpuMemArena(); ///< Wraps OrtApi::DisableCpuMemArena - - SessionOptionsImpl& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file); ///< Wraps OrtApi::SetOptimizedModelFilePath - - SessionOptionsImpl& EnableProfiling(const ORTCHAR_T* profile_file_prefix); ///< Wraps OrtApi::EnableProfiling - SessionOptionsImpl& DisableProfiling(); ///< Wraps OrtApi::DisableProfiling - - SessionOptionsImpl& EnableOrtCustomOps(); ///< Wraps OrtApi::EnableOrtCustomOps - - SessionOptionsImpl& EnableMemPattern(); ///< Wraps OrtApi::EnableMemPattern - SessionOptionsImpl& DisableMemPattern(); ///< Wraps OrtApi::DisableMemPattern - - SessionOptionsImpl& SetExecutionMode(ExecutionMode execution_mode); ///< Wraps OrtApi::SetSessionExecutionMode - - SessionOptionsImpl& SetLogId(const char* logid); ///< Wraps OrtApi::SetSessionLogId - SessionOptionsImpl& SetLogSeverityLevel(int level); ///< Wraps OrtApi::SetSessionLogSeverityLevel - - SessionOptionsImpl& Add(OrtCustomOpDomain* custom_op_domain); ///< Wraps OrtApi::AddCustomOpDomain - - SessionOptionsImpl& DisablePerSessionThreads(); ///< Wraps OrtApi::DisablePerSessionThreads - - SessionOptionsImpl& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddSessionConfigEntry - - SessionOptionsImpl& AddInitializer(const char* name, const OrtValue* ort_val); ///< Wraps OrtApi::AddInitializer - SessionOptionsImpl& AddExternalInitializers(const std::vector& names, const std::vector& ort_values); ///< Wraps OrtApi::AddExternalInitializers - SessionOptionsImpl& AddExternalInitializersFromFilesInMemory(const std::vector>& external_initializer_file_names, - const std::vector& external_initializer_file_buffer_array, - const std::vector& external_initializer_file_lengths); ///< Wraps OrtApi::AddExternalInitializersFromFilesInMemory - - SessionOptionsImpl& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA - SessionOptionsImpl& AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2 - SessionOptionsImpl& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM - SessionOptionsImpl& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO - ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO_V2 - SessionOptionsImpl& AppendExecutionProvider_OpenVINO_V2(const std::unordered_map& provider_options = {}); - SessionOptionsImpl& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT - SessionOptionsImpl& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT - SessionOptionsImpl& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX - ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN - SessionOptionsImpl& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options); - ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl - SessionOptionsImpl& AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options); - /// Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports QNN, SNPE and XNNPACK. - SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name, - const std::unordered_map& provider_options = {}); - - SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn - SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions - SessionOptionsImpl& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn - - ///< Registers the custom operator from the specified shared library via OrtApi::RegisterCustomOpsLibrary_V2. - ///< The custom operator configurations are optional. If provided, custom operator configs are set via - ///< OrtApi::AddSessionConfigEntry. - SessionOptionsImpl& RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs = {}); - - SessionOptionsImpl& RegisterCustomOpsUsingFunction(const char* function_name); ///< Wraps OrtApi::RegisterCustomOpsUsingFunction - - ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_VitisAI - SessionOptionsImpl& AppendExecutionProvider_VitisAI(const std::unordered_map& provider_options = {}); -}; -} // namespace detail - -using UnownedSessionOptions = detail::SessionOptionsImpl>; -using ConstSessionOptions = detail::ConstSessionOptionsImpl>; - -/** \brief Wrapper around ::OrtSessionOptions - * - */ -struct SessionOptions : detail::SessionOptionsImpl { - explicit SessionOptions(std::nullptr_t) {} ///< Create an empty SessionOptions object, must be assigned a valid one to be used - SessionOptions(); ///< Wraps OrtApi::CreateSessionOptions - explicit SessionOptions(OrtSessionOptions* p) : SessionOptionsImpl{p} {} ///< Used for interop with the C API - UnownedSessionOptions GetUnowned() const { return UnownedSessionOptions{this->p_}; } - ConstSessionOptions GetConst() const { return ConstSessionOptions{this->p_}; } -}; - -/** \brief Wrapper around ::OrtModelMetadata - * - */ -struct ModelMetadata : detail::Base { - explicit ModelMetadata(std::nullptr_t) {} ///< Create an empty ModelMetadata object, must be assigned a valid one to be used - explicit ModelMetadata(OrtModelMetadata* p) : Base{p} {} ///< Used for interop with the C API - - /** \brief Returns a copy of the producer name. - * - * \param allocator to allocate memory for the copy of the name returned - * \return a instance of smart pointer that would deallocate the buffer when out of scope. - * The OrtAllocator instances must be valid at the point of memory release. - */ - AllocatedStringPtr GetProducerNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetProducerName - - /** \brief Returns a copy of the graph name. - * - * \param allocator to allocate memory for the copy of the name returned - * \return a instance of smart pointer that would deallocate the buffer when out of scope. - * The OrtAllocator instances must be valid at the point of memory release. - */ - AllocatedStringPtr GetGraphNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphName - - /** \brief Returns a copy of the domain name. - * - * \param allocator to allocate memory for the copy of the name returned - * \return a instance of smart pointer that would deallocate the buffer when out of scope. - * The OrtAllocator instances must be valid at the point of memory release. - */ - AllocatedStringPtr GetDomainAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDomain - - /** \brief Returns a copy of the description. - * - * \param allocator to allocate memory for the copy of the string returned - * \return a instance of smart pointer that would deallocate the buffer when out of scope. - * The OrtAllocator instances must be valid at the point of memory release. - */ - AllocatedStringPtr GetDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDescription - - /** \brief Returns a copy of the graph description. - * - * \param allocator to allocate memory for the copy of the string returned - * \return a instance of smart pointer that would deallocate the buffer when out of scope. - * The OrtAllocator instances must be valid at the point of memory release. - */ - AllocatedStringPtr GetGraphDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphDescription - - /** \brief Returns a vector of copies of the custom metadata keys. - * - * \param allocator to allocate memory for the copy of the string returned - * \return a instance std::vector of smart pointers that would deallocate the buffers when out of scope. - * The OrtAllocator instance must be valid at the point of memory release. - */ - std::vector GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetCustomMetadataMapKeys - - /** \brief Looks up a value by a key in the Custom Metadata map - * - * \param key zero terminated string key to lookup - * \param allocator to allocate memory for the copy of the string returned - * \return a instance of smart pointer that would deallocate the buffer when out of scope. - * maybe nullptr if key is not found. - * - * The OrtAllocator instances must be valid at the point of memory release. - */ - AllocatedStringPtr LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataLookupCustomMetadataMap - - int64_t GetVersion() const; ///< Wraps OrtApi::ModelMetadataGetVersion -}; - -struct IoBinding; - -namespace detail { - -// we separate const-only methods because passing const ptr to non-const methods -// is only discovered when inline methods are compiled which is counter-intuitive -template -struct ConstSessionImpl : Base { - using B = Base; - using B::B; - - size_t GetInputCount() const; ///< Returns the number of model inputs - size_t GetOutputCount() const; ///< Returns the number of model outputs - size_t GetOverridableInitializerCount() const; ///< Returns the number of inputs that have defaults that can be overridden - - /** \brief Returns a copy of input name at the specified index. - * - * \param index must less than the value returned by GetInputCount() - * \param allocator to allocate memory for the copy of the name returned - * \return a instance of smart pointer that would deallocate the buffer when out of scope. - * The OrtAllocator instances must be valid at the point of memory release. - */ - AllocatedStringPtr GetInputNameAllocated(size_t index, OrtAllocator* allocator) const; - - /** \brief Returns a copy of output name at then specified index. - * - * \param index must less than the value returned by GetOutputCount() - * \param allocator to allocate memory for the copy of the name returned - * \return a instance of smart pointer that would deallocate the buffer when out of scope. - * The OrtAllocator instances must be valid at the point of memory release. - */ - AllocatedStringPtr GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const; - - /** \brief Returns a copy of the overridable initializer name at then specified index. - * - * \param index must less than the value returned by GetOverridableInitializerCount() - * \param allocator to allocate memory for the copy of the name returned - * \return a instance of smart pointer that would deallocate the buffer when out of scope. - * The OrtAllocator instances must be valid at the point of memory release. - */ - AllocatedStringPtr GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOverridableInitializerName - - uint64_t GetProfilingStartTimeNs() const; ///< Wraps OrtApi::SessionGetProfilingStartTimeNs - ModelMetadata GetModelMetadata() const; ///< Wraps OrtApi::SessionGetModelMetadata - - TypeInfo GetInputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetInputTypeInfo - TypeInfo GetOutputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOutputTypeInfo - TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOverridableInitializerTypeInfo -}; - -template -struct SessionImpl : ConstSessionImpl { - using B = ConstSessionImpl; - using B::B; - - /** \brief Run the model returning results in an Ort allocated vector. - * - * Wraps OrtApi::Run - * - * The caller provides a list of inputs and a list of the desired outputs to return. - * - * See the output logs for more information on warnings/errors that occur while processing the model. - * Common errors are.. (TODO) - * - * \param[in] run_options - * \param[in] input_names Array of null terminated strings of length input_count that is the list of input names - * \param[in] input_values Array of Value objects of length input_count that is the list of input values - * \param[in] input_count Number of inputs (the size of the input_names & input_values arrays) - * \param[in] output_names Array of C style strings of length output_count that is the list of output names - * \param[in] output_count Number of outputs (the size of the output_names array) - * \return A std::vector of Value objects that directly maps to the output_names array (eg. output_name[0] is the first entry of the returned vector) - */ - std::vector Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, - const char* const* output_names, size_t output_count); - - /** \brief Run the model returning results in user provided outputs - * Same as Run(const RunOptions&, const char* const*, const Value*, size_t,const char* const*, size_t) - */ - void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, - const char* const* output_names, Value* output_values, size_t output_count); - - void Run(const RunOptions& run_options, const IoBinding&); ///< Wraps OrtApi::RunWithBinding - - /** \brief Run the model asynchronously in a thread owned by intra op thread pool - * - * Wraps OrtApi::RunAsync - * - * \param[in] run_options - * \param[in] input_names Array of null terminated UTF8 encoded strings of the input names - * \param[in] input_values Array of Value objects of length input_count - * \param[in] input_count Number of elements in the input_names and inputs arrays - * \param[in] output_names Array of null terminated UTF8 encoded strings of the output names - * \param[out] output_values Array of provided Values to be filled with outputs. - * On calling RunAsync, output_values[i] could either be initialized by a null pointer or a preallocated OrtValue*. - * Later, on invoking the callback, each output_values[i] of null will be filled with an OrtValue* allocated by onnxruntime. - * Then, an OrtValue** pointer will be casted from output_values, and pass to the callback. - * NOTE: it is customer's duty to finally release output_values and each of its member, - * regardless of whether the member (Ort::Value) is allocated by onnxruntime or preallocated by the customer. - * \param[in] output_count Number of elements in the output_names and outputs array - * \param[in] callback Callback function on model run completion - * \param[in] user_data User data that pass back to the callback - */ - void RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, - const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data); - - /** \brief End profiling and return a copy of the profiling file name. - * - * \param allocator to allocate memory for the copy of the string returned - * \return a instance of smart pointer that would deallocate the buffer when out of scope. - * The OrtAllocator instances must be valid at the point of memory release. - */ - AllocatedStringPtr EndProfilingAllocated(OrtAllocator* allocator); ///< Wraps OrtApi::SessionEndProfiling -}; - -} // namespace detail - -using ConstSession = detail::ConstSessionImpl>; -using UnownedSession = detail::SessionImpl>; - -/** \brief Wrapper around ::OrtSession - * - */ -struct Session : detail::SessionImpl { - explicit Session(std::nullptr_t) {} ///< Create an empty Session object, must be assigned a valid one to be used - Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); ///< Wraps OrtApi::CreateSession - Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options, - OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer - Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); ///< Wraps OrtApi::CreateSessionFromArray - Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options, - OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer - - ConstSession GetConst() const { return ConstSession{this->p_}; } - UnownedSession GetUnowned() const { return UnownedSession{this->p_}; } -}; - -namespace detail { -template -struct MemoryInfoImpl : Base { - using B = Base; - using B::B; - - std::string GetAllocatorName() const; - OrtAllocatorType GetAllocatorType() const; - int GetDeviceId() const; - OrtMemoryInfoDeviceType GetDeviceType() const; - OrtMemType GetMemoryType() const; - - template - bool operator==(const MemoryInfoImpl& o) const; -}; -} // namespace detail - -// Const object holder that does not own the underlying object -using ConstMemoryInfo = detail::MemoryInfoImpl>; - -/** \brief Wrapper around ::OrtMemoryInfo - * - */ -struct MemoryInfo : detail::MemoryInfoImpl { - static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1); - explicit MemoryInfo(std::nullptr_t) {} ///< No instance is created - explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl{p} {} ///< Take ownership of a pointer created by C Api - MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type); - ConstMemoryInfo GetConst() const { return ConstMemoryInfo{this->p_}; } -}; - -namespace detail { -template -struct TensorTypeAndShapeInfoImpl : Base { - using B = Base; - using B::B; - - ONNXTensorElementDataType GetElementType() const; ///< Wraps OrtApi::GetTensorElementType - size_t GetElementCount() const; ///< Wraps OrtApi::GetTensorShapeElementCount - - size_t GetDimensionsCount() const; ///< Wraps OrtApi::GetDimensionsCount - - /** \deprecated use GetShape() returning std::vector - * [[deprecated]] - * This interface is unsafe to use - */ - [[deprecated("use GetShape()")]] void GetDimensions(int64_t* values, size_t values_count) const; ///< Wraps OrtApi::GetDimensions - - void GetSymbolicDimensions(const char** values, size_t values_count) const; ///< Wraps OrtApi::GetSymbolicDimensions - - std::vector GetShape() const; ///< Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape -}; - -} // namespace detail - -using ConstTensorTypeAndShapeInfo = detail::TensorTypeAndShapeInfoImpl>; - -/** \brief Wrapper around ::OrtTensorTypeAndShapeInfo - * - */ -struct TensorTypeAndShapeInfo : detail::TensorTypeAndShapeInfoImpl { - explicit TensorTypeAndShapeInfo(std::nullptr_t) {} ///< Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used - explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {} ///< Used for interop with the C API - ConstTensorTypeAndShapeInfo GetConst() const { return ConstTensorTypeAndShapeInfo{this->p_}; } -}; - -namespace detail { -template -struct SequenceTypeInfoImpl : Base { - using B = Base; - using B::B; - TypeInfo GetSequenceElementType() const; ///< Wraps OrtApi::GetSequenceElementType -}; - -} // namespace detail - -using ConstSequenceTypeInfo = detail::SequenceTypeInfoImpl>; - -/** \brief Wrapper around ::OrtSequenceTypeInfo - * - */ -struct SequenceTypeInfo : detail::SequenceTypeInfoImpl { - explicit SequenceTypeInfo(std::nullptr_t) {} ///< Create an empty SequenceTypeInfo object, must be assigned a valid one to be used - explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : SequenceTypeInfoImpl{p} {} ///< Used for interop with the C API - ConstSequenceTypeInfo GetConst() const { return ConstSequenceTypeInfo{this->p_}; } -}; - -namespace detail { -template -struct OptionalTypeInfoImpl : Base { - using B = Base; - using B::B; - TypeInfo GetOptionalElementType() const; ///< Wraps OrtApi::CastOptionalTypeToContainedTypeInfo -}; - -} // namespace detail - -// This is always owned by the TypeInfo and can only be obtained from it. -using ConstOptionalTypeInfo = detail::OptionalTypeInfoImpl>; - -namespace detail { -template -struct MapTypeInfoImpl : detail::Base { - using B = Base; - using B::B; - ONNXTensorElementDataType GetMapKeyType() const; ///< Wraps OrtApi::GetMapKeyType - TypeInfo GetMapValueType() const; ///< Wraps OrtApi::GetMapValueType -}; - -} // namespace detail - -using ConstMapTypeInfo = detail::MapTypeInfoImpl>; - -/** \brief Wrapper around ::OrtMapTypeInfo - * - */ -struct MapTypeInfo : detail::MapTypeInfoImpl { - explicit MapTypeInfo(std::nullptr_t) {} ///< Create an empty MapTypeInfo object, must be assigned a valid one to be used - explicit MapTypeInfo(OrtMapTypeInfo* p) : MapTypeInfoImpl{p} {} ///< Used for interop with the C API - ConstMapTypeInfo GetConst() const { return ConstMapTypeInfo{this->p_}; } -}; - -namespace detail { -template -struct TypeInfoImpl : detail::Base { - using B = Base; - using B::B; - - ConstTensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; ///< Wraps OrtApi::CastTypeInfoToTensorInfo - ConstSequenceTypeInfo GetSequenceTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToSequenceTypeInfo - ConstMapTypeInfo GetMapTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToMapTypeInfo - ConstOptionalTypeInfo GetOptionalTypeInfo() const; ///< wraps OrtApi::CastTypeInfoToOptionalTypeInfo - - ONNXType GetONNXType() const; -}; -} // namespace detail - -/// -/// Contains a constant, unowned OrtTypeInfo that can be copied and passed around by value. -/// Provides access to const OrtTypeInfo APIs. -/// -using ConstTypeInfo = detail::TypeInfoImpl>; - -/// -/// Type information that may contain either TensorTypeAndShapeInfo or -/// the information about contained sequence or map depending on the ONNXType. -/// -struct TypeInfo : detail::TypeInfoImpl { - explicit TypeInfo(std::nullptr_t) {} ///< Create an empty TypeInfo object, must be assigned a valid one to be used - explicit TypeInfo(OrtTypeInfo* p) : TypeInfoImpl{p} {} ///< C API Interop - - ConstTypeInfo GetConst() const { return ConstTypeInfo{this->p_}; } -}; - -namespace detail { -// This structure is used to feed sparse tensor values -// information for use with FillSparseTensor() API -// if the data type for the sparse tensor values is numeric -// use data.p_data, otherwise, use data.str pointer to feed -// values. data.str is an array of const char* that are zero terminated. -// number of strings in the array must match shape size. -// For fully sparse tensors use shape {0} and set p_data/str -// to nullptr. -struct OrtSparseValuesParam { - const int64_t* values_shape; - size_t values_shape_len; - union { - const void* p_data; - const char** str; - } data; -}; - -// Provides a way to pass shape in a single -// argument -struct Shape { - const int64_t* shape; - size_t shape_len; -}; - -template -struct ConstValueImpl : Base { - using B = Base; - using B::B; - - /// - /// Obtains a pointer to a user defined data for experimental purposes - /// - template - void GetOpaqueData(const char* domain, const char* type_name, R&) const; ///< Wraps OrtApi::GetOpaqueValue - - bool IsTensor() const; ///< Returns true if Value is a tensor, false for other types like map/sequence/etc - bool HasValue() const; /// < Return true if OrtValue contains data and returns false if the OrtValue is a None - - size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements - Value GetValue(int index, OrtAllocator* allocator) const; - - /// - /// This API returns a full length of string data contained within either a tensor or a sparse Tensor. - /// For sparse tensor it returns a full length of stored non-empty strings (values). The API is useful - /// for allocating necessary memory and calling GetStringTensorContent(). - /// - /// total length of UTF-8 encoded bytes contained. No zero terminators counted. - size_t GetStringTensorDataLength() const; - - /// - /// The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor - /// into a supplied buffer. Use GetStringTensorDataLength() to find out the length of the buffer to allocate. - /// The user must also allocate offsets buffer with the number of entries equal to that of the contained - /// strings. - /// - /// Strings are always assumed to be on CPU, no X-device copy. - /// - /// user allocated buffer - /// length in bytes of the allocated buffer - /// a pointer to the offsets user allocated buffer - /// count of offsets, must be equal to the number of strings contained. - /// that can be obtained from the shape of the tensor or from GetSparseTensorValuesTypeAndShapeInfo() - /// for sparse tensors - void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const; - - /// - /// Returns a const typed pointer to the tensor contained data. - /// No type checking is performed, the caller must ensure the type matches the tensor type. - /// - /// - /// const pointer to data, no copies made - template - const R* GetTensorData() const; ///< Wraps OrtApi::GetTensorMutableData /// - - /// - /// Returns a non-typed pointer to a tensor contained data. - /// - /// const pointer to data, no copies made - const void* GetTensorRawData() const; - - /// - /// The API returns type information for data contained in a tensor. For sparse - /// tensors it returns type information for contained non-zero values. - /// It returns dense shape for sparse tensors. - /// - /// TypeInfo - TypeInfo GetTypeInfo() const; - - /// - /// The API returns type information for data contained in a tensor. For sparse - /// tensors it returns type information for contained non-zero values. - /// It returns dense shape for sparse tensors. - /// - /// TensorTypeAndShapeInfo - TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; - - /// - /// This API returns information about the memory allocation used to hold data. - /// - /// Non owning instance of MemoryInfo - ConstMemoryInfo GetTensorMemoryInfo() const; - - /// - /// The API copies UTF-8 encoded bytes for the requested string element - /// contained within a tensor or a sparse tensor into a provided buffer. - /// Use GetStringTensorElementLength() to obtain the length of the buffer to allocate. - /// - /// - /// - /// - void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const; - - /// - /// Returns string tensor UTF-8 encoded string element. - /// Use of this API is recommended over GetStringTensorElement() that takes void* buffer pointer. - /// - /// - /// std::string - std::string GetStringTensorElement(size_t element_index) const; - - /// - /// The API returns a byte length of UTF-8 encoded string element - /// contained in either a tensor or a spare tensor values. - /// - /// - /// byte length for the specified string element - size_t GetStringTensorElementLength(size_t element_index) const; - -#if !defined(DISABLE_SPARSE_TENSORS) - /// - /// The API returns the sparse data format this OrtValue holds in a sparse tensor. - /// If the sparse tensor was not fully constructed, i.e. Use*() or Fill*() API were not used - /// the value returned is ORT_SPARSE_UNDEFINED. - /// - /// Format enum - OrtSparseFormat GetSparseFormat() const; - - /// - /// The API returns type and shape information for stored non-zero values of the - /// sparse tensor. Use GetSparseTensorValues() to obtain values buffer pointer. - /// - /// TensorTypeAndShapeInfo values information - TensorTypeAndShapeInfo GetSparseTensorValuesTypeAndShapeInfo() const; - - /// - /// The API returns type and shape information for the specified indices. Each supported - /// indices have their own enum values even if a give format has more than one kind of indices. - /// Use GetSparseTensorIndicesData() to obtain pointer to indices buffer. - /// - /// enum requested - /// type and shape information - TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat format) const; - - /// - /// The API retrieves a pointer to the internal indices buffer. The API merely performs - /// a convenience data type casting on the return type pointer. Make sure you are requesting - /// the right type, use GetSparseTensorIndicesTypeShapeInfo(); - /// - /// type to cast to - /// requested indices kind - /// number of indices entries - /// Pinter to the internal sparse tensor buffer containing indices. Do not free this pointer. - template - const R* GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const; - - /// - /// Returns true if the OrtValue contains a sparse tensor - /// - /// - bool IsSparseTensor() const; - - /// - /// The API returns a pointer to an internal buffer of the sparse tensor - /// containing non-zero values. The API merely does casting. Make sure you - /// are requesting the right data type by calling GetSparseTensorValuesTypeAndShapeInfo() - /// first. - /// - /// numeric data types only. Use GetStringTensor*() to retrieve strings. - /// a pointer to the internal values buffer. Do not free this pointer. - template - const R* GetSparseTensorValues() const; - -#endif -}; - -template -struct ValueImpl : ConstValueImpl { - using B = ConstValueImpl; - using B::B; - - /// - /// Returns a non-const typed pointer to an OrtValue/Tensor contained buffer - /// No type checking is performed, the caller must ensure the type matches the tensor type. - /// - /// non-const pointer to data, no copies made - template - R* GetTensorMutableData(); - - /// - /// Returns a non-typed non-const pointer to a tensor contained data. - /// - /// pointer to data, no copies made - void* GetTensorMutableRawData(); - - /// - // Obtain a reference to an element of data at the location specified - /// by the vector of dims. - /// - /// - /// [in] expressed by a vecotr of dimensions offsets - /// - template - R& At(const std::vector& location); - - /// - /// Set all strings at once in a string tensor - /// - /// [in] An array of strings. Each string in this array must be null terminated. - /// [in] Count of strings in s (Must match the size of \p value's tensor shape) - void FillStringTensor(const char* const* s, size_t s_len); - - /// - /// Set a single string in a string tensor - /// - /// [in] A null terminated UTF-8 encoded string - /// [in] Index of the string in the tensor to set - void FillStringTensorElement(const char* s, size_t index); - - /// - /// Allocate if necessary and obtain a pointer to a UTF-8 - /// encoded string element buffer indexed by the flat element index, - /// of the specified length. - /// - /// This API is for advanced usage. It avoids a need to construct - /// an auxiliary array of string pointers, and allows to write data directly - /// (do not zero terminate). - /// - /// - /// - /// a pointer to a writable buffer - char* GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length); - -#if !defined(DISABLE_SPARSE_TENSORS) - /// - /// Supplies COO format specific indices and marks the contained sparse tensor as being a COO format tensor. - /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user - /// allocated buffers lifespan must eclipse that of the OrtValue. - /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time. - /// - /// pointer to the user allocated buffer with indices. Use nullptr for fully sparse tensors. - /// number of indices entries. Use 0 for fully sparse tensors - void UseCooIndices(int64_t* indices_data, size_t indices_num); - - /// - /// Supplies CSR format specific indices and marks the contained sparse tensor as being a CSR format tensor. - /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user - /// allocated buffers lifespan must eclipse that of the OrtValue. - /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time. - /// - /// pointer to the user allocated buffer with inner indices or nullptr for fully sparse tensors - /// number of csr inner indices or 0 for fully sparse tensors - /// pointer to the user allocated buffer with outer indices or nullptr for fully sparse tensors - /// number of csr outer indices or 0 for fully sparse tensors - void UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num); - - /// - /// Supplies BlockSparse format specific indices and marks the contained sparse tensor as being a BlockSparse format tensor. - /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user - /// allocated buffers lifespan must eclipse that of the OrtValue. - /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time. - /// - /// indices shape or a {0} for fully sparse - /// user allocated buffer with indices or nullptr for fully spare tensors - void UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data); - - /// - /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API - /// and copy the values and COO indices into it. If data_mem_info specifies that the data is located - /// at difference device than the allocator, a X-device copy will be performed if possible. - /// - /// specified buffer memory description - /// values buffer information. - /// coo indices buffer or nullptr for fully sparse data - /// number of COO indices or 0 for fully sparse data - void FillSparseTensorCoo(const OrtMemoryInfo* data_mem_info, const OrtSparseValuesParam& values_param, - const int64_t* indices_data, size_t indices_num); - - /// - /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API - /// and copy the values and CSR indices into it. If data_mem_info specifies that the data is located - /// at difference device than the allocator, a X-device copy will be performed if possible. - /// - /// specified buffer memory description - /// values buffer information - /// csr inner indices pointer or nullptr for fully sparse tensors - /// number of csr inner indices or 0 for fully sparse tensors - /// pointer to csr indices data or nullptr for fully sparse tensors - /// number of csr outer indices or 0 - void FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info, - const OrtSparseValuesParam& values, - const int64_t* inner_indices_data, size_t inner_indices_num, - const int64_t* outer_indices_data, size_t outer_indices_num); - - /// - /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API - /// and copy the values and BlockSparse indices into it. If data_mem_info specifies that the data is located - /// at difference device than the allocator, a X-device copy will be performed if possible. - /// - /// specified buffer memory description - /// values buffer information - /// indices shape. use {0} for fully sparse tensors - /// pointer to indices data or nullptr for fully sparse tensors - void FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info, - const OrtSparseValuesParam& values, - const Shape& indices_shape, - const int32_t* indices_data); - -#endif -}; - -} // namespace detail - -using ConstValue = detail::ConstValueImpl>; -using UnownedValue = detail::ValueImpl>; - -/** \brief Wrapper around ::OrtValue - * - */ -struct Value : detail::ValueImpl { - using Base = detail::ValueImpl; - using OrtSparseValuesParam = detail::OrtSparseValuesParam; - using Shape = detail::Shape; - - explicit Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used - explicit Value(OrtValue* p) : Base{p} {} ///< Used for interop with the C API - Value(Value&&) = default; - Value& operator=(Value&&) = default; - - ConstValue GetConst() const { return ConstValue{this->p_}; } - UnownedValue GetUnowned() const { return UnownedValue{this->p_}; } - - /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue. - * \tparam T The numeric datatype. This API is not suitable for strings. - * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc). - * \param p_data Pointer to the data buffer. - * \param p_data_element_count The number of elements in the data buffer. - * \param shape Pointer to the tensor shape dimensions. - * \param shape_len The number of tensor shape dimensions. - */ - template - static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len); - - /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue. - * - * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc). - * \param p_data Pointer to the data buffer. - * \param p_data_byte_count The number of bytes in the data buffer. - * \param shape Pointer to the tensor shape dimensions. - * \param shape_len The number of tensor shape dimensions. - * \param type The data type. - */ - static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len, - ONNXTensorElementDataType type); - - /** \brief Creates an OrtValue with a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue. - * This overload will allocate the buffer for the tensor according to the supplied shape and data type. - * The allocated buffer will be owned by the returned OrtValue and will be freed when the OrtValue is released. - * The input data would need to be copied into the allocated buffer. - * This API is not suitable for strings. - * - * \tparam T The numeric datatype. This API is not suitable for strings. - * \param allocator The allocator to use. - * \param shape Pointer to the tensor shape dimensions. - * \param shape_len The number of tensor shape dimensions. - */ - template - static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len); - - /** \brief Creates an OrtValue with a tensor using the supplied OrtAllocator. - * Wraps OrtApi::CreateTensorAsOrtValue. - * The allocated buffer will be owned by the returned OrtValue and will be freed when the OrtValue is released. - * The input data would need to be copied into the allocated buffer. - * This API is not suitable for strings. - * - * \param allocator The allocator to use. - * \param shape Pointer to the tensor shape dimensions. - * \param shape_len The number of tensor shape dimensions. - * \param type The data type. - */ - static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type); - - /** \brief Creates an OrtValue with a Map Onnx type representation. - * The API would ref-count the supplied OrtValues and they will be released - * when the returned OrtValue is released. The caller may release keys and values after the call - * returns. - * - * \param keys an OrtValue containing a tensor with primitive data type keys. - * \param values an OrtValue that may contain a tensor. Ort currently supports only primitive data type values. - */ - static Value CreateMap(const Value& keys, const Value& values); ///< Wraps OrtApi::CreateValue - - /** \brief Creates an OrtValue with a Sequence Onnx type representation. - * The API would ref-count the supplied OrtValues and they will be released - * when the returned OrtValue is released. The caller may release the values after the call - * returns. - * - * \param values a vector of OrtValues that must have the same Onnx value type. - */ - static Value CreateSequence(const std::vector& values); ///< Wraps OrtApi::CreateValue - - /** \brief Creates an OrtValue wrapping an Opaque type. - * This is used for experimental support of non-tensor types. - * - * \tparam T - the type of the value. - * \param domain - zero terminated utf-8 string. Domain of the type. - * \param type_name - zero terminated utf-8 string. Name of the type. - * \param value - the value to be wrapped. - */ - template - static Value CreateOpaque(const char* domain, const char* type_name, const T& value); ///< Wraps OrtApi::CreateOpaqueValue - -#if !defined(DISABLE_SPARSE_TENSORS) - /// - /// This is a simple forwarding method to the other overload that helps deducing - /// data type enum value from the type of the buffer. - /// - /// numeric datatype. This API is not suitable for strings. - /// Memory description where the user buffers reside (CPU vs GPU etc) - /// pointer to the user supplied buffer, use nullptr for fully sparse tensors - /// a would be dense shape of the tensor - /// non zero values shape. Use a single 0 shape for fully sparse tensors. - /// - template - static Value CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape, - const Shape& values_shape); - - /// - /// Creates an OrtValue instance containing SparseTensor. This constructs - /// a sparse tensor that makes use of user allocated buffers. It does not make copies - /// of the user provided data and does not modify it. The lifespan of user provided buffers should - /// eclipse the life span of the resulting OrtValue. This call constructs an instance that only contain - /// a pointer to non-zero values. To fully populate the sparse tensor call UseIndices() API below - /// to supply a sparse format specific indices. - /// This API is not suitable for string data. Use CreateSparseTensor() with allocator specified so strings - /// can be properly copied into the allocated buffer. - /// - /// Memory description where the user buffers reside (CPU vs GPU etc) - /// pointer to the user supplied buffer, use nullptr for fully sparse tensors - /// a would be dense shape of the tensor - /// non zero values shape. Use a single 0 shape for fully sparse tensors. - /// data type - /// Ort::Value instance containing SparseTensor - static Value CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape, - const Shape& values_shape, ONNXTensorElementDataType type); - - /// - /// This is a simple forwarding method to the below CreateSparseTensor. - /// This helps to specify data type enum in terms of C++ data type. - /// Use CreateSparseTensor - /// - /// numeric data type only. String data enum must be specified explicitly. - /// allocator to use - /// a would be dense shape of the tensor - /// Ort::Value - template - static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape); - - /// - /// Creates an instance of OrtValue containing sparse tensor. The created instance has no data. - /// The data must be supplied by on of the FillSparseTensor() methods that take both non-zero values - /// and indices. The data will be copied into a buffer that would be allocated using the supplied allocator. - /// Use this API to create OrtValues that contain sparse tensors with all supported data types including - /// strings. - /// - /// allocator to use. The allocator lifespan must eclipse that of the resulting OrtValue - /// a would be dense shape of the tensor - /// data type - /// an instance of Ort::Value - static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape, ONNXTensorElementDataType type); - -#endif // !defined(DISABLE_SPARSE_TENSORS) -}; - -/// -/// Represents native memory allocation coming from one of the -/// OrtAllocators registered with OnnxRuntime. -/// Use it to wrap an allocation made by an allocator -/// so it can be automatically released when no longer needed. -/// -struct MemoryAllocation { - MemoryAllocation(OrtAllocator* allocator, void* p, size_t size); - ~MemoryAllocation(); - MemoryAllocation(const MemoryAllocation&) = delete; - MemoryAllocation& operator=(const MemoryAllocation&) = delete; - MemoryAllocation(MemoryAllocation&&) noexcept; - MemoryAllocation& operator=(MemoryAllocation&&) noexcept; - - void* get() { return p_; } - size_t size() const { return size_; } - - private: - OrtAllocator* allocator_; - void* p_; - size_t size_; -}; - -namespace detail { -template -struct AllocatorImpl : Base { - using B = Base; - using B::B; - - void* Alloc(size_t size); - MemoryAllocation GetAllocation(size_t size); - void Free(void* p); - ConstMemoryInfo GetInfo() const; -}; - -} // namespace detail - -/** \brief Wrapper around ::OrtAllocator default instance that is owned by Onnxruntime - * - */ -struct AllocatorWithDefaultOptions : detail::AllocatorImpl> { - explicit AllocatorWithDefaultOptions(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance - AllocatorWithDefaultOptions(); -}; - -/** \brief Wrapper around ::OrtAllocator - * - */ -struct Allocator : detail::AllocatorImpl { - explicit Allocator(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance - Allocator(const Session& session, const OrtMemoryInfo*); -}; - -using UnownedAllocator = detail::AllocatorImpl>; - -namespace detail { -namespace binding_utils { -// Bring these out of template -std::vector GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator*); -std::vector GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator*); -} // namespace binding_utils - -template -struct ConstIoBindingImpl : Base { - using B = Base; - using B::B; - - std::vector GetOutputNames() const; - std::vector GetOutputNames(OrtAllocator*) const; - std::vector GetOutputValues() const; - std::vector GetOutputValues(OrtAllocator*) const; -}; - -template -struct IoBindingImpl : ConstIoBindingImpl { - using B = ConstIoBindingImpl; - using B::B; - - void BindInput(const char* name, const Value&); - void BindOutput(const char* name, const Value&); - void BindOutput(const char* name, const OrtMemoryInfo*); - void ClearBoundInputs(); - void ClearBoundOutputs(); - void SynchronizeInputs(); - void SynchronizeOutputs(); -}; - -} // namespace detail - -using ConstIoBinding = detail::ConstIoBindingImpl>; -using UnownedIoBinding = detail::IoBindingImpl>; - -/** \brief Wrapper around ::OrtIoBinding - * - */ -struct IoBinding : detail::IoBindingImpl { - explicit IoBinding(std::nullptr_t) {} ///< Create an empty object for convenience. Sometimes, we want to initialize members later. - explicit IoBinding(Session& session); - ConstIoBinding GetConst() const { return ConstIoBinding{this->p_}; } - UnownedIoBinding GetUnowned() const { return UnownedIoBinding{this->p_}; } -}; - -/*! \struct Ort::ArenaCfg - * \brief it is a structure that represents the configuration of an arena based allocator - * \details Please see docs/C_API.md for details - */ -struct ArenaCfg : detail::Base { - explicit ArenaCfg(std::nullptr_t) {} ///< Create an empty ArenaCfg object, must be assigned a valid one to be used - /** - * Wraps OrtApi::CreateArenaCfg - * \param max_mem - use 0 to allow ORT to choose the default - * \param arena_extend_strategy - use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested - * \param initial_chunk_size_bytes - use -1 to allow ORT to choose the default - * \param max_dead_bytes_per_chunk - use -1 to allow ORT to choose the default - * See docs/C_API.md for details on what the following parameters mean and how to choose these values - */ - ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk); -}; - -// -// Custom OPs (only needed to implement custom OPs) -// - -/// -/// This struct provides life time management for custom op attribute -/// -struct OpAttr : detail::Base { - OpAttr(const char* name, const void* data, int len, OrtOpAttrType type); -}; - -/** - * Macro that logs a message using the provided logger. Throws an exception if OrtApi::Logger_LogMessage fails. - * Example: ORT_CXX_LOG(logger, ORT_LOGGING_LEVEL_INFO, "Log a message"); - * - * \param logger The Ort::Logger instance to use. Must be a value or reference. - * \param message_severity The logging severity level of the message. - * \param message A null-terminated UTF-8 message to log. - */ -#define ORT_CXX_LOG(logger, message_severity, message) \ - do { \ - if (message_severity >= logger.GetLoggingSeverityLevel()) { \ - Ort::ThrowOnError(logger.LogMessage(message_severity, ORT_FILE, __LINE__, \ - static_cast(__FUNCTION__), message)); \ - } \ - } while (false) - -/** - * Macro that logs a message using the provided logger. Can be used in noexcept code since errors are silently ignored. - * Example: ORT_CXX_LOG_NOEXCEPT(logger, ORT_LOGGING_LEVEL_INFO, "Log a message"); - * - * \param logger The Ort::Logger instance to use. Must be a value or reference. - * \param message_severity The logging severity level of the message. - * \param message A null-terminated UTF-8 message to log. - */ -#define ORT_CXX_LOG_NOEXCEPT(logger, message_severity, message) \ - do { \ - if (message_severity >= logger.GetLoggingSeverityLevel()) { \ - static_cast(logger.LogMessage(message_severity, ORT_FILE, __LINE__, \ - static_cast(__FUNCTION__), message)); \ - } \ - } while (false) - -/** - * Macro that logs a printf-like formatted message using the provided logger. Throws an exception if - * OrtApi::Logger_LogMessage fails or if a formatting error occurs. - * Example: ORT_CXX_LOGF(logger, ORT_LOGGING_LEVEL_INFO, "Log an int: %d", 12); - * - * \param logger The Ort::Logger instance to use. Must be a value or reference. - * \param message_severity The logging severity level of the message. - * \param format A null-terminated UTF-8 format string forwarded to a printf-like function. - * Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats. - * \param ... Zero or more variadic arguments referenced by the format string. - */ -#define ORT_CXX_LOGF(logger, message_severity, /*format,*/...) \ - do { \ - if (message_severity >= logger.GetLoggingSeverityLevel()) { \ - Ort::ThrowOnError(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__, \ - static_cast(__FUNCTION__), __VA_ARGS__)); \ - } \ - } while (false) - -/** - * Macro that logs a printf-like formatted message using the provided logger. Can be used in noexcept code since errors - * are silently ignored. - * Example: ORT_CXX_LOGF_NOEXCEPT(logger, ORT_LOGGING_LEVEL_INFO, "Log an int: %d", 12); - * - * \param logger The Ort::Logger instance to use. Must be a value or reference. - * \param message_severity The logging severity level of the message. - * \param format A null-terminated UTF-8 format string forwarded to a printf-like function. - * Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats. - * \param ... Zero or more variadic arguments referenced by the format string. - */ -#define ORT_CXX_LOGF_NOEXCEPT(logger, message_severity, /*format,*/...) \ - do { \ - if (message_severity >= logger.GetLoggingSeverityLevel()) { \ - static_cast(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__, \ - static_cast(__FUNCTION__), __VA_ARGS__)); \ - } \ - } while (false) - -/// -/// This class represents an ONNX Runtime logger that can be used to log information with an -/// associated severity level and source code location (file path, line number, function name). -/// -/// A Logger can be obtained from within custom operators by calling Ort::KernelInfo::GetLogger(). -/// Instances of Ort::Logger are the size of two pointers and can be passed by value. -/// -/// Use the ORT_CXX_LOG macros to ensure the source code location is set properly from the callsite -/// and to take advantage of a cached logging severity level that can bypass calls to the underlying C API. -/// -struct Logger { - /** - * Creates an empty Ort::Logger. Must be initialized from a valid Ort::Logger before use. - */ - Logger() = default; - - /** - * Creates an empty Ort::Logger. Must be initialized from a valid Ort::Logger before use. - */ - explicit Logger(std::nullptr_t) {} - - /** - * Creates a logger from an ::OrtLogger instance. Caches the logger's current severity level by calling - * OrtApi::Logger_GetLoggingSeverityLevel. Throws an exception if OrtApi::Logger_GetLoggingSeverityLevel fails. - * - * \param logger The ::OrtLogger to wrap. - */ - explicit Logger(const OrtLogger* logger); - - ~Logger() = default; - - Logger(const Logger&) = default; - Logger& operator=(const Logger&) = default; - - Logger(Logger&& v) noexcept = default; - Logger& operator=(Logger&& v) noexcept = default; - - /** - * Returns the logger's current severity level from the cached member. - * - * \return The current ::OrtLoggingLevel. - */ - OrtLoggingLevel GetLoggingSeverityLevel() const noexcept; - - /** - * Logs the provided message via OrtApi::Logger_LogMessage. Use the ORT_CXX_LOG or ORT_CXX_LOG_NOEXCEPT - * macros to properly set the source code location and to use the cached severity level to potentially bypass - * calls to the underlying C API. - * - * \param log_severity_level The message's logging severity level. - * \param file_path The filepath of the file in which the message is logged. Usually the value of ORT_FILE. - * \param line_number The file line number in which the message is logged. Usually the value of __LINE__. - * \param func_name The name of the function in which the message is logged. Usually the value of __FUNCTION__. - * \param message The message to log. - * \return A Ort::Status value to indicate error or success. - */ - Status LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number, - const char* func_name, const char* message) const noexcept; - - /** - * Logs a printf-like formatted message via OrtApi::Logger_LogMessage. Use the ORT_CXX_LOGF or ORT_CXX_LOGF_NOEXCEPT - * macros to properly set the source code location and to use the cached severity level to potentially bypass - * calls to the underlying C API. Returns an error status if a formatting error occurs. - * - * \param log_severity_level The message's logging severity level. - * \param file_path The filepath of the file in which the message is logged. Usually the value of ORT_FILE. - * \param line_number The file line number in which the message is logged. Usually the value of __LINE__. - * \param func_name The name of the function in which the message is logged. Usually the value of __FUNCTION__. - * \param format A null-terminated UTF-8 format string forwarded to a printf-like function. - * Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats. - * \param args Zero or more variadic arguments referenced by the format string. - * \return A Ort::Status value to indicate error or success. - */ - template - Status LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number, - const char* func_name, const char* format, Args&&... args) const noexcept; - - private: - const OrtLogger* logger_{}; - OrtLoggingLevel cached_severity_level_{}; -}; - -/// -/// This class wraps a raw pointer OrtKernelContext* that is being passed -/// to the custom kernel Compute() method. Use it to safely access context -/// attributes, input and output parameters with exception safety guarantees. -/// See usage example in onnxruntime/test/testdata/custom_op_library/custom_op_library.cc -/// -struct KernelContext { - explicit KernelContext(OrtKernelContext* context); - size_t GetInputCount() const; - size_t GetOutputCount() const; - // If input is optional and is not present, the method returns en empty ConstValue - // which can be compared to nullptr. - ConstValue GetInput(size_t index) const; - // If outout is optional and is not present, the method returns en empty UnownedValue - // which can be compared to nullptr. - UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const; - UnownedValue GetOutput(size_t index, const std::vector& dims) const; - void* GetGPUComputeStream() const; - Logger GetLogger() const; - OrtAllocator* GetAllocator(const OrtMemoryInfo& memory_info) const; - OrtKernelContext* GetOrtKernelContext() const { return ctx_; } - void ParallelFor(void (*fn)(void*, size_t), size_t total, size_t num_batch, void* usr_data) const; - - private: - OrtKernelContext* ctx_; -}; - -struct KernelInfo; - -namespace detail { -namespace attr_utils { -void GetAttr(const OrtKernelInfo* p, const char* name, float&); -void GetAttr(const OrtKernelInfo* p, const char* name, int64_t&); -void GetAttr(const OrtKernelInfo* p, const char* name, std::string&); -void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector&); -void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector&); -} // namespace attr_utils - -template -struct KernelInfoImpl : Base { - using B = Base; - using B::B; - - KernelInfo Copy() const; - - template // R is only implemented for float, int64_t, and string - R GetAttribute(const char* name) const { - R val; - attr_utils::GetAttr(this->p_, name, val); - return val; - } - - template // R is only implemented for std::vector, std::vector - std::vector GetAttributes(const char* name) const { - std::vector result; - attr_utils::GetAttrs(this->p_, name, result); - return result; - } - - Value GetTensorAttribute(const char* name, OrtAllocator* allocator) const; - - size_t GetInputCount() const; - size_t GetOutputCount() const; - - std::string GetInputName(size_t index) const; - std::string GetOutputName(size_t index) const; - - TypeInfo GetInputTypeInfo(size_t index) const; - TypeInfo GetOutputTypeInfo(size_t index) const; - - ConstValue GetTensorConstantInput(size_t index, int* is_constant) const; - - std::string GetNodeName() const; - Logger GetLogger() const; -}; - -} // namespace detail - -using ConstKernelInfo = detail::KernelInfoImpl>; - -/// -/// This struct owns the OrtKernInfo* pointer when a copy is made. -/// For convenient wrapping of OrtKernelInfo* passed to kernel constructor -/// and query attributes, warp the pointer with Ort::Unowned instance -/// so it does not destroy the pointer the kernel does not own. -/// -struct KernelInfo : detail::KernelInfoImpl { - explicit KernelInfo(std::nullptr_t) {} ///< Create an empty instance to initialize later - explicit KernelInfo(OrtKernelInfo* info); ///< Take ownership of the instance - ConstKernelInfo GetConst() const { return ConstKernelInfo{this->p_}; } -}; - -/// -/// Create and own custom defined operation. -/// -struct Op : detail::Base { - explicit Op(std::nullptr_t) {} ///< Create an empty Operator object, must be assigned a valid one to be used - - explicit Op(OrtOp*); ///< Take ownership of the OrtOp - - static Op Create(const OrtKernelInfo* info, const char* op_name, const char* domain, - int version, const char** type_constraint_names, - const ONNXTensorElementDataType* type_constraint_values, - size_t type_constraint_count, - const OpAttr* attr_values, - size_t attr_count, - size_t input_count, size_t output_count); - - void Invoke(const OrtKernelContext* context, - const Value* input_values, - size_t input_count, - Value* output_values, - size_t output_count); - - // For easier refactoring - void Invoke(const OrtKernelContext* context, - const OrtValue* const* input_values, - size_t input_count, - OrtValue* const* output_values, - size_t output_count); -}; - -/// -/// Provide access to per-node attributes and input shapes, so one could compute and set output shapes. -/// -struct ShapeInferContext { - struct SymbolicInteger { - SymbolicInteger(int64_t i) : i_(i), is_int_(true) {}; - SymbolicInteger(const char* s) : s_(s), is_int_(false) {}; - SymbolicInteger(const SymbolicInteger&) = default; - SymbolicInteger(SymbolicInteger&&) = default; - - SymbolicInteger& operator=(const SymbolicInteger&) = default; - SymbolicInteger& operator=(SymbolicInteger&&) = default; - - bool operator==(const SymbolicInteger& dim) const { - if (is_int_ == dim.is_int_) { - if (is_int_) { - return i_ == dim.i_; - } else { - return std::string{s_} == std::string{dim.s_}; - } - } - return false; - } - - bool IsInt() const { return is_int_; } - int64_t AsInt() const { return i_; } - const char* AsSym() const { return s_; } - - static constexpr int INVALID_INT_DIM = -2; - - private: - union { - int64_t i_; - const char* s_; - }; - bool is_int_; - }; - - using Shape = std::vector; - - ShapeInferContext(const OrtApi* ort_api, OrtShapeInferContext* ctx); - - const Shape& GetInputShape(size_t indice) const { return input_shapes_.at(indice); } - - size_t GetInputCount() const { return input_shapes_.size(); } - - Status SetOutputShape(size_t indice, const Shape& shape, ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); - - int64_t GetAttrInt(const char* attr_name); - - using Ints = std::vector; - Ints GetAttrInts(const char* attr_name); - - float GetAttrFloat(const char* attr_name); - - using Floats = std::vector; - Floats GetAttrFloats(const char* attr_name); - - std::string GetAttrString(const char* attr_name); - - using Strings = std::vector; - Strings GetAttrStrings(const char* attr_name); - - private: - const OrtOpAttr* GetAttrHdl(const char* attr_name) const; - const OrtApi* ort_api_; - OrtShapeInferContext* ctx_; - std::vector input_shapes_; -}; - -using ShapeInferFn = Ort::Status (*)(Ort::ShapeInferContext&); - -#define MAX_CUSTOM_OP_END_VER (1UL << 31) - 1 - -template -struct CustomOpBase : OrtCustomOp { - CustomOpBase() { - OrtCustomOp::version = ORT_API_VERSION; - OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast(this_)->GetName(); }; - - OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast(this_)->GetExecutionProviderType(); }; - - OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast(this_)->GetInputTypeCount(); }; - OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetInputType(index); }; - OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetInputMemoryType(index); }; - - OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast(this_)->GetOutputTypeCount(); }; - OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetOutputType(index); }; - -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(push) -#pragma warning(disable : 26409) -#endif - OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast(op_kernel); }; -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(pop) -#endif - OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetInputCharacteristic(index); }; - OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetOutputCharacteristic(index); }; - - OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp* this_) { return static_cast(this_)->GetVariadicInputMinArity(); }; - OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp* this_) { return static_cast(static_cast(this_)->GetVariadicInputHomogeneity()); }; - OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp* this_) { return static_cast(this_)->GetVariadicOutputMinArity(); }; - OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp* this_) { return static_cast(static_cast(this_)->GetVariadicOutputHomogeneity()); }; -#ifdef __cpp_if_constexpr - if constexpr (WithStatus) { -#else - if (WithStatus) { -#endif - OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr { - return static_cast(this_)->CreateKernelV2(*api, info, op_kernel); - }; - OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { - return static_cast(op_kernel)->ComputeV2(context); - }; - } else { - OrtCustomOp::CreateKernelV2 = nullptr; - OrtCustomOp::KernelComputeV2 = nullptr; - - OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast(this_)->CreateKernel(*api, info); }; - OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { - static_cast(op_kernel)->Compute(context); - }; - } - - SetShapeInferFn(0); - - OrtCustomOp::GetStartVersion = [](const OrtCustomOp* this_) { - return static_cast(this_)->start_ver_; - }; - - OrtCustomOp::GetEndVersion = [](const OrtCustomOp* this_) { - return static_cast(this_)->end_ver_; - }; - - OrtCustomOp::GetMayInplace = nullptr; - OrtCustomOp::ReleaseMayInplace = nullptr; - OrtCustomOp::GetAliasMap = nullptr; - OrtCustomOp::ReleaseAliasMap = nullptr; - } - - // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider - const char* GetExecutionProviderType() const { return nullptr; } - - // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below - // (inputs and outputs are required by default) - OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const { - return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; - } - - OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const { - return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; - } - - // Default implemention of GetInputMemoryType() that returns OrtMemTypeDefault - OrtMemType GetInputMemoryType(size_t /*index*/) const { - return OrtMemTypeDefault; - } - - // Default implementation of GetVariadicInputMinArity() returns 1 to specify that a variadic input - // should expect at least 1 argument. - int GetVariadicInputMinArity() const { - return 1; - } - - // Default implementation of GetVariadicInputHomegeneity() returns true to specify that all arguments - // to a variadic input should be of the same type. - bool GetVariadicInputHomogeneity() const { - return true; - } - - // Default implementation of GetVariadicOutputMinArity() returns 1 to specify that a variadic output - // should produce at least 1 output value. - int GetVariadicOutputMinArity() const { - return 1; - } - - // Default implementation of GetVariadicOutputHomegeneity() returns true to specify that all output values - // produced by a variadic output should be of the same type. - bool GetVariadicOutputHomogeneity() const { - return true; - } - - // Declare list of session config entries used by this Custom Op. - // Implement this function in order to get configs from CustomOpBase::GetSessionConfigs(). - // This default implementation returns an empty vector of config entries. - std::vector GetSessionConfigKeys() const { - return std::vector{}; - } - - template - decltype(&C::InferOutputShape) SetShapeInferFn(decltype(&C::InferOutputShape)) { - OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr { - ShapeInferContext ctx(&GetApi(), ort_ctx); - return C::InferOutputShape(ctx); - }; - return {}; - } - - template - void SetShapeInferFn(...) { - OrtCustomOp::InferOutputShapeFn = {}; - } - - protected: - // Helper function that returns a map of session config entries specified by CustomOpBase::GetSessionConfigKeys. - void GetSessionConfigs(std::unordered_map& out, ConstSessionOptions options) const; - - int start_ver_ = 1; - int end_ver_ = MAX_CUSTOM_OP_END_VER; -}; - -} // namespace Ort - -#include "onnxruntime_cxx_inline.h" diff --git a/tools/onnx_lib/Source/include_rel/onnxruntime_cxx_inline.h b/tools/onnx_lib/Source/include_rel/onnxruntime_cxx_inline.h deleted file mode 100644 index 9b9dd81a74..0000000000 --- a/tools/onnx_lib/Source/include_rel/onnxruntime_cxx_inline.h +++ /dev/null @@ -1,2128 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// Do not include this file directly. Please include "onnxruntime_cxx_api.h" instead. -// If interested in trying out features of the new experimental C++ API, include "experimental_onnxruntime_cxx_api.h" instead. -// -// These are the inline implementations of the C++ header APIs. They're in this separate file as to not clutter -// the main C++ file with implementation details. - -#include -#include -#include -#include - -// Convert OrtStatus to Ort::Status and return -// instead of throwing -#define ORT_CXX_RETURN_ON_API_FAIL(expression) \ - { \ - auto ort_status = (expression); \ - if (ort_status) { \ - return Ort::Status(ort_status); \ - } \ - } - -#ifdef __cpp_if_constexpr -#define ORT_CXX_IF_CONSTEXPR if constexpr -#else -#define ORT_CXX_IF_CONSTEXPR if -#endif - -namespace Ort { - -namespace detail { -inline void ThrowStatus(const Status& st) { - std::string error_message = st.GetErrorMessage(); - OrtErrorCode error_code = st.GetErrorCode(); - ORT_CXX_API_THROW(std::move(error_message), error_code); -} -} // namespace detail - -inline void ThrowOnError(OrtStatus* ort_status) { - if (ort_status) { - Ort::Status st(ort_status); - detail::ThrowStatus(st); - } -} - -inline void ThrowOnError(const Status& st) { - if (st) { - detail::ThrowStatus(st); - } -} - -inline Status::Status(OrtStatus* status) noexcept : Base{status} { -} - -inline Status::Status(const std::exception& e) noexcept { - p_ = GetApi().CreateStatus(ORT_FAIL, e.what()); -} - -inline Status::Status(const Exception& e) noexcept { - p_ = GetApi().CreateStatus(e.GetOrtErrorCode(), e.what()); -} - -inline Status::Status(const char* message, OrtErrorCode code) noexcept { - p_ = GetApi().CreateStatus(code, message); -} - -inline std::string Status::GetErrorMessage() const { - std::string message(GetApi().GetErrorMessage(p_)); - return message; -} - -inline OrtErrorCode Status::GetErrorCode() const { - return GetApi().GetErrorCode(p_); -} - -inline bool Status::IsOK() const noexcept { - return (p_ == nullptr); -} - -// This template converts a C++ type into it's ONNXTensorElementDataType -template -struct TypeToTensorType; -template <> -struct TypeToTensorType { - static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; -}; -template <> -struct TypeToTensorType { - static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; -}; -template <> -struct TypeToTensorType { - static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16; -}; -template <> -struct TypeToTensorType { - static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; -}; -template <> -struct TypeToTensorType { - static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; -}; -template <> -struct TypeToTensorType { - static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; -}; -template <> -struct TypeToTensorType { - static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; -}; -template <> -struct TypeToTensorType { - static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; -}; -template <> -struct TypeToTensorType { - static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; -}; -template <> -struct TypeToTensorType { - static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; -}; -template <> -struct TypeToTensorType { - static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; -}; -template <> -struct TypeToTensorType { - static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; -}; -template <> -struct TypeToTensorType { - static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; -}; - -template <> -struct TypeToTensorType { - static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN; -}; -template <> -struct TypeToTensorType { - static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ; -}; -template <> -struct TypeToTensorType { - static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2; -}; -template <> -struct TypeToTensorType { - static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ; -}; - -inline bool BFloat16_t::operator==(const BFloat16_t& rhs) const noexcept { - if (IsNaN() || rhs.IsNaN()) { - // IEEE defines that NaN is not equal to anything, including itself. - return false; - } - return val == rhs.val; -} - -inline bool BFloat16_t::operator<(const BFloat16_t& rhs) const noexcept { - if (IsNaN() || rhs.IsNaN()) { - // IEEE defines that NaN is unordered with respect to everything, including itself. - return false; - } - - const bool left_is_negative = IsNegative(); - if (left_is_negative != rhs.IsNegative()) { - // When the signs of left and right differ, we know that left is less than right if it is - // the negative value. The exception to this is if both values are zero, in which case IEEE - // says they should be equal, even if the signs differ. - return left_is_negative && !AreZero(*this, rhs); - } - return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative); -} - -inline MemoryAllocation::MemoryAllocation(OrtAllocator* allocator, void* p, size_t size) - : allocator_(allocator), p_(p), size_(size) { -} - -inline MemoryAllocation::~MemoryAllocation() { - if (p_ != nullptr) { - // We do not throw out of destructor - auto ret = GetApi().AllocatorFree(allocator_, p_); - static_cast(ret); - } -} - -inline MemoryAllocation::MemoryAllocation(MemoryAllocation&& o) noexcept : allocator_(nullptr), p_(nullptr), size_(0) { - *this = std::move(o); -} - -inline MemoryAllocation& MemoryAllocation::operator=(MemoryAllocation&& o) noexcept { - OrtAllocator* alloc = nullptr; - void* p = nullptr; - size_t sz = 0; - - // Swap out this - std::swap(alloc, allocator_); - std::swap(p, p_); - std::swap(sz, size_); - - // Swap with incoming - std::swap(allocator_, o.allocator_); - std::swap(p_, o.p_); - std::swap(size_, o.size_); - - // Destroy this instance if needed - MemoryAllocation this_alloc(alloc, p, sz); - return *this; -} - -namespace detail { - -template -inline void* AllocatorImpl::Alloc(size_t size) { - void* out; - ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out)); - return out; -} - -template -inline MemoryAllocation AllocatorImpl::GetAllocation(size_t size) { - void* out; - ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out)); - MemoryAllocation result(this->p_, out, size); - return result; -} - -template -inline void AllocatorImpl::Free(void* p) { - ThrowOnError(GetApi().AllocatorFree(this->p_, p)); -} - -template -inline ConstMemoryInfo AllocatorImpl::GetInfo() const { - const OrtMemoryInfo* out; - ThrowOnError(GetApi().AllocatorGetInfo(this->p_, &out)); - return ConstMemoryInfo{out}; -} - -} // namespace detail - -inline AllocatorWithDefaultOptions::AllocatorWithDefaultOptions() { - ThrowOnError(GetApi().GetAllocatorWithDefaultOptions(&this->p_)); -} - -inline Allocator::Allocator(const Session& sess, const OrtMemoryInfo* mem_info) { - ThrowOnError(GetApi().CreateAllocator(sess, mem_info, &this->p_)); -} - -namespace detail { - -template -inline std::string MemoryInfoImpl::GetAllocatorName() const { - const char* name = nullptr; - ThrowOnError(GetApi().MemoryInfoGetName(this->p_, &name)); - return std::string(name); -} - -template -inline OrtAllocatorType MemoryInfoImpl::GetAllocatorType() const { - OrtAllocatorType type; - ThrowOnError(GetApi().MemoryInfoGetType(this->p_, &type)); - return type; -} - -template -inline int MemoryInfoImpl::GetDeviceId() const { - int id = 0; - ThrowOnError(GetApi().MemoryInfoGetId(this->p_, &id)); - return id; -} - -template -inline OrtMemoryInfoDeviceType MemoryInfoImpl::GetDeviceType() const { - OrtMemoryInfoDeviceType type; - GetApi().MemoryInfoGetDeviceType(this->p_, &type); - return type; -} - -template -inline OrtMemType MemoryInfoImpl::GetMemoryType() const { - OrtMemType type; - ThrowOnError(GetApi().MemoryInfoGetMemType(this->p_, &type)); - return type; -} - -template -template -inline bool MemoryInfoImpl::operator==(const MemoryInfoImpl& o) const { - int comp_result = 0; - ThrowOnError(Ort::GetApi().CompareMemoryInfo(this->p_, o, &comp_result)); - return comp_result == 0; -} - -} // namespace detail - -inline MemoryInfo MemoryInfo::CreateCpu(OrtAllocatorType type, OrtMemType mem_type) { - OrtMemoryInfo* p; - ThrowOnError(GetApi().CreateCpuMemoryInfo(type, mem_type, &p)); - return MemoryInfo(p); -} - -inline MemoryInfo::MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) { - ThrowOnError(GetApi().CreateMemoryInfo(name, type, id, mem_type, &this->p_)); -} - -namespace detail { -template -inline std::vector ConstIoBindingImpl::GetOutputNames() const { - AllocatorWithDefaultOptions allocator; - return binding_utils::GetOutputNamesHelper(this->p_, allocator); -} - -template -inline std::vector ConstIoBindingImpl::GetOutputNames(OrtAllocator* allocator) const { - return binding_utils::GetOutputNamesHelper(this->p_, allocator); -} - -template -inline std::vector ConstIoBindingImpl::GetOutputValues() const { - AllocatorWithDefaultOptions allocator; - return binding_utils::GetOutputValuesHelper(this->p_, allocator); -} - -template -inline std::vector ConstIoBindingImpl::GetOutputValues(OrtAllocator* allocator) const { - return binding_utils::GetOutputValuesHelper(this->p_, allocator); -} - -template -inline void IoBindingImpl::BindInput(const char* name, const Value& value) { - ThrowOnError(GetApi().BindInput(this->p_, name, value)); -} - -template -inline void IoBindingImpl::BindOutput(const char* name, const Value& value) { - ThrowOnError(GetApi().BindOutput(this->p_, name, value)); -} - -template -inline void IoBindingImpl::BindOutput(const char* name, const OrtMemoryInfo* mem_info) { - ThrowOnError(GetApi().BindOutputToDevice(this->p_, name, mem_info)); -} - -template -inline void IoBindingImpl::ClearBoundInputs() { - GetApi().ClearBoundInputs(this->p_); -} - -template -inline void IoBindingImpl::ClearBoundOutputs() { - GetApi().ClearBoundOutputs(this->p_); -} - -template -inline void IoBindingImpl::SynchronizeInputs() { - ThrowOnError(GetApi().SynchronizeBoundInputs(this->p_)); -} - -template -inline void IoBindingImpl::SynchronizeOutputs() { - ThrowOnError(GetApi().SynchronizeBoundOutputs(this->p_)); -} - -namespace binding_utils { -inline std::vector GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) { - std::vector result; - auto free_fn = detail::AllocatedFree(allocator); - using Ptr = std::unique_ptr; - - char* buffer = nullptr; - size_t* lengths = nullptr; - size_t count = 0; - ThrowOnError(GetApi().GetBoundOutputNames(binding, allocator, &buffer, &lengths, &count)); - - if (count == 0) { - return result; - } - - Ptr buffer_g(buffer, free_fn); - Ptr lengths_g(lengths, free_fn); - - result.reserve(count); - for (size_t i = 0; i < count; ++i) { - auto sz = *lengths; - result.emplace_back(buffer, sz); - buffer += sz; - ++lengths; - } - return result; -} - -inline std::vector GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) { - std::vector result; - size_t owned = 0; - size_t output_count = 0; - // Lambda to release the buffer when no longer needed and - // make sure that we destroy all instances on exception - auto free_fn = [&owned, &output_count, allocator](OrtValue** buffer) { - if (buffer) { - while (owned < output_count) { - auto* p = buffer + owned++; - GetApi().ReleaseValue(*p); - } - allocator->Free(allocator, buffer); - } - }; - using Ptr = std::unique_ptr; - - OrtValue** output_buffer = nullptr; - ThrowOnError(GetApi().GetBoundOutputValues(binding, allocator, &output_buffer, &output_count)); - if (output_count == 0) { - return result; - } - - Ptr buffer_g(output_buffer, free_fn); - - result.reserve(output_count); - for (size_t i = 0; i < output_count; ++i) { - result.emplace_back(output_buffer[i]); - ++owned; - } - return result; -} - -} // namespace binding_utils -} // namespace detail - -inline IoBinding::IoBinding(Session& session) { - ThrowOnError(GetApi().CreateIoBinding(session, &this->p_)); -} - -inline ArenaCfg::ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk) { - ThrowOnError(GetApi().CreateArenaCfg(max_mem, arena_extend_strategy, initial_chunk_size_bytes, max_dead_bytes_per_chunk, &p_)); -} - -inline ThreadingOptions::ThreadingOptions() { - ThrowOnError(GetApi().CreateThreadingOptions(&p_)); -} - -inline ThreadingOptions& ThreadingOptions::SetGlobalIntraOpNumThreads(int intra_op_num_threads) { - ThrowOnError(GetApi().SetGlobalIntraOpNumThreads(p_, intra_op_num_threads)); - return *this; -} - -inline ThreadingOptions& ThreadingOptions::SetGlobalInterOpNumThreads(int inter_op_num_threads) { - ThrowOnError(GetApi().SetGlobalInterOpNumThreads(p_, inter_op_num_threads)); - return *this; -} - -inline ThreadingOptions& ThreadingOptions::SetGlobalSpinControl(int allow_spinning) { - ThrowOnError(GetApi().SetGlobalSpinControl(p_, allow_spinning)); - return *this; -} - -inline ThreadingOptions& ThreadingOptions::SetGlobalDenormalAsZero() { - ThrowOnError(GetApi().SetGlobalDenormalAsZero(p_)); - return *this; -} - -inline ThreadingOptions& ThreadingOptions::SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) { - ThrowOnError(GetApi().SetGlobalCustomCreateThreadFn(p_, ort_custom_create_thread_fn)); - return *this; -} - -inline ThreadingOptions& ThreadingOptions::SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options) { - ThrowOnError(GetApi().SetGlobalCustomThreadCreationOptions(p_, ort_custom_thread_creation_options)); - return *this; -} - -inline ThreadingOptions& ThreadingOptions::SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) { - ThrowOnError(GetApi().SetGlobalCustomJoinThreadFn(p_, ort_custom_join_thread_fn)); - return *this; -} - -inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) { - ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_)); - if (strcmp(logid, "onnxruntime-node") == 0) { - ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS)); - } else { - ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS)); - } -} - -inline Env::Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param) { - ThrowOnError(GetApi().CreateEnvWithCustomLogger(logging_function, logger_param, logging_level, logid, &p_)); - if (strcmp(logid, "onnxruntime-node") == 0) { - ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS)); - } else { - ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS)); - } -} - -inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level, _In_ const char* logid) { - ThrowOnError(GetApi().CreateEnvWithGlobalThreadPools(logging_level, logid, tp_options, &p_)); - if (strcmp(logid, "onnxruntime-node") == 0) { - ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS)); - } else { - ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS)); - } -} - -inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param, - OrtLoggingLevel logging_level, _In_ const char* logid) { - ThrowOnError(GetApi().CreateEnvWithCustomLoggerAndGlobalThreadPools(logging_function, logger_param, logging_level, logid, tp_options, &p_)); - if (strcmp(logid, "onnxruntime-node") == 0) { - ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS)); - } else { - ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS)); - } -} - -inline Env& Env::EnableTelemetryEvents() { - ThrowOnError(GetApi().EnableTelemetryEvents(p_)); - return *this; -} - -inline Env& Env::DisableTelemetryEvents() { - ThrowOnError(GetApi().DisableTelemetryEvents(p_)); - return *this; -} - -inline Env& Env::UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level) { - ThrowOnError(GetApi().UpdateEnvWithCustomLogLevel(p_, log_severity_level)); - return *this; -} - -inline Env& Env::CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg) { - ThrowOnError(GetApi().CreateAndRegisterAllocator(p_, mem_info, arena_cfg)); - return *this; -} - -inline Env& Env::CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info, const std::unordered_map& options, const OrtArenaCfg* arena_cfg) { - std::vector keys, values; - auto num_entries = options.size(); - if (num_entries > 0) { - keys.reserve(num_entries); - values.reserve(num_entries); - for (const auto& entry : options) { - keys.push_back(entry.first.c_str()); - values.push_back(entry.second.c_str()); - } - } - ThrowOnError(GetApi().CreateAndRegisterAllocatorV2(p_, provider_type.c_str(), mem_info, arena_cfg, keys.data(), values.data(), num_entries)); - return *this; -} - -inline CustomOpDomain::CustomOpDomain(const char* domain) { - ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_)); -} - -inline void CustomOpDomain::Add(const OrtCustomOp* op) { - ThrowOnError(GetApi().CustomOpDomain_Add(p_, op)); -} - -inline RunOptions::RunOptions() { - ThrowOnError(GetApi().CreateRunOptions(&p_)); -} - -inline RunOptions& RunOptions::SetRunLogVerbosityLevel(int level) { - ThrowOnError(GetApi().RunOptionsSetRunLogVerbosityLevel(p_, level)); - return *this; -} - -inline RunOptions& RunOptions::SetRunLogSeverityLevel(int level) { - ThrowOnError(GetApi().RunOptionsSetRunLogSeverityLevel(p_, level)); - return *this; -} - -inline int RunOptions::GetRunLogVerbosityLevel() const { - int out; - ThrowOnError(GetApi().RunOptionsGetRunLogVerbosityLevel(p_, &out)); - return out; -} - -inline int RunOptions::GetRunLogSeverityLevel() const { - int out; - ThrowOnError(GetApi().RunOptionsGetRunLogSeverityLevel(p_, &out)); - return out; -} - -inline RunOptions& RunOptions::SetRunTag(const char* run_tag) { - ThrowOnError(GetApi().RunOptionsSetRunTag(p_, run_tag)); - return *this; -} - -inline const char* RunOptions::GetRunTag() const { - const char* out; - ThrowOnError(GetApi().RunOptionsGetRunTag(p_, &out)); - return out; -} - -inline RunOptions& RunOptions::AddConfigEntry(const char* config_key, const char* config_value) { - ThrowOnError(GetApi().AddRunConfigEntry(p_, config_key, config_value)); - return *this; -} - -inline RunOptions& RunOptions::SetTerminate() { - ThrowOnError(GetApi().RunOptionsSetTerminate(p_)); - return *this; -} - -inline RunOptions& RunOptions::UnsetTerminate() { - ThrowOnError(GetApi().RunOptionsUnsetTerminate(p_)); - return *this; -} - -namespace detail { - -template -inline Ort::SessionOptions ConstSessionOptionsImpl::Clone() const { - OrtSessionOptions* out; - ThrowOnError(GetApi().CloneSessionOptions(this->p_, &out)); - return SessionOptions{out}; -} - -template -inline std::string ConstSessionOptionsImpl::GetConfigEntry(const char* config_key) const { - size_t size = 0; - // Feed nullptr for the data buffer to query the true size of the string value - Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, nullptr, &size)); - - std::string out; - out.resize(size); - Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, &out[0], &size)); - out.resize(size - 1); // remove the terminating character '\0' - - return out; -} - -template -inline bool ConstSessionOptionsImpl::HasConfigEntry(const char* config_key) const { - int out = 0; - Ort::ThrowOnError(GetApi().HasSessionConfigEntry(this->p_, config_key, &out)); - return static_cast(out); -} - -template -inline std::string ConstSessionOptionsImpl::GetConfigEntryOrDefault(const char* config_key, const std::string& def) { - if (!this->HasConfigEntry(config_key)) { - return def; - } - - return this->GetConfigEntry(config_key); -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::SetIntraOpNumThreads(int intra_op_num_threads) { - ThrowOnError(GetApi().SetIntraOpNumThreads(this->p_, intra_op_num_threads)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::SetInterOpNumThreads(int inter_op_num_threads) { - ThrowOnError(GetApi().SetInterOpNumThreads(this->p_, inter_op_num_threads)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) { - ThrowOnError(GetApi().SetSessionGraphOptimizationLevel(this->p_, graph_optimization_level)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::SetDeterministicCompute(bool value) { - ThrowOnError(GetApi().SetDeterministicCompute(this->p_, value)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_filepath) { - ThrowOnError(GetApi().SetOptimizedModelFilePath(this->p_, optimized_model_filepath)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::EnableProfiling(const ORTCHAR_T* profile_file_prefix) { - ThrowOnError(GetApi().EnableProfiling(this->p_, profile_file_prefix)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::DisableProfiling() { - ThrowOnError(GetApi().DisableProfiling(this->p_)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::EnableOrtCustomOps() { - ThrowOnError(GetApi().EnableOrtCustomOps(this->p_)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::EnableMemPattern() { - ThrowOnError(GetApi().EnableMemPattern(this->p_)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::DisableMemPattern() { - ThrowOnError(GetApi().DisableMemPattern(this->p_)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::EnableCpuMemArena() { - ThrowOnError(GetApi().EnableCpuMemArena(this->p_)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::DisableCpuMemArena() { - ThrowOnError(GetApi().DisableCpuMemArena(this->p_)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::SetExecutionMode(ExecutionMode execution_mode) { - ThrowOnError(GetApi().SetSessionExecutionMode(this->p_, execution_mode)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::SetLogId(const char* logid) { - ThrowOnError(GetApi().SetSessionLogId(this->p_, logid)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::SetLogSeverityLevel(int level) { - ThrowOnError(GetApi().SetSessionLogSeverityLevel(this->p_, level)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::Add(OrtCustomOpDomain* custom_op_domain) { - ThrowOnError(GetApi().AddCustomOpDomain(this->p_, custom_op_domain)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::AddConfigEntry(const char* config_key, const char* config_value) { - ThrowOnError(GetApi().AddSessionConfigEntry(this->p_, config_key, config_value)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::AddInitializer(const char* name, const OrtValue* ort_val) { - ThrowOnError(GetApi().AddInitializer(this->p_, name, ort_val)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::DisablePerSessionThreads() { - ThrowOnError(GetApi().DisablePerSessionThreads(this->p_)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::AddExternalInitializers(const std::vector& names, - const std::vector& ort_values) { - const size_t inputs_num = names.size(); - if (inputs_num != ort_values.size()) { - ORT_CXX_API_THROW("Expecting names and ort_values to have the same length", ORT_INVALID_ARGUMENT); - } - std::vector names_ptr; - std::vector ort_values_ptrs; - names_ptr.reserve(inputs_num); - ort_values_ptrs.reserve(inputs_num); - for (size_t i = 0; i < inputs_num; ++i) { - names_ptr.push_back(names[i].c_str()); - ort_values_ptrs.push_back(ort_values[i]); - } - ThrowOnError(GetApi().AddExternalInitializers(this->p_, names_ptr.data(), ort_values_ptrs.data(), inputs_num)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::AddExternalInitializersFromFilesInMemory(const std::vector>& file_names, - const std::vector& buffer_array, - const std::vector& file_lengths) { - const size_t inputs_num = file_names.size(); - if (inputs_num != buffer_array.size()) { - ORT_CXX_API_THROW("Expecting names and buffer_array to have the same length", ORT_INVALID_ARGUMENT); - } - if (inputs_num != file_lengths.size()) { - ORT_CXX_API_THROW("Expecting names and file_lengths to have the same length", ORT_INVALID_ARGUMENT); - } - std::vector names_ptr; - names_ptr.reserve(inputs_num); - for (size_t i = 0; i < inputs_num; ++i) { - names_ptr.push_back(file_names[i].c_str()); - } - ThrowOnError(GetApi().AddExternalInitializersFromFilesInMemory(this->p_, names_ptr.data(), buffer_array.data(), - file_lengths.data(), inputs_num)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options) { - ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA(this->p_, &provider_options)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options) { - ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA_V2(this->p_, &provider_options)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options) { - ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_ROCM(this->p_, &provider_options)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options) { - ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT(this->p_, &provider_options)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options) { - ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT_V2(this->p_, &provider_options)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options) { - ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_MIGraphX(this->p_, &provider_options)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options) { - ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CANN(this->p_, &provider_options)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options) { - ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_Dnnl(this->p_, &provider_options)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider( - const std::string& provider_name, - const std::unordered_map& provider_options) { - auto num_entries = provider_options.size(); - std::vector keys, values; - if (num_entries > 0) { - keys.reserve(num_entries); - values.reserve(num_entries); - - for (const auto& entry : provider_options) { - keys.push_back(entry.first.c_str()); - values.push_back(entry.second.c_str()); - } - } - - ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider(this->p_, provider_name.c_str(), - keys.data(), values.data(), num_entries)); - - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) { - ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(this->p_, ort_custom_create_thread_fn)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options) { - ThrowOnError(GetApi().SessionOptionsSetCustomThreadCreationOptions(this->p_, ort_custom_thread_creation_options)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) { - ThrowOnError(GetApi().SessionOptionsSetCustomJoinThreadFn(this->p_, ort_custom_join_thread_fn)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options) { - ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO(this->p_, &provider_options)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_OpenVINO_V2(const std::unordered_map& provider_options) { - auto num_entries = provider_options.size(); - std::vector keys, values; - if (num_entries > 0) { - keys.reserve(num_entries); - values.reserve(num_entries); - - for (const auto& entry : provider_options) { - keys.push_back(entry.first.c_str()); - values.push_back(entry.second.c_str()); - } - } - - ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO_V2(this->p_, - keys.data(), values.data(), num_entries)); - - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_VitisAI(const std::unordered_map& provider_options) { - auto num_entries = provider_options.size(); - std::vector keys, values; - if (num_entries > 0) { - keys.reserve(num_entries); - values.reserve(num_entries); - - for (const auto& entry : provider_options) { - keys.push_back(entry.first.c_str()); - values.push_back(entry.second.c_str()); - } - } - - ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_VitisAI(this->p_, keys.data(), values.data(), num_entries)); - - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, - const CustomOpConfigs& custom_op_configs) { - // Add custom op config entries before registering the custom op library. Otherwise, the config entries _may_ be ignored by - // the custom op library. - for (const auto& config_iter : custom_op_configs.GetFlattenedConfigs()) { - AddConfigEntry(config_iter.first.c_str(), config_iter.second.c_str()); - } - - ThrowOnError(GetApi().RegisterCustomOpsLibrary_V2(this->p_, library_name)); - return *this; -} - -template -inline SessionOptionsImpl& SessionOptionsImpl::RegisterCustomOpsUsingFunction(const char* registration_function_name) { - ThrowOnError(GetApi().RegisterCustomOpsUsingFunction(this->p_, registration_function_name)); - return *this; -} - -/// Session -template -inline size_t ConstSessionImpl::GetInputCount() const { - size_t out; - ThrowOnError(GetApi().SessionGetInputCount(this->p_, &out)); - return out; -} - -template -inline size_t ConstSessionImpl::GetOutputCount() const { - size_t out; - ThrowOnError(GetApi().SessionGetOutputCount(this->p_, &out)); - return out; -} - -template -inline size_t ConstSessionImpl::GetOverridableInitializerCount() const { - size_t out; - ThrowOnError(GetApi().SessionGetOverridableInitializerCount(this->p_, &out)); - return out; -} - -template -inline AllocatedStringPtr ConstSessionImpl::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const { - char* out; - ThrowOnError(GetApi().SessionGetInputName(this->p_, index, allocator, &out)); - return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); -} - -template -inline AllocatedStringPtr ConstSessionImpl::GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const { - char* out; - ThrowOnError(GetApi().SessionGetOutputName(this->p_, index, allocator, &out)); - return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); -} - -template -inline AllocatedStringPtr ConstSessionImpl::GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const { - char* out; - ThrowOnError(GetApi().SessionGetOverridableInitializerName(this->p_, index, allocator, &out)); - return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); -} - -template -inline uint64_t ConstSessionImpl::GetProfilingStartTimeNs() const { - uint64_t out; - ThrowOnError(GetApi().SessionGetProfilingStartTimeNs(this->p_, &out)); - return out; -} - -template -inline ModelMetadata ConstSessionImpl::GetModelMetadata() const { - OrtModelMetadata* out; - ThrowOnError(GetApi().SessionGetModelMetadata(this->p_, &out)); - return ModelMetadata{out}; -} - -template -inline TypeInfo ConstSessionImpl::GetInputTypeInfo(size_t index) const { - OrtTypeInfo* out; - ThrowOnError(GetApi().SessionGetInputTypeInfo(this->p_, index, &out)); - return TypeInfo{out}; -} - -template -inline TypeInfo ConstSessionImpl::GetOutputTypeInfo(size_t index) const { - OrtTypeInfo* out; - ThrowOnError(GetApi().SessionGetOutputTypeInfo(this->p_, index, &out)); - return TypeInfo{out}; -} - -template -inline TypeInfo ConstSessionImpl::GetOverridableInitializerTypeInfo(size_t index) const { - OrtTypeInfo* out; - ThrowOnError(GetApi().SessionGetOverridableInitializerTypeInfo(this->p_, index, &out)); - return TypeInfo{out}; -} - -template -inline std::vector SessionImpl::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, - const char* const* output_names, size_t output_count) { - std::vector output_values; - output_values.reserve(output_count); - for (size_t i = 0; i < output_count; i++) - output_values.emplace_back(nullptr); - Run(run_options, input_names, input_values, input_count, output_names, output_values.data(), output_count); - return output_values; -} - -template -inline void SessionImpl::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, - const char* const* output_names, Value* output_values, size_t output_count) { - static_assert(sizeof(Value) == sizeof(OrtValue*), "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely"); - auto ort_input_values = reinterpret_cast(input_values); - auto ort_output_values = reinterpret_cast(output_values); - ThrowOnError(GetApi().Run(this->p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values)); -} - -template -inline void SessionImpl::Run(const RunOptions& run_options, const IoBinding& io_binding) { - ThrowOnError(GetApi().RunWithBinding(this->p_, run_options, io_binding)); -} - -template -inline void SessionImpl::RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, - const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data) { - auto ort_input_values = reinterpret_cast(input_values); - auto ort_output_values = reinterpret_cast(output_values); - ThrowOnError(GetApi().RunAsync(this->p_, run_options, input_names, - ort_input_values, input_count, output_names, output_count, - ort_output_values, callback, user_data)); -} - -template -inline AllocatedStringPtr SessionImpl::EndProfilingAllocated(OrtAllocator* allocator) { - char* out = nullptr; - ThrowOnError(GetApi().SessionEndProfiling(this->p_, allocator, &out)); - return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); -} - -} // namespace detail - -inline SessionOptions::SessionOptions() { - ThrowOnError(GetApi().CreateSessionOptions(&this->p_)); -} - -/// CustomOpConfigs -inline std::string detail::MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config) { - std::string config_key = "custom_op."; - - config_key += custom_op_name; - config_key += "."; - config_key += config; - - return config_key; -} - -inline CustomOpConfigs& CustomOpConfigs::AddConfig(const char* custom_op_name, const char* config_key, const char* config_value) { - const std::string full_flat_key = detail::MakeCustomOpConfigEntryKey(custom_op_name, config_key); - flat_configs_[full_flat_key] = config_value; - return *this; -} - -inline const std::unordered_map& CustomOpConfigs::GetFlattenedConfigs() const { - return flat_configs_; -} - -inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) { - ThrowOnError(GetApi().CreateSession(env, model_path, options, &this->p_)); -} - -inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options, - OrtPrepackedWeightsContainer* prepacked_weights_container) { - ThrowOnError(GetApi().CreateSessionWithPrepackedWeightsContainer(env, model_path, options, prepacked_weights_container, &this->p_)); -} - -inline Session::Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) { - ThrowOnError(GetApi().CreateSessionFromArray(env, model_data, model_data_length, options, &this->p_)); -} - -inline Session::Session(const Env& env, const void* model_data, size_t model_data_length, - const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container) { - ThrowOnError(GetApi().CreateSessionFromArrayWithPrepackedWeightsContainer(env, model_data, model_data_length, options, - prepacked_weights_container, &this->p_)); -} - -inline AllocatedStringPtr ModelMetadata::GetProducerNameAllocated(OrtAllocator* allocator) const { - char* out; - ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out)); - return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); -} - -inline AllocatedStringPtr ModelMetadata::GetGraphNameAllocated(OrtAllocator* allocator) const { - char* out; - ThrowOnError(GetApi().ModelMetadataGetGraphName(p_, allocator, &out)); - return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); -} - -inline AllocatedStringPtr ModelMetadata::GetDomainAllocated(OrtAllocator* allocator) const { - char* out; - ThrowOnError(GetApi().ModelMetadataGetDomain(p_, allocator, &out)); - return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); -} - -inline AllocatedStringPtr Ort::ModelMetadata::GetDescriptionAllocated(OrtAllocator* allocator) const { - char* out; - ThrowOnError(GetApi().ModelMetadataGetDescription(p_, allocator, &out)); - return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); -} - -inline AllocatedStringPtr ModelMetadata::GetGraphDescriptionAllocated(OrtAllocator* allocator) const { - char* out; - ThrowOnError(GetApi().ModelMetadataGetGraphDescription(p_, allocator, &out)); - return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); -} - -inline AllocatedStringPtr ModelMetadata::LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const { - char* out; - ThrowOnError(GetApi().ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out)); - return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); -} - -inline std::vector ModelMetadata::GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const { - auto deletor = detail::AllocatedFree(allocator); - std::vector result; - - char** out = nullptr; - int64_t num_keys = 0; - ThrowOnError(GetApi().ModelMetadataGetCustomMetadataMapKeys(p_, allocator, &out, &num_keys)); - if (num_keys <= 0) { - return result; - } - - // array of pointers will be freed - std::unique_ptr array_guard(out, deletor); - // reserve may throw - auto strings_deletor = [&deletor, num_keys](char** out) { for(int64_t i = 0; i < num_keys; ++i) deletor(out[i]); }; - std::unique_ptr strings_guard(out, strings_deletor); - result.reserve(static_cast(num_keys)); - strings_guard.release(); - for (int64_t i = 0; i < num_keys; ++i) { - result.push_back(AllocatedStringPtr(out[i], deletor)); - } - - return result; -} - -inline int64_t ModelMetadata::GetVersion() const { - int64_t out; - ThrowOnError(GetApi().ModelMetadataGetVersion(p_, &out)); - return out; -} - -namespace detail { - -template -inline ONNXTensorElementDataType TensorTypeAndShapeInfoImpl::GetElementType() const { - ONNXTensorElementDataType out; - ThrowOnError(GetApi().GetTensorElementType(this->p_, &out)); - return out; -} - -template -inline size_t TensorTypeAndShapeInfoImpl::GetElementCount() const { - size_t out; - ThrowOnError(GetApi().GetTensorShapeElementCount(this->p_, &out)); - return static_cast(out); -} - -template -inline size_t TensorTypeAndShapeInfoImpl::GetDimensionsCount() const { - size_t out; - ThrowOnError(GetApi().GetDimensionsCount(this->p_, &out)); - return out; -} - -template -inline void TensorTypeAndShapeInfoImpl::GetDimensions(int64_t* values, size_t values_count) const { - ThrowOnError(GetApi().GetDimensions(this->p_, values, values_count)); -} - -template -inline void TensorTypeAndShapeInfoImpl::GetSymbolicDimensions(const char** values, size_t values_count) const { - ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, values, values_count)); -} - -template -inline std::vector TensorTypeAndShapeInfoImpl::GetShape() const { - std::vector out(GetDimensionsCount(), 0); - ThrowOnError(GetApi().GetDimensions(this->p_, out.data(), out.size())); - return out; -} - -template -inline ConstTensorTypeAndShapeInfo TypeInfoImpl::GetTensorTypeAndShapeInfo() const { - const OrtTensorTypeAndShapeInfo* out; - ThrowOnError(GetApi().CastTypeInfoToTensorInfo(this->p_, &out)); - return ConstTensorTypeAndShapeInfo{out}; -} - -template -inline ConstSequenceTypeInfo TypeInfoImpl::GetSequenceTypeInfo() const { - const OrtSequenceTypeInfo* out; - ThrowOnError(GetApi().CastTypeInfoToSequenceTypeInfo(this->p_, &out)); - return ConstSequenceTypeInfo{out}; -} - -template -inline ConstMapTypeInfo TypeInfoImpl::GetMapTypeInfo() const { - const OrtMapTypeInfo* out; - ThrowOnError(GetApi().CastTypeInfoToMapTypeInfo(this->p_, &out)); - return ConstMapTypeInfo{out}; -} - -template -inline ONNXType TypeInfoImpl::GetONNXType() const { - ONNXType out; - ThrowOnError(GetApi().GetOnnxTypeFromTypeInfo(this->p_, &out)); - return out; -} - -template -inline TypeInfo SequenceTypeInfoImpl::GetSequenceElementType() const { - OrtTypeInfo* output; - ThrowOnError(GetApi().GetSequenceElementType(this->p_, &output)); - return TypeInfo{output}; -} - -template -inline TypeInfo OptionalTypeInfoImpl::GetOptionalElementType() const { - OrtTypeInfo* info; - ThrowOnError(GetApi().GetOptionalContainedTypeInfo(this->p_, &info)); - return TypeInfo{info}; -} - -template -inline ONNXTensorElementDataType MapTypeInfoImpl::GetMapKeyType() const { - ONNXTensorElementDataType out; - ThrowOnError(GetApi().GetMapKeyType(this->p_, &out)); - return out; -} - -template -inline TypeInfo MapTypeInfoImpl::GetMapValueType() const { - OrtTypeInfo* output; - ThrowOnError(GetApi().GetMapValueType(this->p_, &output)); - return TypeInfo{output}; -} - -template -inline ConstOptionalTypeInfo TypeInfoImpl::GetOptionalTypeInfo() const { - const OrtOptionalTypeInfo* info; - ThrowOnError(GetApi().CastTypeInfoToOptionalTypeInfo(this->p_, &info)); - return ConstOptionalTypeInfo{info}; -} - -} // namespace detail - -namespace detail { - -template -template -inline void ConstValueImpl::GetOpaqueData(const char* domain, const char* type_name, R& out) const { - ThrowOnError(GetApi().GetOpaqueValue(domain, type_name, this->p_, &out, sizeof(R))); -} - -template -inline bool ConstValueImpl::IsTensor() const { - int out; - ThrowOnError(GetApi().IsTensor(this->p_, &out)); - return out != 0; -} - -template -inline bool ConstValueImpl::HasValue() const { - int out; - ThrowOnError(GetApi().HasValue(this->p_, &out)); - return out != 0; -} - -template -inline size_t ConstValueImpl::GetCount() const { - size_t out; - ThrowOnError(GetApi().GetValueCount(this->p_, &out)); - return out; -} - -template -inline Value ConstValueImpl::GetValue(int index, OrtAllocator* allocator) const { - OrtValue* out; - ThrowOnError(GetApi().GetValue(this->p_, index, allocator, &out)); - return Value{out}; -} - -template -inline size_t ConstValueImpl::GetStringTensorDataLength() const { - size_t out; - ThrowOnError(GetApi().GetStringTensorDataLength(this->p_, &out)); - return out; -} - -template -inline size_t ConstValueImpl::GetStringTensorElementLength(size_t element_index) const { - size_t out; - ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &out)); - return out; -} - -template -template -inline const R* ConstValueImpl::GetTensorData() const { - R* out; - ThrowOnError(GetApi().GetTensorMutableData(const_cast(this->p_), (void**)&out)); - return out; -} - -template -inline const void* ConstValueImpl::GetTensorRawData() const { - void* out; - ThrowOnError(GetApi().GetTensorMutableData(const_cast(this->p_), &out)); - return out; -} - -template -inline TypeInfo ConstValueImpl::GetTypeInfo() const { - OrtTypeInfo* output; - ThrowOnError(GetApi().GetTypeInfo(this->p_, &output)); - return TypeInfo{output}; -} - -template -inline TensorTypeAndShapeInfo ConstValueImpl::GetTensorTypeAndShapeInfo() const { - OrtTensorTypeAndShapeInfo* output; - ThrowOnError(GetApi().GetTensorTypeAndShape(this->p_, &output)); - return TensorTypeAndShapeInfo{output}; -} - -template -inline ConstMemoryInfo ConstValueImpl::GetTensorMemoryInfo() const { - const OrtMemoryInfo* mem_info; - ThrowOnError(GetApi().GetTensorMemoryInfo(this->p_, &mem_info)); - return ConstMemoryInfo(mem_info); -} - -template -inline void ConstValueImpl::GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const { - ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, buffer)); -} - -template -inline std::string ConstValueImpl::GetStringTensorElement(size_t element_index) const { - size_t buffer_length; - ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &buffer_length)); - - std::string s; - s.resize(buffer_length); - ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, &s[0])); - return s; -} - -template -inline void ConstValueImpl::GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const { - ThrowOnError(GetApi().GetStringTensorContent(this->p_, buffer, buffer_length, offsets, offsets_count)); -} - -#if !defined(DISABLE_SPARSE_TENSORS) -template -inline OrtSparseFormat ConstValueImpl::GetSparseFormat() const { - OrtSparseFormat format; - ThrowOnError(GetApi().GetSparseTensorFormat(this->p_, &format)); - return format; -} - -template -inline TensorTypeAndShapeInfo ConstValueImpl::GetSparseTensorValuesTypeAndShapeInfo() const { - OrtTensorTypeAndShapeInfo* output; - ThrowOnError(GetApi().GetSparseTensorValuesTypeAndShape(this->p_, &output)); - return TensorTypeAndShapeInfo{output}; -} - -template -inline TensorTypeAndShapeInfo ConstValueImpl::GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat indices_format) const { - OrtTensorTypeAndShapeInfo* output; - ThrowOnError(GetApi().GetSparseTensorIndicesTypeShape(this->p_, indices_format, &output)); - return TensorTypeAndShapeInfo{output}; -} - -template -template -inline const R* ConstValueImpl::GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const { - const void* out; - ThrowOnError(GetApi().GetSparseTensorIndices(this->p_, indices_format, &num_indices, &out)); - return reinterpret_cast(out); -} - -template -inline bool ConstValueImpl::IsSparseTensor() const { - int out; - ThrowOnError(GetApi().IsSparseTensor(this->p_, &out)); - return out != 0; -} - -template -template -inline const R* ConstValueImpl::GetSparseTensorValues() const { - const void* out; - ThrowOnError(GetApi().GetSparseTensorValues(this->p_, &out)); - return reinterpret_cast(out); -} - -#endif - -template -void ValueImpl::FillStringTensor(const char* const* s, size_t s_len) { - ThrowOnError(GetApi().FillStringTensor(this->p_, s, s_len)); -} - -template -void ValueImpl::FillStringTensorElement(const char* s, size_t index) { - ThrowOnError(GetApi().FillStringTensorElement(this->p_, s, index)); -} - -template -inline char* ValueImpl::GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length) { - char* result; - ThrowOnError(GetApi().GetResizedStringTensorElementBuffer(this->p_, index, buffer_length, &result)); - return result; -} - -template -void* ValueImpl::GetTensorMutableRawData() { - void* out; - ThrowOnError(GetApi().GetTensorMutableData(this->p_, &out)); - return out; -} - -template -template -R* ValueImpl::GetTensorMutableData() { - R* out; - ThrowOnError(GetApi().GetTensorMutableData(this->p_, (void**)&out)); - return out; -} - -template -template -R& ValueImpl::At(const std::vector& location) { - static_assert(!std::is_same::value, "this api does not support std::string"); - R* out; - ThrowOnError(GetApi().TensorAt(this->p_, location.data(), location.size(), (void**)&out)); - return *out; -} - -#if !defined(DISABLE_SPARSE_TENSORS) -template -void ValueImpl::UseCooIndices(int64_t* indices_data, size_t indices_num) { - ThrowOnError(GetApi().UseCooIndices(this->p_, indices_data, indices_num)); -} - -template -void ValueImpl::UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num) { - ThrowOnError(GetApi().UseCsrIndices(this->p_, inner_data, inner_num, outer_data, outer_num)); -} - -template -void ValueImpl::UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data) { - ThrowOnError(GetApi().UseBlockSparseIndices(this->p_, indices_shape.shape, indices_shape.shape_len, indices_data)); -} - -template -void ValueImpl::FillSparseTensorCoo(const OrtMemoryInfo* mem_info, const OrtSparseValuesParam& values_param, - const int64_t* indices_data, size_t indices_num) { - ThrowOnError(GetApi().FillSparseTensorCoo(this->p_, mem_info, values_param.values_shape, - values_param.values_shape_len, values_param.data.p_data, - indices_data, indices_num)); -} - -template -void ValueImpl::FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info, - const OrtSparseValuesParam& values, - const int64_t* inner_indices_data, size_t inner_indices_num, - const int64_t* outer_indices_data, size_t outer_indices_num) { - ThrowOnError(GetApi().FillSparseTensorCsr(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data, - inner_indices_data, inner_indices_num, - outer_indices_data, outer_indices_num)); -} - -template -void ValueImpl::FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info, - const OrtSparseValuesParam& values, - const Shape& indices_shape, - const int32_t* indices_data) { - ThrowOnError(GetApi().FillSparseTensorBlockSparse(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data, - indices_shape.shape, indices_shape.shape_len, - indices_data)); -} - -#endif // !defined(DISABLE_SPARSE_TENSORS) - -} // namespace detail - -template -inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) { - return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType::type); -} - -inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len, - ONNXTensorElementDataType type) { - OrtValue* out; - ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out)); - return Value{out}; -} - -template -inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) { - return CreateTensor(allocator, shape, shape_len, TypeToTensorType::type); -} - -inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) { - OrtValue* out; - ThrowOnError(GetApi().CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out)); - return Value{out}; -} - -#if !defined(DISABLE_SPARSE_TENSORS) - -template -inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape, - const Shape& values_shape) { - return CreateSparseTensor(info, p_data, dense_shape, values_shape, TypeToTensorType::type); -} - -inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape, - const Shape& values_shape, ONNXTensorElementDataType type) { - OrtValue* out; - ThrowOnError(GetApi().CreateSparseTensorWithValuesAsOrtValue(info, p_data, dense_shape.shape, dense_shape.shape_len, - values_shape.shape, values_shape.shape_len, type, &out)); - return Value{out}; -} - -template -inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape) { - return CreateSparseTensor(allocator, dense_shape, TypeToTensorType::type); -} - -inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape, - ONNXTensorElementDataType type) { - OrtValue* out; - ThrowOnError(GetApi().CreateSparseTensorAsOrtValue(allocator, dense_shape.shape, dense_shape.shape_len, type, &out)); - return Value{out}; -} -#endif // !defined(DISABLE_SPARSE_TENSORS) - -inline Value Value::CreateMap(const Value& keys, const Value& values) { - OrtValue* out; - const OrtValue* inputs[2] = {keys, values}; - ThrowOnError(GetApi().CreateValue(inputs, 2, ONNX_TYPE_MAP, &out)); - return Value{out}; -} - -inline Value Value::CreateSequence(const std::vector& values) { - OrtValue* out; - std::vector values_ort{values.data(), values.data() + values.size()}; - ThrowOnError(GetApi().CreateValue(values_ort.data(), values_ort.size(), ONNX_TYPE_SEQUENCE, &out)); - return Value{out}; -} - -template -inline Value Value::CreateOpaque(const char* domain, const char* type_name, const T& data_container) { - OrtValue* out; - ThrowOnError(GetApi().CreateOpaqueValue(domain, type_name, &data_container, sizeof(T), &out)); - return Value{out}; -} - -// -// Custom OP Inlines -// -inline Logger::Logger(const OrtLogger* logger) : logger_(logger) { - Ort::ThrowOnError(GetApi().Logger_GetLoggingSeverityLevel(this->logger_, &this->cached_severity_level_)); -} - -inline OrtLoggingLevel Logger::GetLoggingSeverityLevel() const noexcept { - return cached_severity_level_; -} - -inline Status Logger::LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number, - const char* func_name, const char* message) const noexcept { - OrtStatus* status = GetApi().Logger_LogMessage(logger_, log_severity_level, message, file_path, line_number, - func_name); - return Status{status}; -} - -// Disable warnings about the format string not being a literal (-Wformat-nonliteral and -Wformat-security) -// for gcc and clang. The alternative is to use actual C-style variadic parameters and apply -// __attribute__(format(printf...)), which does not work with variadic templates. -#if defined(__GNUC__) -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wformat-nonliteral" -#pragma GCC diagnostic ignored "-Wformat-security" -#elif defined(__clang__) -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wformat-nonliteral" -#pragma clang diagnostic ignored "-Wformat-security" -#endif -template -inline Status Logger::LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, - int line_number, const char* func_name, const char* format, - Args&&... args) const noexcept { - int msg_len = std::snprintf(nullptr, 0U, format, std::forward(args)...); - - if (msg_len < 0) { // Formatting error - return Status("Failed to log message due to formatting error", OrtErrorCode::ORT_FAIL); - } - - OrtStatus* status = nullptr; - const size_t buffer_size = static_cast(msg_len) + 1U; - - constexpr size_t kStackBufferSize = 1024; - - if (buffer_size < kStackBufferSize) { - char buffer[kStackBufferSize]; - snprintf(buffer, kStackBufferSize, format, std::forward(args)...); - status = GetApi().Logger_LogMessage(logger_, log_severity_level, buffer, file_path, line_number, func_name); - } else { - // std::make_unique is only supported starting at C++14. -#if (__cplusplus >= 201402L) || (_MSC_VER >= 1900) - auto buffer = std::make_unique(buffer_size); -#else - std::unique_ptr buffer(new char[buffer_size]); -#endif - std::snprintf(buffer.get(), buffer_size, format, std::forward(args)...); - status = GetApi().Logger_LogMessage(logger_, log_severity_level, buffer.get(), file_path, line_number, func_name); - } - - return Status{status}; -} -// Re-enable -Wformat-nonliteral and -Wformat-security -#if defined(__GNUC__) -#pragma GCC diagnostic pop -#elif defined(__clang__) -#pragma clang diagnostic pop -#endif - -inline KernelContext::KernelContext(OrtKernelContext* context) : ctx_(context) { -} - -inline size_t KernelContext::GetInputCount() const { - size_t out = 0; - Ort::ThrowOnError(GetApi().KernelContext_GetInputCount(ctx_, &out)); - return out; -} - -inline size_t KernelContext::GetOutputCount() const { - size_t out = 0; - Ort::ThrowOnError(GetApi().KernelContext_GetOutputCount(ctx_, &out)); - return out; -} - -inline ConstValue KernelContext::GetInput(size_t index) const { - const OrtValue* out = nullptr; - Ort::ThrowOnError(GetApi().KernelContext_GetInput(ctx_, index, &out)); - return ConstValue{out}; -} - -inline UnownedValue KernelContext::GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const { - OrtValue* out = nullptr; - Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dim_values, dim_count, &out)); - return UnownedValue(out); -} - -inline UnownedValue KernelContext::GetOutput(size_t index, const std::vector& dims) const { - OrtValue* out = nullptr; - Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dims.data(), dims.size(), &out)); - return UnownedValue(out); -} - -inline void* KernelContext::GetGPUComputeStream() const { - void* out = nullptr; - Ort::ThrowOnError(GetApi().KernelContext_GetGPUComputeStream(ctx_, &out)); - return out; -} - -inline OrtAllocator* KernelContext::GetAllocator(const OrtMemoryInfo& memory_info) const { - OrtAllocator* out = nullptr; - Ort::ThrowOnError(GetApi().KernelContext_GetAllocator(ctx_, &memory_info, &out)); - return out; -} - -inline Logger KernelContext::GetLogger() const { - const OrtLogger* out = nullptr; - ThrowOnError(GetApi().KernelContext_GetLogger(this->ctx_, &out)); - return Logger{out}; -} - -inline void KernelContext::ParallelFor(void (*fn)(void*, size_t), size_t total, size_t num_batch, void* usr_data) const { - ThrowOnError(GetApi().KernelContext_ParallelFor(ctx_, fn, total, num_batch, usr_data)); -} - -inline OpAttr::OpAttr(const char* name, const void* data, int len, OrtOpAttrType type) { - Ort::ThrowOnError(GetApi().CreateOpAttr(name, data, len, type, &p_)); -} - -namespace detail { -template -inline KernelInfo KernelInfoImpl::Copy() const { - OrtKernelInfo* info_copy = nullptr; - Ort::ThrowOnError(GetApi().CopyKernelInfo(this->p_, &info_copy)); - return KernelInfo{info_copy}; -} - -template -inline size_t KernelInfoImpl::GetInputCount() const { - size_t out = 0; - ThrowOnError(GetApi().KernelInfo_GetInputCount(this->p_, &out)); - return out; -} - -template -inline size_t KernelInfoImpl::GetOutputCount() const { - size_t out = 0; - ThrowOnError(GetApi().KernelInfo_GetOutputCount(this->p_, &out)); - return out; -} - -template -inline std::string KernelInfoImpl::GetInputName(size_t index) const { - size_t size = 0; - - // Feed nullptr for the data buffer to query the true size of the string value - Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, nullptr, &size)); - - std::string out; - out.resize(size); - Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, &out[0], &size)); - out.resize(size - 1); // remove the terminating character '\0' - - return out; -} - -template -inline std::string KernelInfoImpl::GetOutputName(size_t index) const { - size_t size = 0; - - // Feed nullptr for the data buffer to query the true size of the string value - Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, nullptr, &size)); - - std::string out; - out.resize(size); - Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, &out[0], &size)); - out.resize(size - 1); // remove the terminating character '\0' - - return out; -} - -template -inline TypeInfo KernelInfoImpl::GetInputTypeInfo(size_t index) const { - OrtTypeInfo* out = nullptr; - ThrowOnError(GetApi().KernelInfo_GetInputTypeInfo(this->p_, index, &out)); - return TypeInfo{out}; -} - -template -inline TypeInfo KernelInfoImpl::GetOutputTypeInfo(size_t index) const { - OrtTypeInfo* out = nullptr; - ThrowOnError(GetApi().KernelInfo_GetOutputTypeInfo(this->p_, index, &out)); - return TypeInfo{out}; -} - -template -inline Value KernelInfoImpl::GetTensorAttribute(const char* name, OrtAllocator* allocator) const { - OrtValue* out = nullptr; - ThrowOnError(GetApi().KernelInfoGetAttribute_tensor(this->p_, name, allocator, &out)); - return Value{out}; -} - -template -inline ConstValue KernelInfoImpl::GetTensorConstantInput(size_t index, int* is_constant) const { - const OrtValue* out = nullptr; - ThrowOnError(GetApi().KernelInfoGetConstantInput_tensor(this->p_, index, is_constant, &out)); - return ConstValue{out}; -} - -template -inline std::string KernelInfoImpl::GetNodeName() const { - size_t size = 0; - - // Feed nullptr for the data buffer to query the true size of the string value - Ort::ThrowOnError(GetApi().KernelInfo_GetNodeName(this->p_, nullptr, &size)); - - std::string out; - out.resize(size); - Ort::ThrowOnError(GetApi().KernelInfo_GetNodeName(this->p_, &out[0], &size)); - out.resize(size - 1); // remove the terminating character '\0' - - return out; -} - -template -inline Logger KernelInfoImpl::GetLogger() const { - const OrtLogger* out = nullptr; - ThrowOnError(GetApi().KernelInfo_GetLogger(this->p_, &out)); - return Logger{out}; -} - -inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, float& out) { - Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_float(p, name, &out)); -} - -inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, int64_t& out) { - Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_int64(p, name, &out)); -} - -inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, std::string& result) { - size_t size = 0; - // Feed nullptr for the data buffer to query the true size of the string attribute - Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, nullptr, &size)); - - std::string out; - out.resize(size); - Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, &out[0], &size)); - out.resize(size - 1); // remove the terminating character '\0' - out.swap(result); -} - -inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector& result) { - size_t size = 0; - // Feed nullptr for the data buffer to query the true size of the attribute - Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, nullptr, &size)); - - std::vector out; - out.resize(size); - Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, out.data(), &size)); - out.swap(result); -} - -inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector& result) { - size_t size = 0; - - // Feed nullptr for the data buffer to query the true size of the attribute - Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, nullptr, &size)); - - std::vector out; - out.resize(size); - Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, out.data(), &size)); - out.swap(result); -} -} // namespace detail - -inline KernelInfo::KernelInfo(OrtKernelInfo* info) : detail::KernelInfoImpl{info} {} - -inline Op::Op(OrtOp* p) : Base(p) {} - -inline Op Op::Create(const OrtKernelInfo* info, const char* op_name, const char* domain, int version, - const char** type_constraint_names, - const ONNXTensorElementDataType* type_constraint_values, - size_t type_constraint_count, - const OpAttr* attr_values, size_t attr_count, - size_t input_count, size_t output_count) { - static_assert(sizeof(OpAttr) == sizeof(OrtOpAttr*), - "OpAttr's is expected to be just an array of OrtOpAttr in memory so we can reinterpret safely"); - auto attr_input_values = reinterpret_cast(attr_values); - OrtOp* op; - Ort::ThrowOnError(GetApi().CreateOp(info, op_name, domain, version, type_constraint_names, type_constraint_values, - static_cast(type_constraint_count), - attr_input_values, - static_cast(attr_count), - static_cast(input_count), - static_cast(output_count), &op)); - return Op{op}; -} - -inline void Op::Invoke(const OrtKernelContext* context, - const Value* input_values, - size_t input_count, - Value* output_values, - size_t output_count) { - static_assert(sizeof(Value) == sizeof(OrtValue*), - "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely"); - auto ort_input_values = reinterpret_cast(input_values); - auto ort_output_values = reinterpret_cast(output_values); - Ort::ThrowOnError(GetApi().InvokeOp(context, p_, ort_input_values, static_cast(input_count), - ort_output_values, static_cast(output_count))); -} - -inline void Op::Invoke(const OrtKernelContext* context, - const OrtValue* const* input_values, - size_t input_count, - OrtValue* const* output_values, - size_t output_count) { - Ort::ThrowOnError(GetApi().InvokeOp(context, p_, input_values, static_cast(input_count), - output_values, static_cast(output_count))); -} - -inline std::string GetVersionString() { - return OrtGetApiBase()->GetVersionString(); -} - -inline std::string GetBuildInfoString() { - return GetApi().GetBuildInfoString(); -} - -inline std::vector GetAvailableProviders() { - char** providers; - int len; - - auto release_fn = [&len](char** providers) { - // This should always return nullptr. - ThrowOnError(GetApi().ReleaseAvailableProviders(providers, len)); - }; - - ThrowOnError(GetApi().GetAvailableProviders(&providers, &len)); - std::unique_ptr guard(providers, release_fn); - std::vector available_providers; - available_providers.reserve(static_cast(len)); - for (int i = 0; i < len; ++i) { - available_providers.emplace_back(providers[i]); - } - return available_providers; -} - -template -void CustomOpBase::GetSessionConfigs(std::unordered_map& out, - ConstSessionOptions options) const { - const TOp* derived = static_cast(this); - std::vector keys = derived->GetSessionConfigKeys(); - - out.reserve(keys.size()); - - std::string config_entry_key = detail::MakeCustomOpConfigEntryKey(derived->GetName(), ""); - const size_t prefix_size = config_entry_key.length(); - - for (const auto& key : keys) { - config_entry_key.resize(prefix_size); - config_entry_key.append(key); - out[key] = options.GetConfigEntryOrDefault(config_entry_key.c_str(), ""); - } -} - -inline ShapeInferContext::ShapeInferContext(const OrtApi* ort_api, - OrtShapeInferContext* ctx) : ort_api_(ort_api), ctx_(ctx) { - size_t input_count = 0; - Ort::ThrowOnError(ort_api_->ShapeInferContext_GetInputCount(ctx_, &input_count)); - for (size_t ith_input = 0; ith_input < input_count; ++ith_input) { - OrtTensorTypeAndShapeInfo* info{}; - Ort::ThrowOnError(ort_api_->ShapeInferContext_GetInputTypeShape(ctx, ith_input, &info)); - TensorTypeAndShapeInfo type_shape_info(info); - auto integer_shape = type_shape_info.GetShape(); - std::vector symbolic_shape(integer_shape.size(), {}); - if (!integer_shape.empty()) { - type_shape_info.GetSymbolicDimensions(&symbolic_shape[0], integer_shape.size()); - } - Shape shape; - for (size_t ith = 0; ith < integer_shape.size(); ++ith) { - if (symbolic_shape[ith] && std::string{symbolic_shape[ith]}.size() > 0) { - shape.emplace_back(symbolic_shape[ith]); - } else { - shape.emplace_back(integer_shape[ith]); - } - } - input_shapes_.push_back(std::move(shape)); - type_shape_info.release(); - } -} - -inline Status ShapeInferContext::SetOutputShape(size_t indice, const Shape& shape, ONNXTensorElementDataType type) { - OrtTensorTypeAndShapeInfo* info = {}; - ORT_CXX_RETURN_ON_API_FAIL(ort_api_->CreateTensorTypeAndShapeInfo(&info)); - ORT_CXX_RETURN_ON_API_FAIL(ort_api_->SetTensorElementType(info, type)); - - using InfoPtr = std::unique_ptr>; - - InfoPtr info_ptr(info, [this](OrtTensorTypeAndShapeInfo* obj) { - ort_api_->ReleaseTensorTypeAndShapeInfo(obj); - }); - - std::vector integer_dims; - std::vector symbolic_dims; - - for (const auto dim : shape) { - if (dim.IsInt()) { - integer_dims.push_back(dim.AsInt()); - symbolic_dims.push_back(""); - } else { - if (!dim.AsSym() || std::string{dim.AsSym()}.empty()) { - ORT_CXX_API_THROW("Symbolic dim must not be an empty string", ORT_INVALID_ARGUMENT); - } - integer_dims.push_back(SymbolicInteger::INVALID_INT_DIM); - symbolic_dims.push_back(dim.AsSym()); - } - } - - ORT_CXX_RETURN_ON_API_FAIL(ort_api_->SetDimensions(info, integer_dims.data(), integer_dims.size())); - ORT_CXX_RETURN_ON_API_FAIL(ort_api_->SetSymbolicDimensions(info, symbolic_dims.data(), symbolic_dims.size())); - ORT_CXX_RETURN_ON_API_FAIL(ort_api_->ShapeInferContext_SetOutputTypeShape(ctx_, indice, info)); - return Status{nullptr}; -} - -inline int64_t ShapeInferContext::GetAttrInt(const char* attr_name) { - const auto* attr = GetAttrHdl(attr_name); - int64_t i = {}; - size_t out = {}; - Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INT, &i, sizeof(i), &out)); - return i; -} - -inline ShapeInferContext::Ints ShapeInferContext::GetAttrInts(const char* attr_name) { - const auto* attr = GetAttrHdl(attr_name); - int64_t i = {}; - size_t out = {}; - // first call to get the bytes needed - auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, &i, sizeof(i), &out); - if (status) { - size_t num_i = out / sizeof(int64_t); - ShapeInferContext::Ints ints(num_i, 0); - Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, ints.data(), out, &out)); - return ints; - } else { - return {i}; - } -} - -inline float ShapeInferContext::GetAttrFloat(const char* attr_name) { - const auto* attr = GetAttrHdl(attr_name); - float f = {}; - size_t out = {}; - Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOAT, &f, sizeof(f), &out)); - return f; -} - -inline ShapeInferContext::Floats ShapeInferContext::GetAttrFloats(const char* attr_name) { - const auto* attr = GetAttrHdl(attr_name); - float f = {}; - size_t out = {}; - // first call to get the bytes needed - auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, &f, sizeof(f), &out); - if (status) { - size_t num_f = out / sizeof(float); - ShapeInferContext::Floats floats(num_f, 0); - Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, floats.data(), out, &out)); - return floats; - } else { - return {f}; - } -} - -inline std::string ShapeInferContext::GetAttrString(const char* attr_name) { - const auto* attr = GetAttrHdl(attr_name); - char c = {}; - size_t out = {}; - // first call to get the bytes needed - auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, &c, sizeof(char), &out); - if (status) { - std::vector chars(out, '\0'); - Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, chars.data(), out, &out)); - return {chars.data()}; - } else { - return {c}; - } -} - -inline ShapeInferContext::Strings ShapeInferContext::GetAttrStrings(const char* attr_name) { - const auto* attr = GetAttrHdl(attr_name); - char c = {}; - size_t out = {}; - // first call to get the bytes needed - auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, &c, sizeof(char), &out); - if (status) { - std::vector chars(out, '\0'); - Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, chars.data(), out, &out)); - ShapeInferContext::Strings strings; - char* char_st = chars.data(); - char* char_ed = char_st + out; - while (char_st < char_ed) { - strings.emplace_back(char_st); - while (*char_st != '\0') { - char_st++; - } - char_st++; - } - return strings; - } else { - return {std::string{c}}; - } -} - -inline const OrtOpAttr* ShapeInferContext::GetAttrHdl(const char* attr_name) const { - const OrtOpAttr* attr_hdl = {}; - Ort::ThrowOnError(ort_api_->ShapeInferContext_GetAttribute(ctx_, attr_name, &attr_hdl)); - return attr_hdl; -} - -} // namespace Ort diff --git a/tools/onnx_lib/Source/include_rel/onnxruntime_float16.h b/tools/onnx_lib/Source/include_rel/onnxruntime_float16.h deleted file mode 100644 index 0b066a9cc9..0000000000 --- a/tools/onnx_lib/Source/include_rel/onnxruntime_float16.h +++ /dev/null @@ -1,540 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include -#include - -namespace onnxruntime_float16 { - -namespace detail { - -enum class endian { -#if defined(_WIN32) - little = 0, - big = 1, - native = little, -#elif defined(__GNUC__) || defined(__clang__) - little = __ORDER_LITTLE_ENDIAN__, - big = __ORDER_BIG_ENDIAN__, - native = __BYTE_ORDER__, -#else -#error onnxruntime_float16::detail::endian is not implemented in this environment. -#endif -}; - -static_assert( - endian::native == endian::little || endian::native == endian::big, - "Only little-endian or big-endian native byte orders are supported."); - -} // namespace detail - -/// -/// Shared implementation between public and internal classes. CRTP pattern. -/// -template -struct Float16Impl { - protected: - /// - /// Converts from float to uint16_t float16 representation - /// - /// - /// - constexpr static uint16_t ToUint16Impl(float v) noexcept; - - /// - /// Converts float16 to float - /// - /// float representation of float16 value - float ToFloatImpl() const noexcept; - - /// - /// Creates an instance that represents absolute value. - /// - /// Absolute value - uint16_t AbsImpl() const noexcept { - return static_cast(val & ~kSignMask); - } - - /// - /// Creates a new instance with the sign flipped. - /// - /// Flipped sign instance - uint16_t NegateImpl() const noexcept { - return IsNaN() ? val : static_cast(val ^ kSignMask); - } - - public: - // uint16_t special values - static constexpr uint16_t kSignMask = 0x8000U; - static constexpr uint16_t kBiasedExponentMask = 0x7C00U; - static constexpr uint16_t kPositiveInfinityBits = 0x7C00U; - static constexpr uint16_t kNegativeInfinityBits = 0xFC00U; - static constexpr uint16_t kPositiveQNaNBits = 0x7E00U; - static constexpr uint16_t kNegativeQNaNBits = 0xFE00U; - static constexpr uint16_t kEpsilonBits = 0x4170U; - static constexpr uint16_t kMinValueBits = 0xFBFFU; // Minimum normal number - static constexpr uint16_t kMaxValueBits = 0x7BFFU; // Largest normal number - static constexpr uint16_t kOneBits = 0x3C00U; - static constexpr uint16_t kMinusOneBits = 0xBC00U; - - uint16_t val{0}; - - Float16Impl() = default; - - /// - /// Checks if the value is negative - /// - /// true if negative - bool IsNegative() const noexcept { - return static_cast(val) < 0; - } - - /// - /// Tests if the value is NaN - /// - /// true if NaN - bool IsNaN() const noexcept { - return AbsImpl() > kPositiveInfinityBits; - } - - /// - /// Tests if the value is finite - /// - /// true if finite - bool IsFinite() const noexcept { - return AbsImpl() < kPositiveInfinityBits; - } - - /// - /// Tests if the value represents positive infinity. - /// - /// true if positive infinity - bool IsPositiveInfinity() const noexcept { - return val == kPositiveInfinityBits; - } - - /// - /// Tests if the value represents negative infinity - /// - /// true if negative infinity - bool IsNegativeInfinity() const noexcept { - return val == kNegativeInfinityBits; - } - - /// - /// Tests if the value is either positive or negative infinity. - /// - /// True if absolute value is infinity - bool IsInfinity() const noexcept { - return AbsImpl() == kPositiveInfinityBits; - } - - /// - /// Tests if the value is NaN or zero. Useful for comparisons. - /// - /// True if NaN or zero. - bool IsNaNOrZero() const noexcept { - auto abs = AbsImpl(); - return (abs == 0 || abs > kPositiveInfinityBits); - } - - /// - /// Tests if the value is normal (not zero, subnormal, infinite, or NaN). - /// - /// True if so - bool IsNormal() const noexcept { - auto abs = AbsImpl(); - return (abs < kPositiveInfinityBits) // is finite - && (abs != 0) // is not zero - && ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent) - } - - /// - /// Tests if the value is subnormal (denormal). - /// - /// True if so - bool IsSubnormal() const noexcept { - auto abs = AbsImpl(); - return (abs < kPositiveInfinityBits) // is finite - && (abs != 0) // is not zero - && ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent) - } - - /// - /// Creates an instance that represents absolute value. - /// - /// Absolute value - Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); } - - /// - /// Creates a new instance with the sign flipped. - /// - /// Flipped sign instance - Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); } - - /// - /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check - /// for two values by or'ing the private bits together and stripping the sign. They are both zero, - /// and therefore equivalent, if the resulting value is still zero. - /// - /// first value - /// second value - /// True if both arguments represent zero - static bool AreZero(const Float16Impl& lhs, const Float16Impl& rhs) noexcept { - return static_cast((lhs.val | rhs.val) & ~kSignMask) == 0; - } - - bool operator==(const Float16Impl& rhs) const noexcept { - if (IsNaN() || rhs.IsNaN()) { - // IEEE defines that NaN is not equal to anything, including itself. - return false; - } - return val == rhs.val; - } - - bool operator!=(const Float16Impl& rhs) const noexcept { return !(*this == rhs); } - - bool operator<(const Float16Impl& rhs) const noexcept { - if (IsNaN() || rhs.IsNaN()) { - // IEEE defines that NaN is unordered with respect to everything, including itself. - return false; - } - - const bool left_is_negative = IsNegative(); - if (left_is_negative != rhs.IsNegative()) { - // When the signs of left and right differ, we know that left is less than right if it is - // the negative value. The exception to this is if both values are zero, in which case IEEE - // says they should be equal, even if the signs differ. - return left_is_negative && !AreZero(*this, rhs); - } - return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative); - } -}; - -// The following Float16_t conversions are based on the code from -// Eigen library. - -// The conversion routines are Copyright (c) Fabian Giesen, 2016. -// The original license follows: -// -// Copyright (c) Fabian Giesen, 2016 -// All rights reserved. -// Redistribution and use in source and binary forms, with or without -// modification, are permitted. -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -namespace detail { -union float32_bits { - unsigned int u; - float f; -}; -} // namespace detail - -template -inline constexpr uint16_t Float16Impl::ToUint16Impl(float v) noexcept { - detail::float32_bits f{}; - f.f = v; - - constexpr detail::float32_bits f32infty = {255 << 23}; - constexpr detail::float32_bits f16max = {(127 + 16) << 23}; - constexpr detail::float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23}; - constexpr unsigned int sign_mask = 0x80000000u; - uint16_t val = static_cast(0x0u); - - unsigned int sign = f.u & sign_mask; - f.u ^= sign; - - // NOTE all the integer compares in this function can be safely - // compiled into signed compares since all operands are below - // 0x80000000. Important if you want fast straight SSE2 code - // (since there's no unsigned PCMPGTD). - - if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set) - val = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf - } else { // (De)normalized number or zero - if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero - // use a magic value to align our 10 mantissa bits at the bottom of - // the float. as long as FP addition is round-to-nearest-even this - // just works. - f.f += denorm_magic.f; - - // and one integer subtract of the bias later, we have our final float! - val = static_cast(f.u - denorm_magic.u); - } else { - unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd - - // update exponent, rounding bias part 1 - // Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but - // without arithmetic overflow. - f.u += 0xc8000fffU; - // rounding bias part 2 - f.u += mant_odd; - // take the bits! - val = static_cast(f.u >> 13); - } - } - - val |= static_cast(sign >> 16); - return val; -} - -template -inline float Float16Impl::ToFloatImpl() const noexcept { - constexpr detail::float32_bits magic = {113 << 23}; - constexpr unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift - detail::float32_bits o{}; - - o.u = (val & 0x7fff) << 13; // exponent/mantissa bits - unsigned int exp = shifted_exp & o.u; // just the exponent - o.u += (127 - 15) << 23; // exponent adjust - - // handle exponent special cases - if (exp == shifted_exp) { // Inf/NaN? - o.u += (128 - 16) << 23; // extra exp adjust - } else if (exp == 0) { // Zero/Denormal? - o.u += 1 << 23; // extra exp adjust - o.f -= magic.f; // re-normalize - } - - // Attempt to workaround the Internal Compiler Error on ARM64 - // for bitwise | operator, including std::bitset -#if (defined _MSC_VER) && (defined _M_ARM || defined _M_ARM64 || defined _M_ARM64EC) - if (IsNegative()) { - return -o.f; - } -#else - // original code: - o.u |= (val & 0x8000U) << 16U; // sign bit -#endif - return o.f; -} - -/// Shared implementation between public and internal classes. CRTP pattern. -template -struct BFloat16Impl { - protected: - /// - /// Converts from float to uint16_t float16 representation - /// - /// - /// - static uint16_t ToUint16Impl(float v) noexcept; - - /// - /// Converts bfloat16 to float - /// - /// float representation of bfloat16 value - float ToFloatImpl() const noexcept; - - /// - /// Creates an instance that represents absolute value. - /// - /// Absolute value - uint16_t AbsImpl() const noexcept { - return static_cast(val & ~kSignMask); - } - - /// - /// Creates a new instance with the sign flipped. - /// - /// Flipped sign instance - uint16_t NegateImpl() const noexcept { - return IsNaN() ? val : static_cast(val ^ kSignMask); - } - - public: - // uint16_t special values - static constexpr uint16_t kSignMask = 0x8000U; - static constexpr uint16_t kBiasedExponentMask = 0x7F80U; - static constexpr uint16_t kPositiveInfinityBits = 0x7F80U; - static constexpr uint16_t kNegativeInfinityBits = 0xFF80U; - static constexpr uint16_t kPositiveQNaNBits = 0x7FC1U; - static constexpr uint16_t kNegativeQNaNBits = 0xFFC1U; - static constexpr uint16_t kSignaling_NaNBits = 0x7F80U; - static constexpr uint16_t kEpsilonBits = 0x0080U; - static constexpr uint16_t kMinValueBits = 0xFF7FU; - static constexpr uint16_t kMaxValueBits = 0x7F7FU; - static constexpr uint16_t kRoundToNearest = 0x7FFFU; - static constexpr uint16_t kOneBits = 0x3F80U; - static constexpr uint16_t kMinusOneBits = 0xBF80U; - - uint16_t val{0}; - - BFloat16Impl() = default; - - /// - /// Checks if the value is negative - /// - /// true if negative - bool IsNegative() const noexcept { - return static_cast(val) < 0; - } - - /// - /// Tests if the value is NaN - /// - /// true if NaN - bool IsNaN() const noexcept { - return AbsImpl() > kPositiveInfinityBits; - } - - /// - /// Tests if the value is finite - /// - /// true if finite - bool IsFinite() const noexcept { - return AbsImpl() < kPositiveInfinityBits; - } - - /// - /// Tests if the value represents positive infinity. - /// - /// true if positive infinity - bool IsPositiveInfinity() const noexcept { - return val == kPositiveInfinityBits; - } - - /// - /// Tests if the value represents negative infinity - /// - /// true if negative infinity - bool IsNegativeInfinity() const noexcept { - return val == kNegativeInfinityBits; - } - - /// - /// Tests if the value is either positive or negative infinity. - /// - /// True if absolute value is infinity - bool IsInfinity() const noexcept { - return AbsImpl() == kPositiveInfinityBits; - } - - /// - /// Tests if the value is NaN or zero. Useful for comparisons. - /// - /// True if NaN or zero. - bool IsNaNOrZero() const noexcept { - auto abs = AbsImpl(); - return (abs == 0 || abs > kPositiveInfinityBits); - } - - /// - /// Tests if the value is normal (not zero, subnormal, infinite, or NaN). - /// - /// True if so - bool IsNormal() const noexcept { - auto abs = AbsImpl(); - return (abs < kPositiveInfinityBits) // is finite - && (abs != 0) // is not zero - && ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent) - } - - /// - /// Tests if the value is subnormal (denormal). - /// - /// True if so - bool IsSubnormal() const noexcept { - auto abs = AbsImpl(); - return (abs < kPositiveInfinityBits) // is finite - && (abs != 0) // is not zero - && ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent) - } - - /// - /// Creates an instance that represents absolute value. - /// - /// Absolute value - Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); } - - /// - /// Creates a new instance with the sign flipped. - /// - /// Flipped sign instance - Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); } - - /// - /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check - /// for two values by or'ing the private bits together and stripping the sign. They are both zero, - /// and therefore equivalent, if the resulting value is still zero. - /// - /// first value - /// second value - /// True if both arguments represent zero - static bool AreZero(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept { - // IEEE defines that positive and negative zero are equal, this gives us a quick equality check - // for two values by or'ing the private bits together and stripping the sign. They are both zero, - // and therefore equivalent, if the resulting value is still zero. - return static_cast((lhs.val | rhs.val) & ~kSignMask) == 0; - } -}; - -template -inline uint16_t BFloat16Impl::ToUint16Impl(float v) noexcept { - uint16_t result; - if (std::isnan(v)) { - result = kPositiveQNaNBits; - } else { - auto get_msb_half = [](float fl) { - uint16_t result; -#ifdef __cpp_if_constexpr - if constexpr (detail::endian::native == detail::endian::little) { -#else - if (detail::endian::native == detail::endian::little) { -#endif - std::memcpy(&result, reinterpret_cast(&fl) + sizeof(uint16_t), sizeof(uint16_t)); - } else { - std::memcpy(&result, &fl, sizeof(uint16_t)); - } - return result; - }; - - uint16_t upper_bits = get_msb_half(v); - union { - uint32_t U32; - float F32; - }; - F32 = v; - U32 += (upper_bits & 1) + kRoundToNearest; - result = get_msb_half(F32); - } - return result; -} - -template -inline float BFloat16Impl::ToFloatImpl() const noexcept { - if (IsNaN()) { - return std::numeric_limits::quiet_NaN(); - } - float result; - char* const first = reinterpret_cast(&result); - char* const second = first + sizeof(uint16_t); -#ifdef __cpp_if_constexpr - if constexpr (detail::endian::native == detail::endian::little) { -#else - if (detail::endian::native == detail::endian::little) { -#endif - std::memset(first, 0, sizeof(uint16_t)); - std::memcpy(second, &val, sizeof(uint16_t)); - } else { - std::memcpy(first, &val, sizeof(uint16_t)); - std::memset(second, 0, sizeof(uint16_t)); - } - return result; -} - -} // namespace onnxruntime_float16 diff --git a/tools/onnx_lib/Source/include_rel/onnxruntime_run_options_config_keys.h b/tools/onnx_lib/Source/include_rel/onnxruntime_run_options_config_keys.h deleted file mode 100644 index c80b8c0c16..0000000000 --- a/tools/onnx_lib/Source/include_rel/onnxruntime_run_options_config_keys.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -/* - * This file defines RunOptions Config Keys and format of the Config Values. - * - * The Naming Convention for a RunOptions Config Key, - * "[Area][.[SubArea1].[SubArea2]...].[Keyname]" - * Such as "ep.cuda.use_arena" - * The Config Key cannot be empty - * The maximum length of the Config Key is 128 - * - * The string format of a RunOptions Config Value is defined individually for each Config. - * The maximum length of the Config Value is 1024 - */ - -// Key for enabling shrinkages of user listed device memory arenas. -// Expects a list of semi-colon separated key value pairs separated by colon in the following format: -// "device_0:device_id_0;device_1:device_id_1" -// No white-spaces allowed in the provided list string. -// Currently, the only supported devices are : "cpu", "gpu" (case sensitive). -// If "cpu" is included in the list, DisableCpuMemArena() API must not be called (i.e.) arena for cpu should be enabled. -// Example usage: "cpu:0;gpu:0" (or) "gpu:0" -// By default, the value for this key is empty (i.e.) no memory arenas are shrunk -static const char* const kOrtRunOptionsConfigEnableMemoryArenaShrinkage = "memory.enable_memory_arena_shrinkage"; - -// Set to '1' to not synchronize execution providers with CPU at the end of session run. -// Per default it will be set to '0' -// Taking CUDA EP as an example, it omit triggering cudaStreamSynchronize on the compute stream. -static const char* const kOrtRunOptionsConfigDisableSynchronizeExecutionProviders = "disable_synchronize_execution_providers"; - -// Set HTP performance mode for QNN HTP backend before session run. -// options for HTP performance mode: "burst", "balanced", "default", "high_performance", -// "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver", -// "sustained_high_performance". Default to "default". -static const char* const kOrtRunOptionsConfigQnnPerfMode = "qnn.htp_perf_mode"; - -// Set HTP performance mode for QNN HTP backend post session run. -static const char* const kOrtRunOptionsConfigQnnPerfModePostRun = "qnn.htp_perf_mode_post_run"; - -// Set RPC control latency for QNN HTP backend -static const char* const kOrtRunOptionsConfigQnnRpcControlLatency = "qnn.rpc_control_latency"; - -// Set graph annotation id for CUDA EP. Use with enable_cuda_graph=true. -// The value should be an integer. If the value is not set, the default value is 0 and -// ORT session only captures one cuda graph before another capture is requested. -// If the value is set to -1, cuda graph capture/replay is disabled in that run. -// User are not expected to set the value to 0 as it is reserved for internal use. -static const char* const kOrtRunOptionsConfigCudaGraphAnnotation = "gpu_graph_id"; diff --git a/tools/onnx_lib/Source/include_rel/onnxruntime_session_options_config_keys.h b/tools/onnx_lib/Source/include_rel/onnxruntime_session_options_config_keys.h deleted file mode 100644 index 209fd4279c..0000000000 --- a/tools/onnx_lib/Source/include_rel/onnxruntime_session_options_config_keys.h +++ /dev/null @@ -1,281 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -/* - * This file defines SessionOptions Config Keys and format of the Config Values. - * - * The Naming Convention for a SessionOptions Config Key, - * "[Area][.[SubArea1].[SubArea2]...].[Keyname]" - * Such as "ep.cuda.use_arena" - * The Config Key cannot be empty - * The maximum length of the Config Key is 128 - * - * The string format of a SessionOptions Config Value is defined individually for each Config. - * The maximum length of the Config Value is 1024 - */ - -// Key for disable PrePacking, -// If the config value is set to "1" then the prepacking is disabled, otherwise prepacking is enabled (default value) -static const char* const kOrtSessionOptionsConfigDisablePrepacking = "session.disable_prepacking"; - -// A value of "1" means allocators registered in the env will be used. "0" means the allocators created in the session -// will be used. Use this to override the usage of env allocators on a per session level. -static const char* const kOrtSessionOptionsConfigUseEnvAllocators = "session.use_env_allocators"; - -// Set to 'ORT' (case sensitive) to load an ORT format model. -// If unset, model type will default to ONNX unless inferred from filename ('.ort' == ORT format) or bytes to be ORT -static const char* const kOrtSessionOptionsConfigLoadModelFormat = "session.load_model_format"; - -// Set to 'ORT' (case sensitive) to save optimized model in ORT format when SessionOptions.optimized_model_path is set. -// If unset, format will default to ONNX unless optimized_model_filepath ends in '.ort'. -static const char* const kOrtSessionOptionsConfigSaveModelFormat = "session.save_model_format"; - -// If a value is "1", flush-to-zero and denormal-as-zero are applied. The default is "0". -// When multiple sessions are created, a main thread doesn't override changes from succeeding session options, -// but threads in session thread pools follow option changes. -// When ORT runs with OpenMP, the same rule is applied, i.e. the first session option to flush-to-zero and -// denormal-as-zero is only applied to global OpenMP thread pool, which doesn't support per-session thread pool. -// Note that an alternative way not using this option at runtime is to train and export a model without denormals -// and that's recommended because turning this option on may hurt model accuracy. -static const char* const kOrtSessionOptionsConfigSetDenormalAsZero = "session.set_denormal_as_zero"; - -// It controls to run quantization model in QDQ (QuantizelinearDeQuantizelinear) format or not. -// "0": enable. ORT does fusion logic for QDQ format. -// "1": disable. ORT doesn't do fusion logic for QDQ format. -// Its default value is "0" unless the DirectML execution provider is registered, in which case it defaults to "1". -static const char* const kOrtSessionOptionsDisableQuantQDQ = "session.disable_quant_qdq"; - -// It controls whether to enable Double QDQ remover and Identical Children Consolidation -// "0": not to disable. ORT does remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs -// "1": disable. ORT doesn't remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs -// Its default value is "0" -static const char* const kOrtSessionOptionsDisableDoubleQDQRemover = "session.disable_double_qdq_remover"; - -// If set to "1", enables the removal of QuantizeLinear/DequantizeLinear node pairs once all QDQ handling has been -// completed. e.g. If after all QDQ handling has completed and we have -> FloatOp -> Q -> DQ -> FloatOp -> the -// Q -> DQ could potentially be removed. This will provide a performance benefit by avoiding going from float to -// 8-bit and back to float, but could impact accuracy. The impact on accuracy will be model specific and depend on -// other factors like whether the model was created using Quantization Aware Training or Post Training Quantization. -// As such, it's best to test to determine if enabling this works well for your scenario. -// The default value is "0" -// Available since version 1.11. -static const char* const kOrtSessionOptionsEnableQuantQDQCleanup = "session.enable_quant_qdq_cleanup"; - -// Enable or disable gelu approximation in graph optimization. "0": disable; "1": enable. The default is "0". -// GeluApproximation has side effects which may change the inference results. It is disabled by default due to this. -static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimization.enable_gelu_approximation"; - -// This setting controls whether to enable AheadOfTime function inlining. -// AOT function inlining examines the graph and attempts to inline as many locally defined functions in the model -// as possible with the help of enabled execution providers. -// This can reduce the number of function calls and improve performance because it is done before -// Level1 optimizers and constant folding. However, under some circumstances, when the EPs are not available, -// one can disable the AOT inlining, produce an optimized model and postpone AOT until run time. -// "0": enable; "1": disable. -// Its default value is "0". -static const char* const kOrtSessionOptionsDisableAheadOfTimeFunctionInlining = "session.disable_aot_function_inlining"; - -#ifdef ENABLE_TRAINING -// Specifies a path of the file containing a list of memory optimization configurations. -// The value should be a string indicating the file path of the config file. -// The content of the config file is a JSON struct like this: -// [ -// "Gelu+Cast+:1:0", -// "Dropout+:1:1" -// ] -// Taking the example of "Gelu+Cast+:1:0", -// > "Gelu+Cast+" is the subgraph string, a valid "subgraph string" should be one subgraph representation -// output by ORT graph transformations. -// > "1" is "optimization strategy", valid values: 0 - disabled, 1 - recompute. -// > "0" is "number of subgraph to apply" which is used to control how many subgraphs to apply optimization, -// to avoid "oversaving" the memory. -static const char* const kOrtSessionOptionsMemoryOptimizerApplyConfig = "optimization.memory_optimizer_config"; - -// Specifies the config for detecting subgraphs for memory footprint reduction. -// The value should be a string contains int separated using commas. The default value is "0:0". -static const char* const kOrtSessionOptionsMemoryOptimizerProbeConfig = "optimization.enable_memory_probe_recompute_config"; -#endif - -// This setting if set should contain a comma separated list of optimizers names that should be disabled. -// Optimizers may take time to execute and affect model loading time. If you feel that a specific optimizer -// does not provider runtime benefits, but affects your model loading time you may disable it using this config -// entry. This option is not enabled in ORT_MINIMAL_BUILD build. -// A list of optimizes is available in onnxruntime/core/optimizer/graph_transformer_utils.cc -// -// Default is an empty string which means no optimizers are disabled. -static const char* const kOrtSessionOptionsDisableSpecifiedOptimizers = "optimization.disable_specified_optimizers"; - -// Enable or disable using device allocator for allocating initialized tensor memory. "1": enable; "0": disable. The default is "0". -// Using device allocators means the memory allocation is made using malloc/new. -static const char* const kOrtSessionOptionsUseDeviceAllocatorForInitializers = "session.use_device_allocator_for_initializers"; - -// Configure whether to allow the inter_op/intra_op threads spinning a number of times before blocking -// "0": thread will block if found no job to run -// "1": default, thread will spin a number of times before blocking -static const char* const kOrtSessionOptionsConfigAllowInterOpSpinning = "session.inter_op.allow_spinning"; -static const char* const kOrtSessionOptionsConfigAllowIntraOpSpinning = "session.intra_op.allow_spinning"; - -// Key for using model bytes directly for ORT format -// If a session is created using an input byte array contains the ORT format model data, -// By default we will copy the model bytes at the time of session creation to ensure the model bytes -// buffer is valid. -// Setting this option to "1" will disable copy the model bytes, and use the model bytes directly. The caller -// has to guarantee that the model bytes are valid until the ORT session using the model bytes is destroyed. -static const char* const kOrtSessionOptionsConfigUseORTModelBytesDirectly = "session.use_ort_model_bytes_directly"; - -/// -/// Key for using the ORT format model flatbuffer bytes directly for initializers. -/// This avoids copying the bytes and reduces peak memory usage during model loading and initialization. -/// Requires `session.use_ort_model_bytes_directly` to be true. -/// If set, the flatbuffer bytes provided when creating the InferenceSession MUST remain valid for the entire -/// duration of the InferenceSession. -/// -static const char* const kOrtSessionOptionsConfigUseORTModelBytesForInitializers = - "session.use_ort_model_bytes_for_initializers"; - -// This should only be specified when exporting an ORT format model for use on a different platform. -// If the ORT format model will be used on ARM platforms set to "1". For other platforms set to "0" -// Available since version 1.11. -static const char* const kOrtSessionOptionsQDQIsInt8Allowed = "session.qdqisint8allowed"; - -// x64 SSE4.1/AVX2/AVX512(with no VNNI) has overflow problem with quantizied matrix multiplication with U8S8. -// To avoid this we need to use slower U8U8 matrix multiplication instead. This option, if -// turned on, use slower U8U8 matrix multiplications. Only effective with AVX2 or AVX512 -// platforms. -static const char* const kOrtSessionOptionsAvx2PrecisionMode = "session.x64quantprecision"; - -// Specifies how minimal build graph optimizations are handled in a full build. -// These optimizations are at the extended level or higher. -// Possible values and their effects are: -// "save": Save runtime optimizations when saving an ORT format model. -// "apply": Only apply optimizations available in a minimal build. -// ""/: Apply optimizations available in a full build. -// Available since version 1.11. -static const char* const kOrtSessionOptionsConfigMinimalBuildOptimizations = - "optimization.minimal_build_optimizations"; - -// Note: The options specific to an EP should be specified prior to appending that EP to the session options object in -// order for them to take effect. - -// Specifies a list of stop op types. Nodes of a type in the stop op types and nodes downstream from them will not be -// run by the NNAPI EP. -// The value should be a ","-delimited list of op types. For example, "Add,Sub". -// If not specified, the default set of stop ops is used. To specify an empty stop ops types list and disable stop op -// exclusion, set the value to "". -static const char* const kOrtSessionOptionsConfigNnapiEpPartitioningStopOps = "ep.nnapi.partitioning_stop_ops"; - -// Enabling dynamic block-sizing for multithreading. -// With a positive value, thread pool will split a task of N iterations to blocks of size starting from: -// N / (num_of_threads * dynamic_block_base) -// As execution progresses, the size will decrease according to the diminishing residual of N, -// meaning the task will be distributed in smaller granularity for better parallelism. -// For some models, it helps to reduce the variance of E2E inference latency and boost performance. -// The feature will not function by default, specify any positive integer, e.g. "4", to enable it. -// Available since version 1.11. -static const char* const kOrtSessionOptionsConfigDynamicBlockBase = "session.dynamic_block_base"; - -// This option allows to decrease CPU usage between infrequent -// requests and forces any TP threads spinning stop immediately when the last of -// concurrent Run() call returns. -// Spinning is restarted on the next Run() call. -// Applies only to internal thread-pools -static const char* const kOrtSessionOptionsConfigForceSpinningStop = "session.force_spinning_stop"; - -// "1": all inconsistencies encountered during shape and type inference -// will result in failures. -// "0": in some cases warnings will be logged but processing will continue. The default. -// May be useful to expose bugs in models. -static const char* const kOrtSessionOptionsConfigStrictShapeTypeInference = "session.strict_shape_type_inference"; - -// "1": every model using a more recent opset than the latest released one will fail -// "0": the model may or may not work if onnxruntime cannot find an implementation, this option -// is used for development purpose. -static const char* const kOrtSessionOptionsConfigStrictAllowReleasedOpsetsOnly = "session.allow_released_opsets_only"; - -// The file saves configuration for partitioning node among logic streams -static const char* const kNodePartitionConfigFile = "session.node_partition_config_file"; - -// This Option allows setting affinities for intra op threads. -// Affinity string follows format: -// logical_processor_id,logical_processor_id;logical_processor_id,logical_processor_id -// Semicolon isolates configurations among threads, while comma split processors where ith thread expected to attach to. -// e.g.1,2,3;4,5 -// specifies affinities for two threads, with the 1st thread attach to the 1st, 2nd, and 3rd processor, and 2nd thread to the 4th and 5th. -// To ease the configuration, an "interval" is also allowed: -// e.g. 1-8;8-16;17-24 -// orders that the 1st thread runs on first eight processors, 2nd thread runs on next eight processors, and so forth. -// Note: -// 1. Once set, the number of thread affinities must equal to intra_op_num_threads - 1, since ort does not set affinity on the main thread which -// is started and managed by the calling app; -// 2. For windows, ort will infer the group id from a logical processor id, for example, assuming there are two groups with each has 64 logical processors, -// an id of 64 will be inferred as the last processor of the 1st group, while 65 will be interpreted as the 1st processor of the second group. -// Hence 64-65 is an invalid configuration, because a windows thread cannot be attached to processors across group boundary. -static const char* const kOrtSessionOptionsConfigIntraOpThreadAffinities = "session.intra_op_thread_affinities"; - -// This option will dump out the model to assist debugging any issues with layout transformation, -// and is primarily intended for developer usage. It is only relevant if an execution provider that requests -// NHWC layout is enabled such as NNAPI, XNNPACK or QNN. -// -// Default is off. Set to "1" to enable. -// -// If modified by layout transformation the model will be dumped after these steps: -// 1) insertion of the layout transformation Transpose nodes -// 2) after those are optimized using the transpose optimizer, -// 3) after the L1 transformers are applied to the updated graph. -// The model will be saved to filename post_layout_transform_step_.onnx. -static const char* const kDebugLayoutTransformation = "session.debug_layout_transformation"; - -// Graph nodes that are not supported by the execution providers (EPs) explicitly added to the session are -// assigned (i.e., "fallback") to the CPU EP by default. -// -// This option allows the user to disable the fallback of unsupported graph nodes to the CPU EP. -// If this option is set to "1", session creation will fail if the execution providers other than the CPU EP cannot -// fully support all of the nodes in the graph. -// -// It is invalid to set this option and explicitly add the CPU EP to the session. In this case, session creation -// will also fail with an error. -// -// Option values: -// - "0": CPU EP fallback is not disabled. [DEFAULT] -// - "1": CPU EP fallback is disabled. -static const char* const kOrtSessionOptionsDisableCPUEPFallback = "session.disable_cpu_ep_fallback"; - -// Use this config when serializing a large model after optimization to specify an external initializers file -static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersFileName = - "session.optimized_model_external_initializers_file_name"; - -// Use this config to control the minimum size of the initializer when externalizing it during serialization -static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes = - "session.optimized_model_external_initializers_min_size_in_bytes"; - -// Enable EP context feature to dump the partitioned graph which includes the EP context into Onnx file. -// The dumped Onnx model with EP context can be used for future inference to avoid the EP graph partitioning/compile overhead. -// "0": disable. (default) -// "1": enable. -static const char* const kOrtSessionOptionEpContextEnable = "ep.context_enable"; - -// Specify the file path for the Onnx model which has EP context. -// Default to original_file_name_ctx.onnx if not specified -static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_path"; - -// Flag to specify whether to dump the EP context into the Onnx model. -// "0": dump the EP context into separate file, keep the file name in the Onnx model. -// "1": dump the EP context into the Onnx model. (default). -static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode"; - -// Specify the EPContext node name prefix to make it unique -// in case user need to merge/connect multiple EPContext nodes in one model -static const char* const kOrtSessionOptionEpContextNodeNamePrefix = "ep.context_node_name_prefix"; - -// Gemm fastmath mode provides fp32 gemm acceleration with bfloat16 based matmul. -// Option values: -// - "0": Gemm FastMath mode is not enabled. [DEFAULT] -// - "1": Gemm FastMath mode is enabled. -static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16"; - -// When converting DQ + MatMul -> MatMulNBits, the accuracy level of the MatMulNBits is controlled by this option. -// Refer to MatMulNBits op schema for more details. -// If not provided, default is 4. -static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "session.qdq_matmulnbits_accuracy_level"; diff --git a/tools/onnx_lib/Source/include_rel/provider_options.h b/tools/onnx_lib/Source/include_rel/provider_options.h deleted file mode 100644 index aab13e808e..0000000000 --- a/tools/onnx_lib/Source/include_rel/provider_options.h +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -namespace onnxruntime { - -// data types for execution provider options - -using ProviderOptions = std::unordered_map; -using ProviderOptionsVector = std::vector; -using ProviderOptionsMap = std::unordered_map; - -} // namespace onnxruntime diff --git a/tools/onnx_lib/onnx_hise_library.jucer b/tools/onnx_lib/onnx_hise_library.jucer index 17ef8ff288..07beb1bd2d 100644 --- a/tools/onnx_lib/onnx_hise_library.jucer +++ b/tools/onnx_lib/onnx_hise_library.jucer @@ -1,60 +1,78 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +