diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp index f47abf1c92..a8600d2052 100644 --- a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp +++ b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp @@ -73,6 +73,7 @@ GstOnnxClient::GstOnnxClient ():session (nullptr), GstOnnxClient::~GstOnnxClient () { + outputNames.clear(); delete session; delete[]dest; } @@ -115,6 +116,10 @@ std::string GstOnnxClient::getOutputNodeName (GstMlOutputNodeFunction nodeType) case GST_ML_OUTPUT_NODE_FUNCTION_CLASS: return "label"; break; + case GST_ML_OUTPUT_NODE_NUMBER_OF: + g_assert_not_reached(); + GST_WARNING("Invalid parameter"); + break; }; return ""; @@ -130,9 +135,16 @@ GstMlModelInputImageFormat GstOnnxClient::getInputImageFormat (void) return inputImageFormat; } -std::vector < const char *>GstOnnxClient::getOutputNodeNames (void) +std::vector< const char *> GstOnnxClient::getOutputNodeNames (void) { - return outputNames; + if (!outputNames.empty() && outputNamesRaw.size() != outputNames.size()) { + outputNamesRaw.resize(outputNames.size()); + for (size_t i = 0; i < outputNamesRaw.size(); i++) { + outputNamesRaw[i] = outputNames[i].get(); + } + } + + return outputNamesRaw; } void GstOnnxClient::setOutputNodeIndex (GstMlOutputNodeFunction node, @@ -227,11 +239,13 @@ bool GstOnnxClient::createSession (std::string modelFile, GST_DEBUG ("Number of Output Nodes: %d", (gint) session->GetOutputCount ()); Ort::AllocatorWithDefaultOptions allocator; - GST_DEBUG ("Input name: %s", session->GetInputName (0, allocator)); + auto input_name = session->GetInputNameAllocated (0, allocator); + GST_DEBUG ("Input name: %s", input_name.get()); for (size_t i = 0; i < session->GetOutputCount (); ++i) { - auto output_name = session->GetOutputName (i, allocator); - outputNames.push_back (output_name); + auto output_name = session->GetOutputNameAllocated (i, allocator); + GST_DEBUG("Output name %lu:%s", i, output_name.get()); + outputNames.push_back (std::move(output_name)); auto type_info = session->GetOutputTypeInfo (i); auto tensor_info = type_info.GetTensorTypeAndShapeInfo (); @@ -278,7 +292,7 @@ template < typename T > std::vector < GstMlBoundingBox > parseDimensions (vmeta); Ort::AllocatorWithDefaultOptions allocator; - auto inputName = session->GetInputName (0, allocator); + auto inputName = session->GetInputNameAllocated (0, allocator); auto inputTypeInfo = session->GetInputTypeInfo (0); std::vector < int64_t > inputDims = inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape (); @@ -366,11 +380,11 @@ template < typename T > std::vector < GstMlBoundingBox > std::vector < Ort::Value > inputTensors; inputTensors.push_back (Ort::Value::CreateTensor < uint8_t > (memoryInfo, dest, inputTensorSize, inputDims.data (), inputDims.size ())); - std::vector < const char *>inputNames { inputName }; + std::vector < const char *>inputNames { inputName.get () }; std::vector < Ort::Value > modelOutput = session->Run (Ort::RunOptions { nullptr}, inputNames.data (), - inputTensors.data (), 1, outputNames.data (), outputNames.size ()); + inputTensors.data (), 1, outputNamesRaw.data (), outputNamesRaw.size ()); auto numDetections = modelOutput[getOutputNodeIndex diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.h b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.h index 769cd11550..edbc2f4655 100644 --- a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.h +++ b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.h @@ -108,7 +108,8 @@ namespace GstOnnxNamespace { GstMlOutputNodeInfo outputNodeInfo[GST_ML_OUTPUT_NODE_NUMBER_OF]; // !! indexed by array index size_t outputNodeIndexToFunction[GST_ML_OUTPUT_NODE_NUMBER_OF]; - std::vector < const char *>outputNames; + std::vector < const char *> outputNamesRaw; + std::vector < Ort::AllocatedStringPtr > outputNames; GstMlModelInputImageFormat inputImageFormat; bool fixedInputImageSize; }; diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstonnxobjectdetector.cpp b/subprojects/gst-plugins-bad/ext/onnx/gstonnxobjectdetector.cpp index 28f4cf2fa0..680b02f3f8 100644 --- a/subprojects/gst-plugins-bad/ext/onnx/gstonnxobjectdetector.cpp +++ b/subprojects/gst-plugins-bad/ext/onnx/gstonnxobjectdetector.cpp @@ -55,16 +55,24 @@ * * ## Example launch command: * - * (note: an object detection model has 3 or 4 output nodes, but there is no naming convention - * to indicate which node outputs the bounding box, which node outputs the label, etc. - * So, the `onnxobjectdetector` element has properties to map each node's functionality to its - * respective node index in the specified model ) + * (note: an object detection model has 3 or 4 output nodes, but there is no + * naming convention to indicate which node outputs the bounding box, which + * node outputs the label, etc. So, the `onnxobjectdetector` element has + * properties to map each node's functionality to its respective node index in + * the specified model. Image resolution also need to be adapted to the model. + * The videoscale in the pipeline below will scale the image, using padding if + * required, to 640x383 resolution required by the model.) + * + * model.onnx can be found here: + * https://github.com/zoq/onnx-runtime-examples/raw/main/data/models/model.onnx * * ``` * GST_DEBUG=objectdetector:5 gst-launch-1.0 multifilesrc \ * location=000000088462.jpg caps=image/jpeg,framerate=\(fraction\)30/1 ! jpegdec ! \ * videoconvert ! \ - * onnxobjectdetector \ + * videoscale ! \ + * 'video/x-raw,width=640,height=383' ! \ + * onnxobjectdetector ! \ * box-node-index=0 \ * class-node-index=1 \ * score-node-index=2 \ diff --git a/subprojects/gst-plugins-bad/ext/onnx/meson.build b/subprojects/gst-plugins-bad/ext/onnx/meson.build index ff91739746..e66d649e03 100644 --- a/subprojects/gst-plugins-bad/ext/onnx/meson.build +++ b/subprojects/gst-plugins-bad/ext/onnx/meson.build @@ -3,7 +3,7 @@ if get_option('onnx').disabled() endif -onnxrt_dep = dependency('libonnxruntime',required : get_option('onnx')) +onnxrt_dep = dependency('libonnxruntime', version : '>= 1.13.1', required : get_option('onnx')) if onnxrt_dep.found() onnxrt_include_root = onnxrt_dep.get_variable('includedir')