Skip to content

Commit

Permalink
Major changes to make clusterizer parallelizable. Problem remains: di…
Browse files Browse the repository at this point in the history
…fferent sizes of nnClusterizerBatchedMode lead to different number of clusters if nnClusterizerBatchedMode < clusterer.mPmemory->counters.nClusters
  • Loading branch information
ChSonnabend committed Feb 17, 2025
1 parent 3c4c587 commit 95bb2ff
Show file tree
Hide file tree
Showing 9 changed files with 445 additions and 712 deletions.
4 changes: 3 additions & 1 deletion Common/ML/include/ML/OrtInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class OrtModel
OrtModel(std::unordered_map<std::string, std::string> optionsMap) { reset(optionsMap); }
void init(std::unordered_map<std::string, std::string> optionsMap) { reset(optionsMap); }
void reset(std::unordered_map<std::string, std::string>);
bool isInitialized() { return mInitialized; }

virtual ~OrtModel() = default;

Expand Down Expand Up @@ -79,6 +80,7 @@ class OrtModel
std::vector<std::vector<int64_t>> mInputShapes, mOutputShapes;

// Environment settings
bool mInitialized = false;
std::string modelPath, device = "cpu", dtype = "float"; // device options should be cpu, rocm, migraphx, cuda
int intraOpNumThreads = 0, deviceId = 0, enableProfiling = 0, loggingLevel = 0, allocateDeviceMemory = 0, enableOptimizations = 0;

Expand All @@ -89,4 +91,4 @@ class OrtModel

} // namespace o2

#endif // O2_ML_ORTINTERFACE_H
#endif // O2_ML_ORTINTERFACE_H
168 changes: 83 additions & 85 deletions Common/ML/src/OrtInterface.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,19 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
if (!optionsMap.contains("model-path")) {
LOG(fatal) << "(ORT) Model path cannot be empty!";
}
modelPath = optionsMap["model-path"];
device = (optionsMap.contains("device") ? optionsMap["device"] : "CPU");
dtype = (optionsMap.contains("dtype") ? optionsMap["dtype"] : "float");
deviceId = (optionsMap.contains("device-id") ? std::stoi(optionsMap["device-id"]) : 0);
allocateDeviceMemory = (optionsMap.contains("allocate-device-memory") ? std::stoi(optionsMap["allocate-device-memory"]) : 0);
intraOpNumThreads = (optionsMap.contains("intra-op-num-threads") ? std::stoi(optionsMap["intra-op-num-threads"]) : 0);
loggingLevel = (optionsMap.contains("logging-level") ? std::stoi(optionsMap["logging-level"]) : 2);
enableProfiling = (optionsMap.contains("enable-profiling") ? std::stoi(optionsMap["enable-profiling"]) : 0);
enableOptimizations = (optionsMap.contains("enable-optimizations") ? std::stoi(optionsMap["enable-optimizations"]) : 0);

std::string dev_mem_str = "Hip";

Check failure on line 47 in Common/ML/src/OrtInterface.cxx

View workflow job for this annotation

GitHub Actions / PR formatting / whitespace

Trailing spaces

