onnx: add plugin to apply ONNX neural network models to video

This MR provides a transform element that leverage ONNX runtime
to run AI inference on a broad range of neural network toolkits, running
on either CPU or GPU. ONNX supports 16 different providers at the
moment, so with ONNX we immediately get support for Nvidia, AMD, Xilinx
and many others.

For the first release, this plugin adds a gstonnxobjectdetector element to
detect objects in video frames. Meta data generated by the model is
attached to the video buffer as a custom GstObjectDetectorMeta meta.

Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-bad/-/merge_requests/1997>
This commit is contained in:
Aaron Boxer 2021-02-23 11:56:53 -05:00 committed by Nicolas Dufresne
parent f1ec6ddd5e
commit f71eb29497
10 changed files with 1545 additions and 0 deletions

View file

@ -36,6 +36,7 @@ subdir('mplex')
subdir('musepack')
subdir('neon')
subdir('ofa')
subdir('onnx')
subdir('openal')
subdir('openaptx')
subdir('opencv')

40
ext/onnx/gstonnx.c Normal file
View file

@ -0,0 +1,40 @@
/*
* GStreamer gstreamer-onnx
* Copyright (C) 2021 Collabora Ltd
*
* gstonnx.c
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public
* License along with this library; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
#ifdef HAVE_CONFIG_H
#include "config.h"
#endif
#include "gstonnxobjectdetector.h"
static gboolean
plugin_init (GstPlugin * plugin)
{
GST_ELEMENT_REGISTER (onnx_object_detector, plugin);
return TRUE;
}
GST_PLUGIN_DEFINE (GST_VERSION_MAJOR,
GST_VERSION_MINOR,
onnx,
"ONNX neural network plugin",
plugin_init, VERSION, GST_LICENSE, GST_PACKAGE_NAME, GST_PACKAGE_ORIGIN);

423
ext/onnx/gstonnxclient.cpp Normal file
View file

@ -0,0 +1,423 @@
/*
* GStreamer gstreamer-onnxclient
* Copyright (C) 2021 Collabora Ltd
*
* gstonnxclient.cpp
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public
* License along with this library; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
#include "gstonnxclient.h"
#include <providers/cpu/cpu_provider_factory.h>
#ifdef GST_ML_ONNX_RUNTIME_HAVE_CUDA
#include <providers/cuda/cuda_provider_factory.h>
#endif
#include <exception>
#include <fstream>
#include <iostream>
#include <limits>
#include <numeric>
#include <cmath>
#include <sstream>
namespace GstOnnxNamespace
{
template < typename T >
std::ostream & operator<< (std::ostream & os, const std::vector < T > &v)
{
os << "[";
for (size_t i = 0; i < v.size (); ++i)
{
os << v[i];
if (i != v.size () - 1)
{
os << ", ";
}
}
os << "]";
return os;
}
GstMlOutputNodeInfo::GstMlOutputNodeInfo (void):index
(GST_ML_NODE_INDEX_DISABLED),
type (ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
{
}
GstOnnxClient::GstOnnxClient ():session (nullptr),
width (0),
height (0),
channels (0),
dest (nullptr),
m_provider (GST_ONNX_EXECUTION_PROVIDER_CPU),
inputImageFormat (GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC),
fixedInputImageSize (true)
{
for (size_t i = 0; i < GST_ML_OUTPUT_NODE_NUMBER_OF; ++i)
outputNodeIndexToFunction[i] = (GstMlOutputNodeFunction) i;
}
GstOnnxClient::~GstOnnxClient ()
{
delete session;
delete[]dest;
}
Ort::Env & GstOnnxClient::getEnv (void)
{
static Ort::Env env (OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING,
"GstOnnxNamespace");
return env;
}
int32_t GstOnnxClient::getWidth (void)
{
return width;
}
int32_t GstOnnxClient::getHeight (void)
{
return height;
}
bool GstOnnxClient::isFixedInputImageSize (void)
{
return fixedInputImageSize;
}
std::string GstOnnxClient::getOutputNodeName (GstMlOutputNodeFunction nodeType)
{
switch (nodeType) {
case GST_ML_OUTPUT_NODE_FUNCTION_DETECTION:
return "detection";
break;
case GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX:
return "bounding box";
break;
case GST_ML_OUTPUT_NODE_FUNCTION_SCORE:
return "score";
break;
case GST_ML_OUTPUT_NODE_FUNCTION_CLASS:
return "label";
break;
};
return "";
}
void GstOnnxClient::setInputImageFormat (GstMlModelInputImageFormat format)
{
inputImageFormat = format;
}
GstMlModelInputImageFormat GstOnnxClient::getInputImageFormat (void)
{
return inputImageFormat;
}
std::vector < const char *>GstOnnxClient::getOutputNodeNames (void)
{
return outputNames;
}
void GstOnnxClient::setOutputNodeIndex (GstMlOutputNodeFunction node,
gint index)
{
g_assert (index < GST_ML_OUTPUT_NODE_NUMBER_OF);
outputNodeInfo[node].index = index;
if (index != GST_ML_NODE_INDEX_DISABLED)
outputNodeIndexToFunction[index] = node;
}
gint GstOnnxClient::getOutputNodeIndex (GstMlOutputNodeFunction node)
{
return outputNodeInfo[node].index;
}
void GstOnnxClient::setOutputNodeType (GstMlOutputNodeFunction node,
ONNXTensorElementDataType type)
{
outputNodeInfo[node].type = type;
}
ONNXTensorElementDataType
GstOnnxClient::getOutputNodeType (GstMlOutputNodeFunction node)
{
return outputNodeInfo[node].type;
}
bool GstOnnxClient::hasSession (void)
{
return session != nullptr;
}
bool GstOnnxClient::createSession (std::string modelFile,
GstOnnxOptimizationLevel optim, GstOnnxExecutionProvider provider)
{
if (session)
return true;
GraphOptimizationLevel onnx_optim;
switch (optim) {
case GST_ONNX_OPTIMIZATION_LEVEL_DISABLE_ALL:
onnx_optim = GraphOptimizationLevel::ORT_DISABLE_ALL;
break;
case GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_BASIC:
onnx_optim = GraphOptimizationLevel::ORT_ENABLE_BASIC;
break;
case GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED:
onnx_optim = GraphOptimizationLevel::ORT_ENABLE_EXTENDED;
break;
case GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_ALL:
onnx_optim = GraphOptimizationLevel::ORT_ENABLE_ALL;
break;
default:
onnx_optim = GraphOptimizationLevel::ORT_ENABLE_EXTENDED;
break;
};
Ort::SessionOptions sessionOptions;
// for debugging
//sessionOptions.SetIntraOpNumThreads (1);
sessionOptions.SetGraphOptimizationLevel (onnx_optim);
m_provider = provider;
switch (m_provider) {
case GST_ONNX_EXECUTION_PROVIDER_CUDA:
#ifdef GST_ML_ONNX_RUNTIME_HAVE_CUDA
Ort::ThrowOnError (OrtSessionOptionsAppendExecutionProvider_CUDA
(sessionOptions, 0));
#else
return false;
#endif
break;
default:
break;
};
session = new Ort::Session (getEnv (), modelFile.c_str (), sessionOptions);
auto inputTypeInfo = session->GetInputTypeInfo (0);
std::vector < int64_t > inputDims =
inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape ();
if (inputImageFormat == GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC) {
height = inputDims[1];
width = inputDims[2];
channels = inputDims[3];
} else {
channels = inputDims[1];
height = inputDims[2];
width = inputDims[3];
}
fixedInputImageSize = width > 0 && height > 0;
GST_DEBUG ("Number of Output Nodes: %d", (gint) session->GetOutputCount ());
Ort::AllocatorWithDefaultOptions allocator;
GST_DEBUG ("Input name: %s", session->GetInputName (0, allocator));
for (size_t i = 0; i < session->GetOutputCount (); ++i) {
auto output_name = session->GetOutputName (i, allocator);
outputNames.push_back (output_name);
auto type_info = session->GetOutputTypeInfo (i);
auto tensor_info = type_info.GetTensorTypeAndShapeInfo ();
if (i < GST_ML_OUTPUT_NODE_NUMBER_OF) {
auto function = outputNodeIndexToFunction[i];
outputNodeInfo[function].type = tensor_info.GetElementType ();
}
}
return true;
}
std::vector < GstMlBoundingBox > GstOnnxClient::run (uint8_t * img_data,
GstVideoMeta * vmeta, std::string labelPath, float scoreThreshold)
{
auto type = getOutputNodeType (GST_ML_OUTPUT_NODE_FUNCTION_CLASS);
return (type ==
ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) ?
doRun < float >(img_data, vmeta, labelPath, scoreThreshold)
: doRun < int >(img_data, vmeta, labelPath, scoreThreshold);
}
void GstOnnxClient::parseDimensions (GstVideoMeta * vmeta)
{
int32_t newWidth = fixedInputImageSize ? width : vmeta->width;
int32_t newHeight = fixedInputImageSize ? height : vmeta->height;
if (!dest || width * height < newWidth * newHeight) {
delete[] dest;
dest = new uint8_t[newWidth * newHeight * channels];
}
width = newWidth;
height = newHeight;
}
template < typename T > std::vector < GstMlBoundingBox >
GstOnnxClient::doRun (uint8_t * img_data, GstVideoMeta * vmeta,
std::string labelPath, float scoreThreshold)
{
std::vector < GstMlBoundingBox > boundingBoxes;
if (!img_data)
return boundingBoxes;
parseDimensions (vmeta);
Ort::AllocatorWithDefaultOptions allocator;
auto inputName = session->GetInputName (0, allocator);
auto inputTypeInfo = session->GetInputTypeInfo (0);
std::vector < int64_t > inputDims =
inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape ();
inputDims[0] = 1;
if (inputImageFormat == GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC) {
inputDims[1] = height;
inputDims[2] = width;
} else {
inputDims[2] = height;
inputDims[3] = width;
}
std::ostringstream buffer;
buffer << inputDims;
GST_DEBUG ("Input dimensions: %s", buffer.str ().c_str ());
// copy video frame
uint8_t *srcPtr[3] = { img_data, img_data + 1, img_data + 2 };
uint32_t srcSamplesPerPixel = 3;
switch (vmeta->format) {
case GST_VIDEO_FORMAT_RGBA:
srcSamplesPerPixel = 4;
break;
case GST_VIDEO_FORMAT_BGRA:
srcSamplesPerPixel = 4;
srcPtr[0] = img_data + 2;
srcPtr[1] = img_data + 1;
srcPtr[2] = img_data + 0;
break;
case GST_VIDEO_FORMAT_ARGB:
srcSamplesPerPixel = 4;
srcPtr[0] = img_data + 1;
srcPtr[1] = img_data + 2;
srcPtr[2] = img_data + 3;
break;
case GST_VIDEO_FORMAT_ABGR:
srcSamplesPerPixel = 4;
srcPtr[0] = img_data + 3;
srcPtr[1] = img_data + 2;
srcPtr[2] = img_data + 1;
break;
case GST_VIDEO_FORMAT_BGR:
srcPtr[0] = img_data + 2;
srcPtr[1] = img_data + 1;
srcPtr[2] = img_data + 0;
break;
default:
break;
}
size_t destIndex = 0;
uint32_t stride = vmeta->stride[0];
if (inputImageFormat == GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC) {
for (int32_t j = 0; j < height; ++j) {
for (int32_t i = 0; i < width; ++i) {
for (int32_t k = 0; k < channels; ++k) {
dest[destIndex++] = *srcPtr[k];
srcPtr[k] += srcSamplesPerPixel;
}
}
// correct for stride
for (uint32_t k = 0; k < 3; ++k)
srcPtr[k] += stride - srcSamplesPerPixel * width;
}
} else {
size_t frameSize = width * height;
uint8_t *destPtr[3] = { dest, dest + frameSize, dest + 2 * frameSize };
for (int32_t j = 0; j < height; ++j) {
for (int32_t i = 0; i < width; ++i) {
for (int32_t k = 0; k < channels; ++k) {
destPtr[k][destIndex] = *srcPtr[k];
srcPtr[k] += srcSamplesPerPixel;
}
destIndex++;
}
// correct for stride
for (uint32_t k = 0; k < 3; ++k)
srcPtr[k] += stride - srcSamplesPerPixel * width;
}
}
const size_t inputTensorSize = width * height * channels;
auto memoryInfo =
Ort::MemoryInfo::CreateCpu (OrtAllocatorType::OrtArenaAllocator,
OrtMemType::OrtMemTypeDefault);
std::vector < Ort::Value > inputTensors;
inputTensors.push_back (Ort::Value::CreateTensor < uint8_t > (memoryInfo,
dest, inputTensorSize, inputDims.data (), inputDims.size ()));
std::vector < const char *>inputNames { inputName };
std::vector < Ort::Value > modelOutput = session->Run (Ort::RunOptions { nullptr},
inputNames.data (),
inputTensors.data (), 1, outputNames.data (), outputNames.size ());
auto numDetections =
modelOutput[getOutputNodeIndex
(GST_ML_OUTPUT_NODE_FUNCTION_DETECTION)].GetTensorMutableData < float >();
auto bboxes =
modelOutput[getOutputNodeIndex
(GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX)].GetTensorMutableData < float >();
auto scores =
modelOutput[getOutputNodeIndex
(GST_ML_OUTPUT_NODE_FUNCTION_SCORE)].GetTensorMutableData < float >();
T *labelIndex = nullptr;
if (getOutputNodeIndex (GST_ML_OUTPUT_NODE_FUNCTION_CLASS) !=
GST_ML_NODE_INDEX_DISABLED) {
labelIndex =
modelOutput[getOutputNodeIndex
(GST_ML_OUTPUT_NODE_FUNCTION_CLASS)].GetTensorMutableData < T > ();
}
if (labels.empty () && !labelPath.empty ())
labels = ReadLabels (labelPath);
for (int i = 0; i < numDetections[0]; ++i) {
if (scores[i] > scoreThreshold) {
std::string label = "";
if (labelIndex && !labels.empty ())
label = labels[labelIndex[i] - 1];
auto score = scores[i];
auto y0 = bboxes[i * 4] * height;
auto x0 = bboxes[i * 4 + 1] * width;
auto bheight = bboxes[i * 4 + 2] * height - y0;
auto bwidth = bboxes[i * 4 + 3] * width - x0;
boundingBoxes.push_back (GstMlBoundingBox (label, score, x0, y0, bwidth,
bheight));
}
}
return boundingBoxes;
}
std::vector < std::string >
GstOnnxClient::ReadLabels (const std::string & labelsFile)
{
std::vector < std::string > labels;
std::string line;
std::ifstream fp (labelsFile);
while (std::getline (fp, line))
labels.push_back (line);
return labels;
}
}

117
ext/onnx/gstonnxclient.h Normal file
View file

@ -0,0 +1,117 @@
/*
* GStreamer gstreamer-onnxclient
* Copyright (C) 2021 Collabora Ltd
*
* gstonnxclient.h
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public
* License along with this library; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
#ifndef __GST_ONNX_CLIENT_H__
#define __GST_ONNX_CLIENT_H__
#include <gst/gst.h>
#include <onnxruntime_cxx_api.h>
#include <gst/video/video.h>
#include "gstonnxelement.h"
#include <string>
#include <vector>
namespace GstOnnxNamespace {
enum GstMlOutputNodeFunction {
GST_ML_OUTPUT_NODE_FUNCTION_DETECTION,
GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX,
GST_ML_OUTPUT_NODE_FUNCTION_SCORE,
GST_ML_OUTPUT_NODE_FUNCTION_CLASS,
GST_ML_OUTPUT_NODE_NUMBER_OF,
};
const gint GST_ML_NODE_INDEX_DISABLED = -1;
struct GstMlOutputNodeInfo {
GstMlOutputNodeInfo(void);
gint index;
ONNXTensorElementDataType type;
};
struct GstMlBoundingBox {
GstMlBoundingBox(std::string lbl,
float score,
float _x0,
float _y0,
float _width,
float _height):label(lbl),
score(score), x0(_x0), y0(_y0), width(_width), height(_height) {
}
GstMlBoundingBox():GstMlBoundingBox("", 0.0f, 0.0f, 0.0f, 0.0f, 0.0f) {
}
std::string label;
float score;
float x0;
float y0;
float width;
float height;
};
class GstOnnxClient {
public:
GstOnnxClient(void);
~GstOnnxClient(void);
bool createSession(std::string modelFile, GstOnnxOptimizationLevel optim,
GstOnnxExecutionProvider provider);
bool hasSession(void);
void setInputImageFormat(GstMlModelInputImageFormat format);
GstMlModelInputImageFormat getInputImageFormat(void);
void setOutputNodeIndex(GstMlOutputNodeFunction nodeType, gint index);
gint getOutputNodeIndex(GstMlOutputNodeFunction nodeType);
void setOutputNodeType(GstMlOutputNodeFunction nodeType,
ONNXTensorElementDataType type);
ONNXTensorElementDataType getOutputNodeType(GstMlOutputNodeFunction type);
std::string getOutputNodeName(GstMlOutputNodeFunction nodeType);
std::vector < GstMlBoundingBox > run(uint8_t * img_data,
GstVideoMeta * vmeta,
std::string labelPath,
float scoreThreshold);
std::vector < GstMlBoundingBox > &getBoundingBoxes(void);
std::vector < const char *>getOutputNodeNames(void);
bool isFixedInputImageSize(void);
int32_t getWidth(void);
int32_t getHeight(void);
private:
void parseDimensions(GstVideoMeta * vmeta);
template < typename T > std::vector < GstMlBoundingBox >
doRun(uint8_t * img_data, GstVideoMeta * vmeta, std::string labelPath,
float scoreThreshold);
std::vector < std::string > ReadLabels(const std::string & labelsFile);
Ort::Env & getEnv(void);
Ort::Session * session;
int32_t width;
int32_t height;
int32_t channels;
uint8_t *dest;
GstOnnxExecutionProvider m_provider;
std::vector < Ort::Value > modelOutput;
std::vector < std::string > labels;
// !! indexed by function
GstMlOutputNodeInfo outputNodeInfo[GST_ML_OUTPUT_NODE_NUMBER_OF];
// !! indexed by array index
size_t outputNodeIndexToFunction[GST_ML_OUTPUT_NODE_NUMBER_OF];
std::vector < const char *>outputNames;
GstMlModelInputImageFormat inputImageFormat;
bool fixedInputImageSize;
};
}
#endif /* __GST_ONNX_CLIENT_H__ */

