onnx: Update to OnnxRT >= 1.13.1 API

- Replace deprecated methods
- Add a check on ORT version we are compatible with.
- Add clarification to the example given.
- Add the url to retrieve the model mentioned in the example.

Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/3388>
This commit is contained in:
Daniel Morin 2022-11-10 08:50:35 -05:00 committed by GStreamer Marge Bot
parent e7d584a816
commit 855f84c558
4 changed files with 38 additions and 15 deletions

View file

@ -73,6 +73,7 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
GstOnnxClient::~GstOnnxClient () GstOnnxClient::~GstOnnxClient ()
{ {
outputNames.clear();
delete session; delete session;
delete[]dest; delete[]dest;
} }
@ -115,6 +116,10 @@ std::string GstOnnxClient::getOutputNodeName (GstMlOutputNodeFunction nodeType)
case GST_ML_OUTPUT_NODE_FUNCTION_CLASS: case GST_ML_OUTPUT_NODE_FUNCTION_CLASS:
return "label"; return "label";
break; break;
case GST_ML_OUTPUT_NODE_NUMBER_OF:
g_assert_not_reached();
GST_WARNING("Invalid parameter");
break;
}; };
return ""; return "";
@ -130,9 +135,16 @@ GstMlModelInputImageFormat GstOnnxClient::getInputImageFormat (void)
return inputImageFormat; return inputImageFormat;
} }
std::vector < const char *>GstOnnxClient::getOutputNodeNames (void) std::vector< const char *> GstOnnxClient::getOutputNodeNames (void)
{ {
return outputNames; if (!outputNames.empty() && outputNamesRaw.size() != outputNames.size()) {
outputNamesRaw.resize(outputNames.size());
for (size_t i = 0; i < outputNamesRaw.size(); i++) {
outputNamesRaw[i] = outputNames[i].get();
}
}
return outputNamesRaw;
} }
void GstOnnxClient::setOutputNodeIndex (GstMlOutputNodeFunction node, void GstOnnxClient::setOutputNodeIndex (GstMlOutputNodeFunction node,
@ -227,11 +239,13 @@ bool GstOnnxClient::createSession (std::string modelFile,
GST_DEBUG ("Number of Output Nodes: %d", (gint) session->GetOutputCount ()); GST_DEBUG ("Number of Output Nodes: %d", (gint) session->GetOutputCount ());
Ort::AllocatorWithDefaultOptions allocator; Ort::AllocatorWithDefaultOptions allocator;
GST_DEBUG ("Input name: %s", session->GetInputName (0, allocator)); auto input_name = session->GetInputNameAllocated (0, allocator);
GST_DEBUG ("Input name: %s", input_name.get());
for (size_t i = 0; i < session->GetOutputCount (); ++i) { for (size_t i = 0; i < session->GetOutputCount (); ++i) {
auto output_name = session->GetOutputName (i, allocator); auto output_name = session->GetOutputNameAllocated (i, allocator);
outputNames.push_back (output_name); GST_DEBUG("Output name %lu:%s", i, output_name.get());
outputNames.push_back (std::move(output_name));
auto type_info = session->GetOutputTypeInfo (i); auto type_info = session->GetOutputTypeInfo (i);
auto tensor_info = type_info.GetTensorTypeAndShapeInfo (); auto tensor_info = type_info.GetTensorTypeAndShapeInfo ();
@ -278,7 +292,7 @@ template < typename T > std::vector < GstMlBoundingBox >
parseDimensions (vmeta); parseDimensions (vmeta);
Ort::AllocatorWithDefaultOptions allocator; Ort::AllocatorWithDefaultOptions allocator;
auto inputName = session->GetInputName (0, allocator); auto inputName = session->GetInputNameAllocated (0, allocator);
auto inputTypeInfo = session->GetInputTypeInfo (0); auto inputTypeInfo = session->GetInputTypeInfo (0);
std::vector < int64_t > inputDims = std::vector < int64_t > inputDims =
inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape (); inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape ();
@ -366,11 +380,11 @@ template < typename T > std::vector < GstMlBoundingBox >
std::vector < Ort::Value > inputTensors; std::vector < Ort::Value > inputTensors;
inputTensors.push_back (Ort::Value::CreateTensor < uint8_t > (memoryInfo, inputTensors.push_back (Ort::Value::CreateTensor < uint8_t > (memoryInfo,
dest, inputTensorSize, inputDims.data (), inputDims.size ())); dest, inputTensorSize, inputDims.data (), inputDims.size ()));
std::vector < const char *>inputNames { inputName }; std::vector < const char *>inputNames { inputName.get () };
std::vector < Ort::Value > modelOutput = session->Run (Ort::RunOptions { nullptr}, std::vector < Ort::Value > modelOutput = session->Run (Ort::RunOptions { nullptr},
inputNames.data (), inputNames.data (),
inputTensors.data (), 1, outputNames.data (), outputNames.size ()); inputTensors.data (), 1, outputNamesRaw.data (), outputNamesRaw.size ());
auto numDetections = auto numDetections =
modelOutput[getOutputNodeIndex modelOutput[getOutputNodeIndex

View file

@ -108,7 +108,8 @@ namespace GstOnnxNamespace {
GstMlOutputNodeInfo outputNodeInfo[GST_ML_OUTPUT_NODE_NUMBER_OF]; GstMlOutputNodeInfo outputNodeInfo[GST_ML_OUTPUT_NODE_NUMBER_OF];
// !! indexed by array index // !! indexed by array index
size_t outputNodeIndexToFunction[GST_ML_OUTPUT_NODE_NUMBER_OF]; size_t outputNodeIndexToFunction[GST_ML_OUTPUT_NODE_NUMBER_OF];
std::vector < const char *>outputNames; std::vector < const char *> outputNamesRaw;
std::vector < Ort::AllocatedStringPtr > outputNames;
GstMlModelInputImageFormat inputImageFormat; GstMlModelInputImageFormat inputImageFormat;
bool fixedInputImageSize; bool fixedInputImageSize;
}; };

View file

@ -55,16 +55,24 @@
* *
* ## Example launch command: * ## Example launch command:
* *
* (note: an object detection model has 3 or 4 output nodes, but there is no naming convention * (note: an object detection model has 3 or 4 output nodes, but there is no
* to indicate which node outputs the bounding box, which node outputs the label, etc. * naming convention to indicate which node outputs the bounding box, which
* So, the `onnxobjectdetector` element has properties to map each node's functionality to its * node outputs the label, etc. So, the `onnxobjectdetector` element has
* respective node index in the specified model ) * properties to map each node's functionality to its respective node index in
* the specified model. Image resolution also need to be adapted to the model.
* The videoscale in the pipeline below will scale the image, using padding if
* required, to 640x383 resolution required by the model.)
*
* model.onnx can be found here:
* https://github.com/zoq/onnx-runtime-examples/raw/main/data/models/model.onnx
* *
* ``` * ```
* GST_DEBUG=objectdetector:5 gst-launch-1.0 multifilesrc \ * GST_DEBUG=objectdetector:5 gst-launch-1.0 multifilesrc \
* location=000000088462.jpg caps=image/jpeg,framerate=\(fraction\)30/1 ! jpegdec ! \ * location=000000088462.jpg caps=image/jpeg,framerate=\(fraction\)30/1 ! jpegdec ! \
* videoconvert ! \ * videoconvert ! \
* onnxobjectdetector \ * videoscale ! \
* 'video/x-raw,width=640,height=383' ! \
* onnxobjectdetector ! \
* box-node-index=0 \ * box-node-index=0 \
* class-node-index=1 \ * class-node-index=1 \
* score-node-index=2 \ * score-node-index=2 \

View file

@ -3,7 +3,7 @@ if get_option('onnx').disabled()
endif endif
onnxrt_dep = dependency('libonnxruntime',required : get_option('onnx')) onnxrt_dep = dependency('libonnxruntime', version : '>= 1.13.1', required : get_option('onnx'))
if onnxrt_dep.found() if onnxrt_dep.found()
onnxrt_include_root = onnxrt_dep.get_variable('includedir') onnxrt_include_root = onnxrt_dep.get_variable('includedir')