Remove the trailing spaces at the end of the line.
if (!optionsMap["model-path"].empty()) {
modelPath = optionsMap["model-path"];
device = (optionsMap.contains("device") ? optionsMap["device"] : "CPU");
dtype = (optionsMap.contains("dtype") ? optionsMap["dtype"] : "float");
deviceId = (optionsMap.contains("device-id") ? std::stoi(optionsMap["device-id"]) : 0);
allocateDeviceMemory = (optionsMap.contains("allocate-device-memory") ? std::stoi(optionsMap["allocate-device-memory"]) : 0);
intraOpNumThreads = (optionsMap.contains("intra-op-num-threads") ? std::stoi(optionsMap["intra-op-num-threads"]) : 0);
loggingLevel = (optionsMap.contains("logging-level") ? std::stoi(optionsMap["logging-level"]) : 0);
enableProfiling = (optionsMap.contains("enable-profiling") ? std::stoi(optionsMap["enable-profiling"]) : 0);
enableOptimizations = (optionsMap.contains("enable-optimizations") ? std::stoi(optionsMap["enable-optimizations"]) : 0);

std::string dev_mem_str = "Hip";
#if defined(ORT_ROCM_BUILD)
#if ORT_ROCM_BUILD == 1
if (device == "ROCM") {
Expand All @@ -81,89 +83,85 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
#endif
#endif

if (allocateDeviceMemory) {
pImplOrt->memoryInfo = Ort::MemoryInfo(dev_mem_str.c_str(), OrtAllocatorType::OrtDeviceAllocator, deviceId, OrtMemType::OrtMemTypeDefault);
LOG(info) << "(ORT) Memory info set to on-device memory";
}
if (allocateDeviceMemory) {
pImplOrt->memoryInfo = Ort::MemoryInfo(dev_mem_str.c_str(), OrtAllocatorType::OrtDeviceAllocator, deviceId, OrtMemType::OrtMemTypeDefault);
LOG(info) << "(ORT) Memory info set to on-device memory";
}

if (device == "CPU") {
(pImplOrt->sessionOptions).SetIntraOpNumThreads(intraOpNumThreads);
if (intraOpNumThreads > 1) {
(pImplOrt->sessionOptions).SetExecutionMode(ExecutionMode::ORT_PARALLEL);
} else if (intraOpNumThreads == 1) {
(pImplOrt->sessionOptions).SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
if (device == "CPU") {
(pImplOrt->sessionOptions).SetIntraOpNumThreads(intraOpNumThreads);
if (intraOpNumThreads > 1) {
(pImplOrt->sessionOptions).SetExecutionMode(ExecutionMode::ORT_PARALLEL);
} else if (intraOpNumThreads == 1) {
(pImplOrt->sessionOptions).SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
}
if (loggingLevel < 2) {
LOG(info) << "(ORT) CPU execution provider set with " << intraOpNumThreads << " threads";
}
}
LOG(info) << "(ORT) CPU execution provider set with " << intraOpNumThreads << " threads";
}

(pImplOrt->sessionOptions).DisableMemPattern();
(pImplOrt->sessionOptions).DisableCpuMemArena();
(pImplOrt->sessionOptions).DisableMemPattern();
(pImplOrt->sessionOptions).DisableCpuMemArena();

if (enableProfiling) {
if (optionsMap.contains("profiling-output-path")) {
(pImplOrt->sessionOptions).EnableProfiling((optionsMap["profiling-output-path"] + "/ORT_LOG_").c_str());
if (enableProfiling) {
if (optionsMap.contains("profiling-output-path")) {
(pImplOrt->sessionOptions).EnableProfiling((optionsMap["profiling-output-path"] + "/ORT_LOG_").c_str());
} else {
LOG(warning) << "(ORT) If profiling is enabled, optionsMap[\"profiling-output-path\"] should be set. Disabling profiling for now.";
(pImplOrt->sessionOptions).DisableProfiling();
}
} else {
LOG(warning) << "(ORT) If profiling is enabled, optionsMap[\"profiling-output-path\"] should be set. Disabling profiling for now.";
(pImplOrt->sessionOptions).DisableProfiling();
}
} else {
(pImplOrt->sessionOptions).DisableProfiling();
}
(pImplOrt->sessionOptions).SetGraphOptimizationLevel(GraphOptimizationLevel(enableOptimizations));
(pImplOrt->sessionOptions).SetLogSeverityLevel(OrtLoggingLevel(loggingLevel));

pImplOrt->env = std::make_shared<Ort::Env>(
OrtLoggingLevel(loggingLevel),
(optionsMap["onnx-environment-name"].empty() ? "onnx_model_inference" : optionsMap["onnx-environment-name"].c_str()),
// Integrate ORT logging into Fairlogger
[](void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, const char* message) {
if (severity == ORT_LOGGING_LEVEL_VERBOSE) {
LOG(debug) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
} else if (severity == ORT_LOGGING_LEVEL_INFO) {
LOG(info) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
} else if (severity == ORT_LOGGING_LEVEL_WARNING) {
LOG(warning) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
} else if (severity == ORT_LOGGING_LEVEL_ERROR) {
LOG(error) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
} else if (severity == ORT_LOGGING_LEVEL_FATAL) {
LOG(fatal) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
} else {
LOG(info) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
}
},
(void*)3);
(pImplOrt->env)->DisableTelemetryEvents(); // Disable telemetry events
pImplOrt->session = std::make_shared<Ort::Session>(*(pImplOrt->env), modelPath.c_str(), pImplOrt->sessionOptions);

for (size_t i = 0; i < (pImplOrt->session)->GetInputCount(); ++i) {
mInputNames.push_back((pImplOrt->session)->GetInputNameAllocated(i, pImplOrt->allocator).get());
}
for (size_t i = 0; i < (pImplOrt->session)->GetInputCount(); ++i) {
mInputShapes.emplace_back((pImplOrt->session)->GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape());
}
for (size_t i = 0; i < (pImplOrt->session)->GetOutputCount(); ++i) {
mOutputNames.push_back((pImplOrt->session)->GetOutputNameAllocated(i, pImplOrt->allocator).get());
}
for (size_t i = 0; i < (pImplOrt->session)->GetOutputCount(); ++i) {
mOutputShapes.emplace_back((pImplOrt->session)->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape());
}
mInitialized = true;

inputNamesChar.resize(mInputNames.size(), nullptr);
std::transform(std::begin(mInputNames), std::end(mInputNames), std::begin(inputNamesChar),
[&](const std::string& str) { return str.c_str(); });
outputNamesChar.resize(mOutputNames.size(), nullptr);
std::transform(std::begin(mOutputNames), std::end(mOutputNames), std::begin(outputNamesChar),
[&](const std::string& str) { return str.c_str(); });

// Print names
LOG(info) << "\tInput Nodes:";
for (size_t i = 0; i < mInputNames.size(); i++) {
LOG(info) << "\t\t" << mInputNames[i] << " : " << printShape(mInputShapes[i]);
}
(pImplOrt->sessionOptions).SetGraphOptimizationLevel(GraphOptimizationLevel(enableOptimizations));
(pImplOrt->sessionOptions).SetLogSeverityLevel(OrtLoggingLevel(loggingLevel));

pImplOrt->env = std::make_shared<Ort::Env>(
OrtLoggingLevel(loggingLevel),
(optionsMap["onnx-environment-name"].empty() ? "onnx_model_inference" : optionsMap["onnx-environment-name"].c_str()),
// Integrate ORT logging into Fairlogger
[](void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, const char* message) {
if (severity == ORT_LOGGING_LEVEL_VERBOSE) {
LOG(debug) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
} else if (severity == ORT_LOGGING_LEVEL_INFO) {
LOG(info) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
} else if (severity == ORT_LOGGING_LEVEL_WARNING) {
LOG(warning) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
} else if (severity == ORT_LOGGING_LEVEL_ERROR) {
LOG(error) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
} else if (severity == ORT_LOGGING_LEVEL_FATAL) {
LOG(fatal) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
} else {
LOG(info) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
}
},
(void*)3);
(pImplOrt->env)->DisableTelemetryEvents(); // Disable telemetry events
pImplOrt->session = std::make_shared<Ort::Session>(*(pImplOrt->env), modelPath.c_str(), pImplOrt->sessionOptions);

for (size_t i = 0; i < (pImplOrt->session)->GetInputCount(); ++i) {
mInputNames.push_back((pImplOrt->session)->GetInputNameAllocated(i, pImplOrt->allocator).get());
}
for (size_t i = 0; i < (pImplOrt->session)->GetInputCount(); ++i) {
mInputShapes.emplace_back((pImplOrt->session)->GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape());
}
for (size_t i = 0; i < (pImplOrt->session)->GetOutputCount(); ++i) {
mOutputNames.push_back((pImplOrt->session)->GetOutputNameAllocated(i, pImplOrt->allocator).get());
}
for (size_t i = 0; i < (pImplOrt->session)->GetOutputCount(); ++i) {
mOutputShapes.emplace_back((pImplOrt->session)->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape());
}

inputNamesChar.resize(mInputNames.size(), nullptr);
std::transform(std::begin(mInputNames), std::end(mInputNames), std::begin(inputNamesChar),
[&](const std::string& str) { return str.c_str(); });
outputNamesChar.resize(mOutputNames.size(), nullptr);
std::transform(std::begin(mOutputNames), std::end(mOutputNames), std::begin(outputNamesChar),
[&](const std::string& str) { return str.c_str(); });

LOG(info) << "\tOutput Nodes:";
for (size_t i = 0; i < mOutputNames.size(); i++) {
LOG(info) << "\t\t" << mOutputNames[i] << " : " << printShape(mOutputShapes[i]);
}
}

Expand Down Expand Up @@ -301,4 +299,4 @@ std::vector<OrtDataType::Float16_t> OrtModel::inference<OrtDataType::Float16_t,

} // namespace ml

} // namespace o2
} // namespace o2
62 changes: 31 additions & 31 deletions GPU/GPUTracking/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -159,37 +159,37 @@ set(HDRS_INSTALL

set(SRCS_NO_CINT ${SRCS_NO_CINT} display/GPUDisplayInterface.cxx)
set(SRCS_NO_CINT
${SRCS_NO_CINT}
Global/GPUChainITS.cxx
ITS/GPUITSFitter.cxx
ITS/GPUITSFitterKernels.cxx
dEdx/GPUdEdx.cxx
TPCConvert/GPUTPCConvert.cxx
TPCConvert/GPUTPCConvertKernel.cxx
DataCompression/GPUTPCCompression.cxx
DataCompression/GPUTPCCompressionTrackModel.cxx
DataCompression/GPUTPCCompressionKernels.cxx
DataCompression/GPUTPCDecompression.cxx
DataCompression/GPUTPCDecompressionKernels.cxx
DataCompression/TPCClusterDecompressor.cxx
DataCompression/GPUTPCClusterStatistics.cxx
TPCClusterFinder/GPUTPCClusterFinder.cxx
TPCClusterFinder/ClusterAccumulator.cxx
TPCClusterFinder/MCLabelAccumulator.cxx
TPCClusterFinder/GPUTPCCFCheckPadBaseline.cxx
TPCClusterFinder/GPUTPCCFStreamCompaction.cxx
TPCClusterFinder/GPUTPCCFChargeMapFiller.cxx
TPCClusterFinder/GPUTPCCFPeakFinder.cxx
TPCClusterFinder/GPUTPCCFNoiseSuppression.cxx
TPCClusterFinder/GPUTPCCFClusterizer.cxx
TPCClusterFinder/GPUTPCNNClusterizer.cxx
TPCClusterFinder/GPUTPCCFDeconvolution.cxx
TPCClusterFinder/GPUTPCCFMCLabelFlattener.cxx
TPCClusterFinder/GPUTPCCFDecodeZS.cxx
TPCClusterFinder/GPUTPCCFGather.cxx
Refit/GPUTrackingRefit.cxx
Refit/GPUTrackingRefitKernel.cxx
Merger/GPUTPCGMO2Output.cxx)
${SRCS_NO_CINT}
Global/GPUChainITS.cxx
ITS/GPUITSFitter.cxx
ITS/GPUITSFitterKernels.cxx
dEdx/GPUdEdx.cxx
TPCConvert/GPUTPCConvert.cxx
TPCConvert/GPUTPCConvertKernel.cxx
DataCompression/GPUTPCCompression.cxx
DataCompression/GPUTPCCompressionTrackModel.cxx
DataCompression/GPUTPCCompressionKernels.cxx
DataCompression/GPUTPCDecompression.cxx
DataCompression/GPUTPCDecompressionKernels.cxx
DataCompression/TPCClusterDecompressor.cxx
DataCompression/GPUTPCClusterStatistics.cxx
TPCClusterFinder/GPUTPCClusterFinder.cxx
TPCClusterFinder/ClusterAccumulator.cxx
TPCClusterFinder/MCLabelAccumulator.cxx
TPCClusterFinder/GPUTPCCFCheckPadBaseline.cxx
TPCClusterFinder/GPUTPCCFStreamCompaction.cxx
TPCClusterFinder/GPUTPCCFChargeMapFiller.cxx
TPCClusterFinder/GPUTPCCFPeakFinder.cxx
TPCClusterFinder/GPUTPCCFNoiseSuppression.cxx
TPCClusterFinder/GPUTPCCFClusterizer.cxx
TPCClusterFinder/GPUTPCNNClusterizer.cxx
TPCClusterFinder/GPUTPCCFDeconvolution.cxx
TPCClusterFinder/GPUTPCCFMCLabelFlattener.cxx
TPCClusterFinder/GPUTPCCFDecodeZS.cxx
TPCClusterFinder/GPUTPCCFGather.cxx
Refit/GPUTrackingRefit.cxx
Refit/GPUTrackingRefitKernel.cxx
Merger/GPUTPCGMO2Output.cxx)

set(SRCS_DATATYPES
${SRCS_DATATYPES}
Expand Down
Loading

0 comments on commit 95bb2ff

Please sign in to comment.