104
ext/onnx/gstonnxelement.c Normal file
View file

@ -0,0 +1,104 @@
/*
* GStreamer gstreamer-onnxelement
* Copyright (C) 2021 Collabora Ltd
*
* gstonnxelement.c
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public
* License along with this library; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
#ifdef HAVE_CONFIG_H
#include "config.h"
#endif
#include "gstonnxelement.h"
GType
gst_onnx_optimization_level_get_type (void)
{
static GType onnx_optimization_type = 0;
if (g_once_init_enter (&onnx_optimization_type)) {
static GEnumValue optimization_level_types[] = {
{GST_ONNX_OPTIMIZATION_LEVEL_DISABLE_ALL, "Disable all optimization",
"disable-all"},
{GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_BASIC,
"Enable basic optimizations (redundant node removals))",
"enable-basic"},
{GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED,
"Enable extended optimizations (redundant node removals + node fusions)",
"enable-extended"},
{GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_ALL,
"Enable all possible optimizations", "enable-all"},
{0, NULL, NULL},
};
GType temp = g_enum_register_static ("GstOnnxOptimizationLevel",
optimization_level_types);
g_once_init_leave (&onnx_optimization_type, temp);
}
return onnx_optimization_type;
}
GType
gst_onnx_execution_provider_get_type (void)
{
static GType onnx_execution_type = 0;
if (g_once_init_enter (&onnx_execution_type)) {
static GEnumValue execution_provider_types[] = {
{GST_ONNX_EXECUTION_PROVIDER_CPU, "CPU execution provider",
"cpu"},
{GST_ONNX_EXECUTION_PROVIDER_CUDA,
"CUDA execution provider",
"cuda"},
{0, NULL, NULL},
};
GType temp = g_enum_register_static ("GstOnnxExecutionProvider",
execution_provider_types);
g_once_init_leave (&onnx_execution_type, temp);
}
return onnx_execution_type;
}
GType
gst_ml_model_input_image_format_get_type (void)
{
static GType ml_model_input_image_format = 0;
if (g_once_init_enter (&ml_model_input_image_format)) {
static GEnumValue ml_model_input_image_format_types[] = {
{GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC,
"Height Width Channel (HWC) a.k.a. interleaved image data format",
"hwc"},
{GST_ML_MODEL_INPUT_IMAGE_FORMAT_CHW,
"Channel Height Width (CHW) a.k.a. planar image data format",
"chw"},
{0, NULL, NULL},
};
GType temp = g_enum_register_static ("GstMlModelInputImageFormat",
ml_model_input_image_format_types);
g_once_init_leave (&ml_model_input_image_format, temp);
}
return ml_model_input_image_format;
}

64
ext/onnx/gstonnxelement.h Normal file
View file

@ -0,0 +1,64 @@
/*
* GStreamer gstreamer-onnxelement
* Copyright (C) 2021 Collabora Ltd
*
* gstonnxelement.h
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public
* License along with this library; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
#ifndef __GST_ONNX_ELEMENT_H__
#define __GST_ONNX_ELEMENT_H__
#include <gst/gst.h>
typedef enum
{
GST_ONNX_OPTIMIZATION_LEVEL_DISABLE_ALL,
GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_BASIC,
GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED,
GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_ALL,
} GstOnnxOptimizationLevel;
typedef enum
{
GST_ONNX_EXECUTION_PROVIDER_CPU,
GST_ONNX_EXECUTION_PROVIDER_CUDA,
} GstOnnxExecutionProvider;
typedef enum {
/* Height Width Channel (a.k.a. interleaved) format */
GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC,
/* Channel Height Width (a.k.a. planar) format */
GST_ML_MODEL_INPUT_IMAGE_FORMAT_CHW,
} GstMlModelInputImageFormat;
G_BEGIN_DECLS
GType gst_onnx_optimization_level_get_type (void);
#define GST_TYPE_ONNX_OPTIMIZATION_LEVEL (gst_onnx_optimization_level_get_type ())
GType gst_onnx_execution_provider_get_type (void);
#define GST_TYPE_ONNX_EXECUTION_PROVIDER (gst_onnx_execution_provider_get_type ())
GType gst_ml_model_input_image_format_get_type (void);
#define GST_TYPE_ML_MODEL_INPUT_IMAGE_FORMAT (gst_ml_model_input_image_format_get_type ())
G_END_DECLS
#endif

