mirror of
https://gitlab.freedesktop.org/gstreamer/gstreamer.git
synced 2024-11-26 11:41:09 +00:00
onnx: Update to OnnxRT >= 1.13.1 API
- Replace deprecated methods - Add a check on ORT version we are compatible with. - Add clarification to the example given. - Add the url to retrieve the model mentioned in the example. Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/3388>
This commit is contained in:
parent
e7d584a816
commit
855f84c558
4 changed files with 38 additions and 15 deletions
|
@ -73,6 +73,7 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
|
|||
|
||||
GstOnnxClient::~GstOnnxClient ()
|
||||
{
|
||||
outputNames.clear();
|
||||
delete session;
|
||||
delete[]dest;
|
||||
}
|
||||
|
@ -115,6 +116,10 @@ std::string GstOnnxClient::getOutputNodeName (GstMlOutputNodeFunction nodeType)
|
|||
case GST_ML_OUTPUT_NODE_FUNCTION_CLASS:
|
||||
return "label";
|
||||
break;
|
||||
case GST_ML_OUTPUT_NODE_NUMBER_OF:
|
||||
g_assert_not_reached();
|
||||
GST_WARNING("Invalid parameter");
|
||||
break;
|
||||
};
|
||||
|
||||
return "";
|
||||
|
@ -130,9 +135,16 @@ GstMlModelInputImageFormat GstOnnxClient::getInputImageFormat (void)
|
|||
return inputImageFormat;
|
||||
}
|
||||
|
||||
std::vector < const char *>GstOnnxClient::getOutputNodeNames (void)
|
||||
std::vector< const char *> GstOnnxClient::getOutputNodeNames (void)
|
||||
{
|
||||
return outputNames;
|
||||
if (!outputNames.empty() && outputNamesRaw.size() != outputNames.size()) {
|
||||
outputNamesRaw.resize(outputNames.size());
|
||||
for (size_t i = 0; i < outputNamesRaw.size(); i++) {
|
||||
outputNamesRaw[i] = outputNames[i].get();
|
||||
}
|
||||
}
|
||||
|
||||
return outputNamesRaw;
|
||||
}
|
||||
|
||||
void GstOnnxClient::setOutputNodeIndex (GstMlOutputNodeFunction node,
|
||||
|
@ -227,11 +239,13 @@ bool GstOnnxClient::createSession (std::string modelFile,
|
|||
GST_DEBUG ("Number of Output Nodes: %d", (gint) session->GetOutputCount ());
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
GST_DEBUG ("Input name: %s", session->GetInputName (0, allocator));
|
||||
auto input_name = session->GetInputNameAllocated (0, allocator);
|
||||
GST_DEBUG ("Input name: %s", input_name.get());
|
||||
|
||||
for (size_t i = 0; i < session->GetOutputCount (); ++i) {
|
||||
auto output_name = session->GetOutputName (i, allocator);
|
||||
outputNames.push_back (output_name);
|
||||
auto output_name = session->GetOutputNameAllocated (i, allocator);
|
||||
GST_DEBUG("Output name %lu:%s", i, output_name.get());
|
||||
outputNames.push_back (std::move(output_name));
|
||||
auto type_info = session->GetOutputTypeInfo (i);
|
||||
auto tensor_info = type_info.GetTensorTypeAndShapeInfo ();
|
||||
|
||||
|
@ -278,7 +292,7 @@ template < typename T > std::vector < GstMlBoundingBox >
|
|||
parseDimensions (vmeta);
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
auto inputName = session->GetInputName (0, allocator);
|
||||
auto inputName = session->GetInputNameAllocated (0, allocator);
|
||||
auto inputTypeInfo = session->GetInputTypeInfo (0);
|
||||
std::vector < int64_t > inputDims =
|
||||
inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape ();
|
||||
|
@ -366,11 +380,11 @@ template < typename T > std::vector < GstMlBoundingBox >
|
|||
std::vector < Ort::Value > inputTensors;
|
||||
inputTensors.push_back (Ort::Value::CreateTensor < uint8_t > (memoryInfo,
|
||||
dest, inputTensorSize, inputDims.data (), inputDims.size ()));
|
||||
std::vector < const char *>inputNames { inputName };
|
||||
std::vector < const char *>inputNames { inputName.get () };
|
||||
|
||||
std::vector < Ort::Value > modelOutput = session->Run (Ort::RunOptions { nullptr},
|
||||
inputNames.data (),
|
||||
inputTensors.data (), 1, outputNames.data (), outputNames.size ());
|
||||
inputTensors.data (), 1, outputNamesRaw.data (), outputNamesRaw.size ());
|
||||
|
||||
auto numDetections =
|
||||
modelOutput[getOutputNodeIndex
|
||||
|
|
|
@ -108,7 +108,8 @@ namespace GstOnnxNamespace {
|
|||
GstMlOutputNodeInfo outputNodeInfo[GST_ML_OUTPUT_NODE_NUMBER_OF];
|
||||
// !! indexed by array index
|
||||
size_t outputNodeIndexToFunction[GST_ML_OUTPUT_NODE_NUMBER_OF];
|
||||
std::vector < const char *>outputNames;
|
||||
std::vector < const char *> outputNamesRaw;
|
||||
std::vector < Ort::AllocatedStringPtr > outputNames;
|
||||
GstMlModelInputImageFormat inputImageFormat;
|
||||
bool fixedInputImageSize;
|
||||
};
|
||||
|
|
|
@ -55,16 +55,24 @@
|
|||
*
|
||||
* ## Example launch command:
|
||||
*
|
||||
* (note: an object detection model has 3 or 4 output nodes, but there is no naming convention
|
||||
* to indicate which node outputs the bounding box, which node outputs the label, etc.
|
||||
* So, the `onnxobjectdetector` element has properties to map each node's functionality to its
|
||||
* respective node index in the specified model )
|
||||
* (note: an object detection model has 3 or 4 output nodes, but there is no
|
||||
* naming convention to indicate which node outputs the bounding box, which
|
||||
* node outputs the label, etc. So, the `onnxobjectdetector` element has
|
||||
* properties to map each node's functionality to its respective node index in
|
||||
* the specified model. Image resolution also need to be adapted to the model.
|
||||
* The videoscale in the pipeline below will scale the image, using padding if
|
||||
* required, to 640x383 resolution required by the model.)
|
||||
*
|
||||
* model.onnx can be found here:
|
||||
* https://github.com/zoq/onnx-runtime-examples/raw/main/data/models/model.onnx
|
||||
*
|
||||
* ```
|
||||
* GST_DEBUG=objectdetector:5 gst-launch-1.0 multifilesrc \
|
||||
* location=000000088462.jpg caps=image/jpeg,framerate=\(fraction\)30/1 ! jpegdec ! \
|
||||
* videoconvert ! \
|
||||
* onnxobjectdetector \
|
||||
* videoscale ! \
|
||||
* 'video/x-raw,width=640,height=383' ! \
|
||||
* onnxobjectdetector ! \
|
||||
* box-node-index=0 \
|
||||
* class-node-index=1 \
|
||||
* score-node-index=2 \
|
||||
|
|
|
@ -3,7 +3,7 @@ if get_option('onnx').disabled()
|
|||
endif
|
||||
|
||||
|
||||
onnxrt_dep = dependency('libonnxruntime',required : get_option('onnx'))
|
||||
onnxrt_dep = dependency('libonnxruntime', version : '>= 1.13.1', required : get_option('onnx'))
|
||||
|
||||
if onnxrt_dep.found()
|
||||
onnxrt_include_root = onnxrt_dep.get_variable('includedir')
|
||||
|
|
Loading…
Reference in a new issue