View file

@ -0,0 +1,670 @@
/*
* GStreamer gstreamer-onnxobjectdetector
* Copyright (C) 2021 Collabora Ltd.
*
* gstonnxobjectdetector.c
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public
* License along with this library; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
/**
* SECTION:element-onnxobjectdetector
* @short_description: Detect objects in video frame
*
* This element can apply a generic ONNX object detection model such as YOLO or SSD
* to each video frame.
*
* To install ONNX on your system, recursively clone this repository
* https://github.com/microsoft/onnxruntime.git
*
* and build and install with cmake:
*
* CPU:
*
* cmake -Donnxruntime_BUILD_SHARED_LIB:ON -DBUILD_TESTING:OFF \
* $SRC_DIR/onnxruntime/cmake && make -j8 && sudo make install
*
*
* GPU :
*
* cmake -Donnxruntime_BUILD_SHARED_LIB:ON -DBUILD_TESTING:OFF -Donnxruntime_USE_CUDA:ON \
* -Donnxruntime_CUDA_HOME=$CUDA_PATH -Donnxruntime_CUDNN_HOME=$CUDA_PATH \
* $SRC_DIR/onnxruntime/cmake && make -j8 && sudo make install
*
*
* where :
*
* 1. $SRC_DIR and $BUILD_DIR are local source and build directories
* 2. To run with CUDA, both CUDA and cuDNN libraries must be installed.
* $CUDA_PATH is an environment variable set to the CUDA root path.
* On Linux, it would be /usr/local/cuda-XX.X where XX.X is the installed version of CUDA.
*
*
* ## Example launch command:
*
* (note: an object detection model has 3 or 4 output nodes, but there is no naming convention
* to indicate which node outputs the bounding box, which node outputs the label, etc.
* So, the `onnxobjectdetector` element has properties to map each node's functionality to its
* respective node index in the specified model )
*
* ```
* GST_DEBUG=objectdetector:5 gst-launch-1.0 multifilesrc \
* location=000000088462.jpg caps=image/jpeg,framerate=\(fraction\)30/1 ! jpegdec ! \
* videoconvert ! \
* onnxobjectdetector \
* box-node-index=0 \
* class-node-index=1 \
* score-node-index=2 \
* detection-node-index=3 \
* execution-provider=cpu \
* model-file=model.onnx \
* label-file=COCO_classes.txt ! \
* videoconvert ! \
* autovideosink
* ```
*/
#ifdef HAVE_CONFIG_H
#include "config.h"
#endif
#include "gstonnxobjectdetector.h"
#include "gstonnxclient.h"
#include <gst/gst.h>
#include <gst/video/video.h>
#include <gst/video/gstvideometa.h>
#include <stdlib.h>
#include <string.h>
#include <glib.h>
GST_DEBUG_CATEGORY_STATIC (onnx_object_detector_debug);
#define GST_CAT_DEFAULT onnx_object_detector_debug
#define GST_ONNX_MEMBER( self ) ((GstOnnxNamespace::GstOnnxClient *) (self->onnx_ptr))
GST_ELEMENT_REGISTER_DEFINE (onnx_object_detector, "onnxobjectdetector",
GST_RANK_PRIMARY, GST_TYPE_ONNX_OBJECT_DETECTOR);
/* GstOnnxObjectDetector properties */
enum
{
PROP_0,
PROP_MODEL_FILE,
PROP_LABEL_FILE,
PROP_SCORE_THRESHOLD,
PROP_DETECTION_NODE_INDEX,
PROP_BOUNDING_BOX_NODE_INDEX,
PROP_SCORE_NODE_INDEX,
PROP_CLASS_NODE_INDEX,
PROP_INPUT_IMAGE_FORMAT,
PROP_OPTIMIZATION_LEVEL,
PROP_EXECUTION_PROVIDER
};
#define GST_ONNX_OBJECT_DETECTOR_DEFAULT_EXECUTION_PROVIDER GST_ONNX_EXECUTION_PROVIDER_CPU
#define GST_ONNX_OBJECT_DETECTOR_DEFAULT_OPTIMIZATION_LEVEL GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED
#define GST_ONNX_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD 0.3f /* 0 to 1 */
static GstStaticPadTemplate gst_onnx_object_detector_src_template =
GST_STATIC_PAD_TEMPLATE ("src",
GST_PAD_SRC,
GST_PAD_ALWAYS,
GST_STATIC_CAPS (GST_VIDEO_CAPS_MAKE ("{ RGB,RGBA,BGR,BGRA }"))
);
static GstStaticPadTemplate gst_onnx_object_detector_sink_template =
GST_STATIC_PAD_TEMPLATE ("sink",
GST_PAD_SINK,
GST_PAD_ALWAYS,
GST_STATIC_CAPS (GST_VIDEO_CAPS_MAKE ("{ RGB,RGBA,BGR,BGRA }"))
);
static void gst_onnx_object_detector_set_property (GObject * object,
guint prop_id, const GValue * value, GParamSpec * pspec);
static void gst_onnx_object_detector_get_property (GObject * object,
guint prop_id, GValue * value, GParamSpec * pspec);
static void gst_onnx_object_detector_finalize (GObject * object);
static GstFlowReturn gst_onnx_object_detector_transform_ip (GstBaseTransform *
trans, GstBuffer * buf);
static gboolean gst_onnx_object_detector_process (GstBaseTransform * trans,
GstBuffer * buf);
static gboolean gst_onnx_object_detector_create_session (GstBaseTransform * trans);
static GstCaps *gst_onnx_object_detector_transform_caps (GstBaseTransform *
trans, GstPadDirection direction, GstCaps * caps, GstCaps * filter_caps);
G_DEFINE_TYPE (GstOnnxObjectDetector, gst_onnx_object_detector,
GST_TYPE_BASE_TRANSFORM);
static void
gst_onnx_object_detector_class_init (GstOnnxObjectDetectorClass * klass)
{
GObjectClass *gobject_class = (GObjectClass *) klass;
GstElementClass *element_class = (GstElementClass *) klass;
GstBaseTransformClass *basetransform_class = (GstBaseTransformClass *) klass;
GST_DEBUG_CATEGORY_INIT (onnx_object_detector_debug, "onnxobjectdetector",
0, "onnx_objectdetector");
gobject_class->set_property = gst_onnx_object_detector_set_property;
gobject_class->get_property = gst_onnx_object_detector_get_property;
gobject_class->finalize = gst_onnx_object_detector_finalize;
/**
* GstOnnxObjectDetector:model-file
*
* ONNX model file
*
* Since: 1.20
*/
g_object_class_install_property (G_OBJECT_CLASS (klass), PROP_MODEL_FILE,
g_param_spec_string ("model-file",
"ONNX model file", "ONNX model file", NULL, (GParamFlags)
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
/**
* GstOnnxObjectDetector:label-file
*
* Label file for ONNX model
*
* Since: 1.20
*/
g_object_class_install_property (G_OBJECT_CLASS (klass), PROP_LABEL_FILE,
g_param_spec_string ("label-file",
"Label file", "Label file associated with model", NULL, (GParamFlags)
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
/**
* GstOnnxObjectDetector:detection-node-index
*
* Index of model detection node
*
* Since: 1.20
*/
g_object_class_install_property (G_OBJECT_CLASS (klass),
PROP_DETECTION_NODE_INDEX,
g_param_spec_int ("detection-node-index",
"Detection node index",
"Index of neural network output node corresponding to number of detected objects",
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED,
GstOnnxNamespace::GST_ML_OUTPUT_NODE_NUMBER_OF-1,
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED, (GParamFlags)
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
/**
* GstOnnxObjectDetector:bounding-box-node-index
*
* Index of model bounding box node
*
* Since: 1.20
*/
g_object_class_install_property (G_OBJECT_CLASS (klass),
PROP_BOUNDING_BOX_NODE_INDEX,
g_param_spec_int ("box-node-index",
"Bounding box node index",
"Index of neural network output node corresponding to bounding box",
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED,
GstOnnxNamespace::GST_ML_OUTPUT_NODE_NUMBER_OF-1,
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED, (GParamFlags)
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
/**
* GstOnnxObjectDetector:score-node-index
*
* Index of model score node
*
* Since: 1.20
*/
g_object_class_install_property (G_OBJECT_CLASS (klass),
PROP_SCORE_NODE_INDEX,
g_param_spec_int ("score-node-index",
"Score node index",
"Index of neural network output node corresponding to score",
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED,
GstOnnxNamespace::GST_ML_OUTPUT_NODE_NUMBER_OF-1,
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED, (GParamFlags)
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
/**
* GstOnnxObjectDetector:class-node-index
*
* Index of model class (label) node
*
* Since: 1.20
*/
g_object_class_install_property (G_OBJECT_CLASS (klass),
PROP_CLASS_NODE_INDEX,
g_param_spec_int ("class-node-index",
"Class node index",
"Index of neural network output node corresponding to class (label)",
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED,
GstOnnxNamespace::GST_ML_OUTPUT_NODE_NUMBER_OF-1,
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED, (GParamFlags)
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
/**
* GstOnnxObjectDetector:score-threshold
*
* Threshold for deciding when to remove boxes based on score
*
* Since: 1.20
*/
g_object_class_install_property (G_OBJECT_CLASS (klass), PROP_SCORE_THRESHOLD,
g_param_spec_float ("score-threshold",
"Score threshold",
"Threshold for deciding when to remove boxes based on score",
0.0, 1.0,
GST_ONNX_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD, (GParamFlags)
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
/**
* GstOnnxObjectDetector:input-image-format
*
* Model input image format
*
* Since: 1.20
*/
g_object_class_install_property (G_OBJECT_CLASS (klass),
PROP_INPUT_IMAGE_FORMAT,
g_param_spec_enum ("input-image-format",
"Input image format",
"Input image format",
GST_TYPE_ML_MODEL_INPUT_IMAGE_FORMAT,
GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC, (GParamFlags)
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
/**
* GstOnnxObjectDetector:optimization-level
*
* ONNX optimization level
*
* Since: 1.20
*/
g_object_class_install_property (G_OBJECT_CLASS (klass),
PROP_OPTIMIZATION_LEVEL,
g_param_spec_enum ("optimization-level",
"Optimization level",
"ONNX optimization level",
GST_TYPE_ONNX_OPTIMIZATION_LEVEL,
GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED, (GParamFlags)
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
/**
* GstOnnxObjectDetector:execution-provider
*
* ONNX execution provider
*
* Since: 1.20
*/
g_object_class_install_property (G_OBJECT_CLASS (klass),
PROP_EXECUTION_PROVIDER,
g_param_spec_enum ("execution-provider",
"Execution provider",
"ONNX execution provider",
GST_TYPE_ONNX_EXECUTION_PROVIDER,
GST_ONNX_EXECUTION_PROVIDER_CPU, (GParamFlags)
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
gst_element_class_set_static_metadata (element_class, "onnxobjectdetector",
"Filter/Effect/Video",
"Apply neural network to detect objects in video frames",
"Aaron Boxer <aaron.boxer@collabora.com>, Marcus Edel <marcus.edel@collabora.com>");
gst_element_class_add_pad_template (element_class,
gst_static_pad_template_get (&gst_onnx_object_detector_sink_template));
gst_element_class_add_pad_template (element_class,
gst_static_pad_template_get (&gst_onnx_object_detector_src_template));
basetransform_class->transform_ip =
GST_DEBUG_FUNCPTR (gst_onnx_object_detector_transform_ip);
basetransform_class->transform_caps =
GST_DEBUG_FUNCPTR (gst_onnx_object_detector_transform_caps);
}
static void
gst_onnx_object_detector_init (GstOnnxObjectDetector * self)
{
self->onnx_ptr = new GstOnnxNamespace::GstOnnxClient ();
self->onnx_disabled = false;
}
static void
gst_onnx_object_detector_finalize (GObject * object)
{
GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (object);
g_free (self->model_file);
delete GST_ONNX_MEMBER (self);
G_OBJECT_CLASS (gst_onnx_object_detector_parent_class)->finalize (object);
}
static void
gst_onnx_object_detector_set_property (GObject * object, guint prop_id,
const GValue * value, GParamSpec * pspec)
{
GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (object);
const gchar *filename;
auto onnxClient = GST_ONNX_MEMBER (self);
switch (prop_id) {
case PROP_MODEL_FILE:
filename = g_value_get_string (value);
if (filename
&& g_file_test (filename,
(GFileTest) (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR))) {
if (self->model_file)
g_free (self->model_file);
self->model_file = g_strdup (filename);
} else {
GST_WARNING_OBJECT (self, "Model file '%s' not found!", filename);
gst_base_transform_set_passthrough (GST_BASE_TRANSFORM (self), TRUE);
}
break;
case PROP_LABEL_FILE:
filename = g_value_get_string (value);
if (filename
&& g_file_test (filename,
(GFileTest) (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR))) {
if (self->label_file)
g_free (self->label_file);
self->label_file = g_strdup (filename);
} else {
GST_WARNING_OBJECT (self, "Label file '%s' not found!", filename);
}
break;
case PROP_SCORE_THRESHOLD:
GST_OBJECT_LOCK (self);
self->score_threshold = g_value_get_float (value);
GST_OBJECT_UNLOCK (self);
break;
case PROP_OPTIMIZATION_LEVEL:
self->optimization_level =
(GstOnnxOptimizationLevel) g_value_get_enum (value);
break;
case PROP_EXECUTION_PROVIDER:
self->execution_provider =
(GstOnnxExecutionProvider) g_value_get_enum (value);
break;
case PROP_DETECTION_NODE_INDEX:
onnxClient->setOutputNodeIndex
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_DETECTION,
g_value_get_int (value));
break;
case PROP_BOUNDING_BOX_NODE_INDEX:
onnxClient->setOutputNodeIndex
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX,
g_value_get_int (value));
break;
break;
case PROP_SCORE_NODE_INDEX:
onnxClient->setOutputNodeIndex
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_SCORE,
g_value_get_int (value));
break;
break;
case PROP_CLASS_NODE_INDEX:
onnxClient->setOutputNodeIndex
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_CLASS,
g_value_get_int (value));
break;
case PROP_INPUT_IMAGE_FORMAT:
onnxClient->setInputImageFormat ((GstMlModelInputImageFormat)
g_value_get_enum (value));
break;
default:
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
break;
}
}
static void
gst_onnx_object_detector_get_property (GObject * object, guint prop_id,
GValue * value, GParamSpec * pspec)
{
GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (object);
auto onnxClient = GST_ONNX_MEMBER (self);
switch (prop_id) {
case PROP_MODEL_FILE:
g_value_set_string (value, self->model_file);
break;
case PROP_LABEL_FILE:
g_value_set_string (value, self->label_file);
break;
case PROP_SCORE_THRESHOLD:
GST_OBJECT_LOCK (self);
g_value_set_float (value, self->score_threshold);
GST_OBJECT_UNLOCK (self);
break;
case PROP_OPTIMIZATION_LEVEL:
g_value_set_enum (value, self->optimization_level);
break;
case PROP_EXECUTION_PROVIDER:
g_value_set_enum (value, self->execution_provider);
break;
case PROP_DETECTION_NODE_INDEX:
g_value_set_int (value,
onnxClient->getOutputNodeIndex
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_DETECTION));
break;
case PROP_BOUNDING_BOX_NODE_INDEX:
g_value_set_int (value,
onnxClient->getOutputNodeIndex
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX));
break;
break;
case PROP_SCORE_NODE_INDEX:
g_value_set_int (value,
onnxClient->getOutputNodeIndex
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_SCORE));
break;
break;
case PROP_CLASS_NODE_INDEX:
g_value_set_int (value,
onnxClient->getOutputNodeIndex
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_CLASS));
break;
case PROP_INPUT_IMAGE_FORMAT:
g_value_set_enum (value, onnxClient->getInputImageFormat ());
break;
default:
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
break;
}
}
static gboolean
gst_onnx_object_detector_create_session (GstBaseTransform * trans)
{
GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (trans);
auto onnxClient = GST_ONNX_MEMBER (self);
GST_OBJECT_LOCK (self);
if (self->onnx_disabled || onnxClient->hasSession ()) {
GST_OBJECT_UNLOCK (self);
return TRUE;
}
if (self->model_file) {
gboolean ret = GST_ONNX_MEMBER (self)->createSession (self->model_file,
self->optimization_level,
self->execution_provider);
if (!ret) {
GST_ERROR_OBJECT (self,
"Unable to create ONNX session. Detection disabled.");
} else {
auto outputNames = onnxClient->getOutputNodeNames ();
for (size_t i = 0; i < outputNames.size (); ++i)
GST_INFO_OBJECT (self, "Output node index: %d for node: %s", (gint) i,
outputNames[i]);
if (outputNames.size () < 3) {
GST_ERROR_OBJECT (self,
"Number of output tensor nodes %d does not match the 3 or 4 nodes "
"required for an object detection model. Detection is disabled.",
(gint) outputNames.size ());
self->onnx_disabled = TRUE;
}
// sanity check on output node indices
if (onnxClient->getOutputNodeIndex
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_DETECTION) ==
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED) {
GST_ERROR_OBJECT (self,
"Output detection node index not set. Detection disabled.");
self->onnx_disabled = TRUE;
}
if (onnxClient->getOutputNodeIndex
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX) ==
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED) {
GST_ERROR_OBJECT (self,
"Output bounding box node index not set. Detection disabled.");
self->onnx_disabled = TRUE;
}
if (onnxClient->getOutputNodeIndex
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_SCORE) ==
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED) {
GST_ERROR_OBJECT (self,
"Output score node index not set. Detection disabled.");
self->onnx_disabled = TRUE;
}
if (outputNames.size () == 4 && onnxClient->getOutputNodeIndex
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_CLASS) ==
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED) {
GST_ERROR_OBJECT (self,
"Output class node index not set. Detection disabled.");
self->onnx_disabled = TRUE;
}
// model is not usable, so fail
if (self->onnx_disabled) {
GST_ELEMENT_WARNING (self, RESOURCE, FAILED,
("ONNX model cannot be used for object detection"), (NULL));
return FALSE;
}
}
} else {
self->onnx_disabled = TRUE;
}
GST_OBJECT_UNLOCK (self);
if (self->onnx_disabled){
gst_base_transform_set_passthrough (GST_BASE_TRANSFORM (self), TRUE);
}
return TRUE;
}
static GstCaps *
gst_onnx_object_detector_transform_caps (GstBaseTransform *
trans, GstPadDirection direction, GstCaps * caps, GstCaps * filter_caps)
{
GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (trans);
auto onnxClient = GST_ONNX_MEMBER (self);
GstCaps *other_caps;
guint i;
if ( !gst_onnx_object_detector_create_session (trans) )
return NULL;
GST_LOG_OBJECT (self, "transforming caps %" GST_PTR_FORMAT, caps);
if (gst_base_transform_is_passthrough (trans)
|| (!onnxClient->isFixedInputImageSize ()))
return gst_caps_ref (caps);
other_caps = gst_caps_new_empty ();
for (i = 0; i < gst_caps_get_size (caps); ++i) {
GstStructure *structure, *new_structure;
structure = gst_caps_get_structure (caps, i);
new_structure = gst_structure_copy (structure);
gst_structure_set (new_structure, "width", G_TYPE_INT,
onnxClient->getWidth (), "height", G_TYPE_INT,
onnxClient->getHeight (), NULL);
GST_LOG_OBJECT (self,
"transformed structure %2d: %" GST_PTR_FORMAT " => %"
GST_PTR_FORMAT, i, structure, new_structure);
gst_caps_append_structure (other_caps, new_structure);
}
if (!gst_caps_is_empty (other_caps) && filter_caps) {
GstCaps *tmp = gst_caps_intersect_full (other_caps,filter_caps,
GST_CAPS_INTERSECT_FIRST);
gst_caps_replace (&other_caps, tmp);
gst_caps_unref (tmp);
}
return other_caps;
}
static GstFlowReturn
gst_onnx_object_detector_transform_ip (GstBaseTransform * trans,
GstBuffer * buf)
{
if (!gst_base_transform_is_passthrough (trans)
&& !gst_onnx_object_detector_process (trans, buf)){
GST_ELEMENT_WARNING (trans, STREAM, FAILED,
("ONNX object detection failed"), (NULL));
return GST_FLOW_ERROR;
}
return GST_FLOW_OK;
}
static gboolean
gst_onnx_object_detector_process (GstBaseTransform * trans, GstBuffer * buf)
{
GstMapInfo info;
GstVideoMeta *vmeta = gst_buffer_get_video_meta (buf);
if (!vmeta) {
GST_WARNING_OBJECT (trans, "missing video meta");
return FALSE;
}
if (gst_buffer_map (buf, &info, GST_MAP_READ)) {
GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (trans);
auto boxes = GST_ONNX_MEMBER (self)->run (info.data, vmeta,
self->label_file ? self->label_file : "",
self->score_threshold);
for (auto & b:boxes) {
auto vroi_meta = gst_buffer_add_video_region_of_interest_meta (buf,
GST_ONNX_OBJECT_DETECTOR_META_NAME,
b.x0, b.y0,
b.width,
b.height);
if (!vroi_meta) {
GST_WARNING_OBJECT (trans,
"Unable to attach GstVideoRegionOfInterestMeta to buffer");
return FALSE;
}
auto s = gst_structure_new (GST_ONNX_OBJECT_DETECTOR_META_PARAM_NAME,
GST_ONNX_OBJECT_DETECTOR_META_FIELD_LABEL,
G_TYPE_STRING,
b.label.c_str (),
GST_ONNX_OBJECT_DETECTOR_META_FIELD_SCORE,
G_TYPE_DOUBLE,
b.score,
NULL);
gst_video_region_of_interest_meta_add_param (vroi_meta, s);
GST_DEBUG_OBJECT (self,
"Object detected with label : %s, score: %f, bound box: (%f,%f,%f,%f) \n",
b.label.c_str (), b.score, b.x0, b.y0,
b.x0 + b.width, b.y0 + b.height);
}
gst_buffer_unmap (buf, &info);
}
return TRUE;
}

View file

@ -0,0 +1,93 @@
/*
* GStreamer gstreamer-onnxobjectdetector
* Copyright (C) 2021 Collabora Ltd
*
* gstonnxobjectdetector.h
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public
* License along with this library; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
#ifndef __GST_ONNX_OBJECT_DETECTOR_H__
#define __GST_ONNX_OBJECT_DETECTOR_H__
#include <gst/gst.h>
#include <gst/video/video.h>
#include <gst/video/gstvideofilter.h>
#include "gstonnxelement.h"
G_BEGIN_DECLS
#define GST_TYPE_ONNX_OBJECT_DETECTOR (gst_onnx_object_detector_get_type())
G_DECLARE_FINAL_TYPE (GstOnnxObjectDetector, gst_onnx_object_detector, GST, ONNX_OBJECT_DETECTOR, GstBaseTransform)
#define GST_ONNX_OBJECT_DETECTOR(obj) (G_TYPE_CHECK_INSTANCE_CAST((obj),GST_TYPE_ONNX_OBJECT_DETECTOR,GstOnnxObjectDetector))
#define GST_ONNX_OBJECT_DETECTOR_CLASS(klass) (G_TYPE_CHECK_CLASS_CAST((klass), GST_TYPE_ONNX_OBJECT_DETECTOR,GstOnnxObjectDetectorClass))
#define GST_ONNX_OBJECT_DETECTOR_GET_CLASS(obj) (G_TYPE_INSTANCE_GET_CLASS((obj), GST_TYPE_ONNX_OBJECT_DETECTOR,GstOnnxObjectDetectorClass))
#define GST_IS_ONNX_OBJECT_DETECTOR(obj) (G_TYPE_CHECK_INSTANCE_TYPE((obj),GST_TYPE_ONNX_OBJECT_DETECTOR))
#define GST_IS_ONNX_OBJECT_DETECTOR_CLASS(klass) (G_TYPE_CHECK_CLASS_TYPE((klass), GST_TYPE_ONNX_OBJECT_DETECTOR))
#define GST_ONNX_OBJECT_DETECTOR_META_NAME "onnx-object_detector"
#define GST_ONNX_OBJECT_DETECTOR_META_PARAM_NAME "extra-data"
#define GST_ONNX_OBJECT_DETECTOR_META_FIELD_LABEL "label"
#define GST_ONNX_OBJECT_DETECTOR_META_FIELD_SCORE "score"
/**
* GstOnnxObjectDetector:
*
* @model_file model file
* @label_file label file
* @score_threshold score threshold
* @confidence_threshold confidence threshold
* @iou_threhsold iou threshold
* @optimization_level ONNX optimization level
* @execution_provider: ONNX execution provider
* @onnx_ptr opaque pointer to ONNX implementation
*
* Since: 1.20
*/
struct _GstOnnxObjectDetector
{
GstBaseTransform basetransform;
gchar *model_file;
gchar *label_file;
gfloat score_threshold;
gfloat confidence_threshold;
gfloat iou_threshold;
GstOnnxOptimizationLevel optimization_level;
GstOnnxExecutionProvider execution_provider;
gpointer onnx_ptr;
gboolean onnx_disabled;
void (*process) (GstOnnxObjectDetector * onnx_object_detector,
GstVideoFrame * inframe, GstVideoFrame * outframe);
};
/**
* GstOnnxObjectDetectorClass:
*
* @parent_class base transform base class
*
* Since: 1.20
*/
struct _GstOnnxObjectDetectorClass
{
GstBaseTransformClass parent_class;
};
GST_ELEMENT_REGISTER_DECLARE (onnx_object_detector)
G_END_DECLS
#endif /* __GST_ONNX_OBJECT_DETECTOR_H__ */

32
ext/onnx/meson.build Normal file
View file

@ -0,0 +1,32 @@
if get_option('onnx').disabled()
subdir_done()
endif
onnxrt_dep = dependency('libonnxruntime',required : get_option('onnx'))
if onnxrt_dep.found()
onnxrt_include_root = onnxrt_dep.get_pkgconfig_variable('includedir')
onnxrt_includes = [onnxrt_include_root / 'core/session', onnxrt_include_root / 'core']
onnxrt_dep_args = []
compiler = meson.get_compiler('cpp')
if compiler.has_header(onnxrt_include_root / 'core/providers/cuda/cuda_provider_factory.h')
onnxrt_dep_args = ['-DGST_ML_ONNX_RUNTIME_HAVE_CUDA']
endif
gstonnx = library('gstonnx',
'gstonnx.c',
'gstonnxelement.c',
'gstonnxobjectdetector.cpp',
'gstonnxclient.cpp',
c_args : gst_plugins_bad_args,
cpp_args: onnxrt_dep_args,
link_args : noseh_link_args,
include_directories : [configinc, libsinc, onnxrt_includes],
dependencies : [gstbase_dep, gstvideo_dep, onnxrt_dep, libm],
install : true,
install_dir : plugins_install_dir,
)
pkgconfig.generate(gstonnx, install_dir : plugins_pkgconfig_install_dir)
plugins += [gstonnx]
endif

View file

@ -129,6 +129,7 @@ option('musepack', type : 'feature', value : 'auto', description : 'libmpcdec Mu
option('neon', type : 'feature', value : 'auto', description : 'NEON HTTP source plugin')
option('nvcodec', type : 'feature', value : 'auto', description : 'NVIDIA GPU codec plugin')
option('ofa', type : 'feature', value : 'auto', description : 'Open Fingerprint Architecture library plugin')
option('onnx', type : 'feature', value : 'auto', description : 'ONNX neural network plugin')
option('openal', type : 'feature', value : 'auto', description : 'OpenAL plugin')
option('openexr', type : 'feature', value : 'auto', description : 'OpenEXR plugin')
option('openh264', type : 'feature', value : 'auto', description : 'H.264 video codec plugin')