onnx: add gstonnxinference element

This element refactors functionality from gstonnxinference element,
namely separating out the ONNX inference from the subsequent analysis.

The new element runs an ONNX model on each video frame, and then
attaches a TensorMeta meta with the output tensor data. This tensor data
will then be consumed by downstream elements such as gstobjectdetector.

At the moment, a provisional TensorMeta is used just in the ONNX
plugin, but in future this will upgraded to a GStreamer API for other
plugins to consume.

Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/4916>
This commit is contained in:
Aaron Boxer 2023-06-26 10:55:53 -04:00 committed by GStreamer Marge Bot
parent 6053dd0d1b
commit 1ff585233a
20 changed files with 1934 additions and 1197 deletions

View file

@ -0,0 +1,195 @@
/*
* GStreamer gstreamer-objectdetectorutils
* Copyright (C) 2023 Collabora Ltd
*
* gstobjectdetectorutils.cpp
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public
* License along with this library; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
#include "gstobjectdetectorutils.h"
#include <fstream>
#include "tensor/gsttensorid.h"
GstMlBoundingBox::GstMlBoundingBox (std::string lbl, float score, float _x0,
float _y0, float _width, float _height):
label (lbl),
score (score),
x0 (_x0),
y0 (_y0),
width (_width),
height (_height)
{
}
GstMlBoundingBox::GstMlBoundingBox ():
GstMlBoundingBox ("", 0.0f, 0.0f, 0.0f, 0.0f, 0.0f)
{
}
namespace GstObjectDetectorUtils
{
GstObjectDetectorUtils::GstObjectDetectorUtils ()
{
}
std::vector < std::string >
GstObjectDetectorUtils::ReadLabels (const std::string & labelsFile)
{
std::vector < std::string > labels;
std::string line;
std::ifstream fp (labelsFile);
while (std::getline (fp, line))
labels.push_back (line);
return labels;
}
std::vector < GstMlBoundingBox > GstObjectDetectorUtils::run (int32_t w,
int32_t h, GstTensorMeta * tmeta, std::string labelPath,
float scoreThreshold)
{
auto classIndex = gst_tensor_meta_get_index_from_id (tmeta,
gst_tensorid_get_quark (GST_MODEL_OBJECT_DETECTOR_CLASSES));
if (classIndex == GST_TENSOR_MISSING_ID) {
GST_ERROR ("Missing class tensor id");
return std::vector < GstMlBoundingBox > ();
}
auto type = tmeta->tensor[classIndex].type;
return (type == GST_TENSOR_TYPE_FLOAT32) ?
doRun < float >(w, h, tmeta, labelPath, scoreThreshold)
: doRun < int >(w, h, tmeta, labelPath, scoreThreshold);
}
template < typename T > std::vector < GstMlBoundingBox >
GstObjectDetectorUtils::doRun (int32_t w, int32_t h,
GstTensorMeta * tmeta, std::string labelPath, float scoreThreshold)
{
std::vector < GstMlBoundingBox > boundingBoxes;
GstMapInfo map_info[GstObjectDetectorMaxNodes];
GstMemory *memory[GstObjectDetectorMaxNodes] = { NULL };
std::vector < std::string > labels;
gint index;
T *numDetections = nullptr, *bboxes = nullptr, *scores =
nullptr, *labelIndex = nullptr;
// number of detections
index = gst_tensor_meta_get_index_from_id (tmeta,
gst_tensorid_get_quark (GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS));
if (index == GST_TENSOR_MISSING_ID) {
GST_WARNING ("Missing tensor data for tensor index %d", index);
goto cleanup;
}
memory[index] = gst_buffer_peek_memory (tmeta->tensor[index].data, 0);
if (!memory[index]) {
GST_WARNING ("Missing tensor data for tensor index %d", index);
goto cleanup;
}
if (!gst_memory_map (memory[index], map_info + index, GST_MAP_READ)) {
GST_WARNING ("Failed to map tensor memory for index %d", index);
goto cleanup;
}
numDetections = (T *) map_info[index].data;
// bounding boxes
index =
gst_tensor_meta_get_index_from_id (tmeta,
gst_tensorid_get_quark (GST_MODEL_OBJECT_DETECTOR_BOXES));
if (index == GST_TENSOR_MISSING_ID) {
GST_WARNING ("Missing tensor data for tensor index %d", index);
goto cleanup;
}
memory[index] = gst_buffer_peek_memory (tmeta->tensor[index].data, 0);
if (!memory[index]) {
GST_WARNING ("Failed to map tensor memory for index %d", index);
goto cleanup;
}
if (!gst_memory_map (memory[index], map_info + index, GST_MAP_READ)) {
GST_ERROR ("Failed to map GstMemory");
goto cleanup;
}
bboxes = (T *) map_info[index].data;
// scores
index =
gst_tensor_meta_get_index_from_id (tmeta,
gst_tensorid_get_quark (GST_MODEL_OBJECT_DETECTOR_SCORES));
if (index == GST_TENSOR_MISSING_ID) {
GST_ERROR ("Missing scores tensor id");
goto cleanup;
}
memory[index] = gst_buffer_peek_memory (tmeta->tensor[index].data, 0);
if (!memory[index]) {
GST_WARNING ("Missing tensor data for tensor index %d", index);
goto cleanup;
}
if (!gst_memory_map (memory[index], map_info + index, GST_MAP_READ)) {
GST_ERROR ("Failed to map GstMemory");
goto cleanup;
}
scores = (T *) map_info[index].data;
// optional label
labelIndex = nullptr;
index =
gst_tensor_meta_get_index_from_id (tmeta,
gst_tensorid_get_quark (GST_MODEL_OBJECT_DETECTOR_CLASSES));
if (index != GST_TENSOR_MISSING_ID) {
memory[index] = gst_buffer_peek_memory (tmeta->tensor[index].data, 0);
if (!memory[index]) {
GST_WARNING ("Missing tensor data for tensor index %d", index);
goto cleanup;
}
if (!gst_memory_map (memory[index], map_info + index, GST_MAP_READ)) {
GST_ERROR ("Failed to map GstMemory");
goto cleanup;
}
labelIndex = (T *) map_info[index].data;
}
if (!labelPath.empty ())
labels = ReadLabels (labelPath);
for (int i = 0; i < numDetections[0]; ++i) {
if (scores[i] > scoreThreshold) {
std::string label = "";
if (labelIndex && !labels.empty ())
label = labels[labelIndex[i] - 1];
auto score = scores[i];
auto y0 = bboxes[i * 4] * h;
auto x0 = bboxes[i * 4 + 1] * w;
auto bheight = bboxes[i * 4 + 2] * h - y0;
auto bwidth = bboxes[i * 4 + 3] * w - x0;
boundingBoxes.push_back (GstMlBoundingBox (label, score, x0, y0, bwidth,
bheight));
}
}
cleanup:
for (int i = 0; i < GstObjectDetectorMaxNodes; ++i) {
if (memory[i])
gst_memory_unmap (memory[i], map_info + i);
}
return boundingBoxes;
}
}

View file

@ -0,0 +1,82 @@
/*
* GStreamer gstreamer-objectdetectorutils
* Copyright (C) 2023 Collabora Ltd
*
* gstobjectdetectorutils.h
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public
* License along with this library; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
#ifndef __GST_OBJECT_DETECTOR_UTILS_H__
#define __GST_OBJECT_DETECTOR_UTILS_H__
#include <gst/gst.h>
#include <string>
#include <vector>
#include "gstml.h"
#include "tensor/gsttensormeta.h"
/* Object detection tensor id strings */
#define GST_MODEL_OBJECT_DETECTOR_BOXES "Gst.Model.ObjectDetector.Boxes"
#define GST_MODEL_OBJECT_DETECTOR_SCORES "Gst.Model.ObjectDetector.Scores"
#define GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS "Gst.Model.ObjectDetector.NumDetections"
#define GST_MODEL_OBJECT_DETECTOR_CLASSES "Gst.Model.ObjectDetector.Classes"
/**
* GstMlBoundingBox:
*
* @label label
* @score detection confidence
* @x0 top left hand x coordinate
* @y0 top left hand y coordinate
* @width width
* @height height
*
* Since: 1.20
*/
struct GstMlBoundingBox {
GstMlBoundingBox(std::string lbl, float score, float _x0, float _y0,
float _width, float _height);
GstMlBoundingBox();
std::string label;
float score;
float x0;
float y0;
float width;
float height;
};
namespace GstObjectDetectorUtils {
const int GstObjectDetectorMaxNodes = 4;
class GstObjectDetectorUtils {
public:
GstObjectDetectorUtils(void);
~GstObjectDetectorUtils(void) = default;
std::vector < GstMlBoundingBox > run(int32_t w, int32_t h,
GstTensorMeta *tmeta,
std::string labelPath,
float scoreThreshold);
private:
template < typename T > std::vector < GstMlBoundingBox >
doRun(int32_t w, int32_t h,
GstTensorMeta *tmeta, std::string labelPath,
float scoreThreshold);
std::vector < std::string > ReadLabels(const std::string & labelsFile);
};
}
#endif /* __GST_OBJECT_DETECTOR_UTILS_H__ */

View file

@ -0,0 +1,348 @@
/*
* GStreamer gstreamer-ssdobjectdetector
* Copyright (C) 2021 Collabora Ltd.
*
* gstssdobjectdetector.cpp
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public
* License along with this library; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
/**
* SECTION:element-ssdobjectdetector
* @short_description: Detect objects in video buffers using SSD neural network
*
* This element can parse per-buffer inference tensor meta data generated by an upstream
* inference element
*
*
* ## Example launch command:
*
* note: image resolution may need to be adapted to the model, if the model expects
* a certain input resolution. The `videoscale` element in the pipeline below will scale
* the image, using padding if required, to 640x383 resolution required by model
*
*
* GST_DEBUG=objectdetector:5 gst-launch-1.0 multifilesrc \
* location=bus.jpg ! jpegdec ! videoconvert ! \
* videoscale ! 'video/x-raw,width=640,height=383' ! \
* onnxinference execution-provider=cpu model-file=model.onnx \
* ssdobjectdetector label-file=COCO_classes.txt ! \
* videoconvert ! autovideosink
*
*/
#ifdef HAVE_CONFIG_H
#include "config.h"
#endif
#include "gstssdobjectdetector.h"
#include "gstobjectdetectorutils.h"
#include <gst/gst.h>
#include <gst/video/video.h>
#include <gst/video/gstvideometa.h>
#include "tensor/gsttensormeta.h"
#include "tensor/gsttensorid.h"
GST_DEBUG_CATEGORY_STATIC (ssd_object_detector_debug);
#define GST_CAT_DEFAULT ssd_object_detector_debug
#define GST_ODUTILS_MEMBER( self ) ((GstObjectDetectorUtils::GstObjectDetectorUtils *) (self->odutils))
GST_ELEMENT_REGISTER_DEFINE (ssd_object_detector, "ssdobjectdetector",
GST_RANK_PRIMARY, GST_TYPE_SSD_OBJECT_DETECTOR);
/* GstSsdObjectDetector properties */
enum
{
PROP_0,
PROP_LABEL_FILE,
PROP_SCORE_THRESHOLD,
};
#define GST_SSD_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD 0.3f /* 0 to 1 */
static GstStaticPadTemplate gst_ssd_object_detector_src_template =
GST_STATIC_PAD_TEMPLATE ("src",
GST_PAD_SRC,
GST_PAD_ALWAYS,
GST_STATIC_CAPS ("video/x-raw")
);
static GstStaticPadTemplate gst_ssd_object_detector_sink_template =
GST_STATIC_PAD_TEMPLATE ("sink",
GST_PAD_SINK,
GST_PAD_ALWAYS,
GST_STATIC_CAPS ("video/x-raw")
);
static void gst_ssd_object_detector_set_property (GObject * object,
guint prop_id, const GValue * value, GParamSpec * pspec);
static void gst_ssd_object_detector_get_property (GObject * object,
guint prop_id, GValue * value, GParamSpec * pspec);
static void gst_ssd_object_detector_finalize (GObject * object);
static GstFlowReturn gst_ssd_object_detector_transform_ip (GstBaseTransform *
trans, GstBuffer * buf);
static gboolean gst_ssd_object_detector_process (GstBaseTransform * trans,
GstBuffer * buf);
static gboolean
gst_ssd_object_detector_set_caps (GstBaseTransform * trans, GstCaps * incaps,
GstCaps * outcaps);
G_DEFINE_TYPE (GstSsdObjectDetector, gst_ssd_object_detector, GST_TYPE_BASE_TRANSFORM);
static void
gst_ssd_object_detector_class_init (GstSsdObjectDetectorClass * klass)
{
GObjectClass *gobject_class = (GObjectClass *) klass;
GstElementClass *element_class = (GstElementClass *) klass;
GstBaseTransformClass *basetransform_class = (GstBaseTransformClass *) klass;
GST_DEBUG_CATEGORY_INIT (ssd_object_detector_debug, "ssdobjectdetector",
0, "ssdobjectdetector");
gobject_class->set_property = gst_ssd_object_detector_set_property;
gobject_class->get_property = gst_ssd_object_detector_get_property;
gobject_class->finalize = gst_ssd_object_detector_finalize;
/**
* GstSsdObjectDetector:label-file
*
* Label file
*
* Since: 1.24
*/
g_object_class_install_property (G_OBJECT_CLASS (klass), PROP_LABEL_FILE,
g_param_spec_string ("label-file",
"Label file", "Label file", NULL, (GParamFlags)
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
/**
* GstSsdObjectDetector:score-threshold
*
* Threshold for deciding when to remove boxes based on score
*
* Since: 1.24
*/
g_object_class_install_property (G_OBJECT_CLASS (klass), PROP_SCORE_THRESHOLD,
g_param_spec_float ("score-threshold",
"Score threshold",
"Threshold for deciding when to remove boxes based on score",
0.0, 1.0, GST_SSD_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD, (GParamFlags)
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
gst_element_class_set_static_metadata (element_class, "objectdetector",
"Filter/Effect/Video",
"Apply tensor output from inference to detect objects in video frames",
"Aaron Boxer <aaron.boxer@collabora.com>, Marcus Edel <marcus.edel@collabora.com>");
gst_element_class_add_pad_template (element_class,
gst_static_pad_template_get (&gst_ssd_object_detector_sink_template));
gst_element_class_add_pad_template (element_class,
gst_static_pad_template_get (&gst_ssd_object_detector_src_template));
basetransform_class->transform_ip =
GST_DEBUG_FUNCPTR (gst_ssd_object_detector_transform_ip);
basetransform_class->set_caps =
GST_DEBUG_FUNCPTR (gst_ssd_object_detector_set_caps);
}
static void
gst_ssd_object_detector_init (GstSsdObjectDetector * self)
{
self->odutils = new GstObjectDetectorUtils::GstObjectDetectorUtils ();
}
static void
gst_ssd_object_detector_finalize (GObject * object)
{
GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (object);
delete GST_ODUTILS_MEMBER (self);
G_OBJECT_CLASS (gst_ssd_object_detector_parent_class)->finalize (object);
}
static void
gst_ssd_object_detector_set_property (GObject * object, guint prop_id,
const GValue * value, GParamSpec * pspec)
{
GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (object);
const gchar *filename;
switch (prop_id) {
case PROP_LABEL_FILE:
filename = g_value_get_string (value);
if (filename
&& g_file_test (filename,
(GFileTest) (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR))) {
g_free (self->label_file);
self->label_file = g_strdup (filename);
} else {
GST_WARNING_OBJECT (self, "Label file '%s' not found!", filename);
}
break;
case PROP_SCORE_THRESHOLD:
GST_OBJECT_LOCK (self);
self->score_threshold = g_value_get_float (value);
GST_OBJECT_UNLOCK (self);
break;
default:
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
break;
}
}
static void
gst_ssd_object_detector_get_property (GObject * object, guint prop_id,
GValue * value, GParamSpec * pspec)
{
GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (object);
switch (prop_id) {
case PROP_LABEL_FILE:
g_value_set_string (value, self->label_file);
break;
case PROP_SCORE_THRESHOLD:
GST_OBJECT_LOCK (self);
g_value_set_float (value, self->score_threshold);
GST_OBJECT_UNLOCK (self);
break;
default:
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
break;
}
}
static GstTensorMeta *
gst_ssd_object_detector_get_tensor_meta (GstSsdObjectDetector * object_detector,
GstBuffer * buf)
{
GstTensorMeta *tmeta = NULL;
GList *tensor_metas;
GList *iter;
// get all tensor metas
tensor_metas = gst_tensor_meta_get_all_from_buffer (buf);
if (!tensor_metas) {
GST_TRACE_OBJECT (object_detector,
"missing tensor meta from buffer %" GST_PTR_FORMAT, buf);
goto cleanup;
}
// find object detector meta
for (iter = tensor_metas; iter != NULL; iter = g_list_next (iter)) {
GstTensorMeta *tensor_meta = (GstTensorMeta *) iter->data;
gint numTensors = tensor_meta->num_tensors;
/* SSD model must have either 3 or 4 output tensor nodes: 4 if there is a label node,
* and only 3 if there is no label */
if (numTensors != 3 && numTensors != 4)
continue;
gint boxesIndex = gst_tensor_meta_get_index_from_id (tensor_meta,
gst_tensorid_get_quark (GST_MODEL_OBJECT_DETECTOR_BOXES));
gint scoresIndex = gst_tensor_meta_get_index_from_id (tensor_meta,
gst_tensorid_get_quark (GST_MODEL_OBJECT_DETECTOR_SCORES));
gint numDetectionsIndex = gst_tensor_meta_get_index_from_id (tensor_meta,
gst_tensorid_get_quark (GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS));
gint clasesIndex = gst_tensor_meta_get_index_from_id (tensor_meta,
gst_tensorid_get_quark (GST_MODEL_OBJECT_DETECTOR_CLASSES));
if (boxesIndex == GST_TENSOR_MISSING_ID || scoresIndex == GST_TENSOR_MISSING_ID
|| numDetectionsIndex == GST_TENSOR_MISSING_ID)
continue;
if (numTensors == 4 && clasesIndex == GST_TENSOR_MISSING_ID)
continue;
tmeta = tensor_meta;
break;
}
cleanup:
g_list_free (tensor_metas);
return tmeta;
}
static gboolean
gst_ssd_object_detector_set_caps (GstBaseTransform * trans, GstCaps * incaps,
GstCaps * outcaps)
{
GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (trans);
if (!gst_video_info_from_caps (&self->video_info, incaps)) {
GST_ERROR_OBJECT (self, "Failed to parse caps");
return FALSE;
}
return TRUE;
}
static GstFlowReturn
gst_ssd_object_detector_transform_ip (GstBaseTransform * trans, GstBuffer * buf)
{
if (!gst_base_transform_is_passthrough (trans)) {
if (!gst_ssd_object_detector_process (trans, buf)) {
GST_ELEMENT_ERROR (trans, STREAM, FAILED,
(NULL), ("ssd object detection failed"));
return GST_FLOW_ERROR;
}
}
return GST_FLOW_OK;
}
static gboolean
gst_ssd_object_detector_process (GstBaseTransform * trans, GstBuffer * buf)
{
GstTensorMeta *tmeta = NULL;
GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (trans);
// get all tensor metas
tmeta = gst_ssd_object_detector_get_tensor_meta (self, buf);
if (!tmeta) {
GST_WARNING_OBJECT (trans, "missing tensor meta");
return TRUE;
}
std::vector < GstMlBoundingBox > boxes =
GST_ODUTILS_MEMBER (self)->run (self->video_info.width,
self->video_info.height, tmeta, self->label_file ? self->label_file : "",
self->score_threshold);
for (auto & b:boxes) {
auto vroi_meta = gst_buffer_add_video_region_of_interest_meta (buf,
GST_SSD_OBJECT_DETECTOR_META_NAME,
b.x0, b.y0,
b.width,
b.height);
if (!vroi_meta) {
GST_WARNING_OBJECT (trans,
"Unable to attach GstVideoRegionOfInterestMeta to buffer");
return FALSE;
}
auto s = gst_structure_new (GST_SSD_OBJECT_DETECTOR_META_PARAM_NAME,
GST_SSD_OBJECT_DETECTOR_META_FIELD_LABEL,
G_TYPE_STRING,
b.label.c_str (),
GST_SSD_OBJECT_DETECTOR_META_FIELD_SCORE,
G_TYPE_DOUBLE,
b.score,
NULL);
gst_video_region_of_interest_meta_add_param (vroi_meta, s);
GST_DEBUG_OBJECT (self,
"Object detected with label : %s, score: %f, bound box: (%f,%f,%f,%f) \n",
b.label.c_str (), b.score, b.x0, b.y0, b.x0 + b.width, b.y0 + b.height);
}
return TRUE;
}

View file

@ -0,0 +1,78 @@
/*
* GStreamer gstreamer-ssdobjectdetector
* Copyright (C) 2021 Collabora Ltd
*
* gstssdobjectdetector.h
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public
* License along with this library; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
#ifndef __GST_SSD_OBJECT_DETECTOR_H__
#define __GST_SSD_OBJECT_DETECTOR_H__
#include <gst/gst.h>
#include <gst/video/video.h>
#include <gst/video/gstvideofilter.h>
G_BEGIN_DECLS
#define GST_TYPE_SSD_OBJECT_DETECTOR (gst_ssd_object_detector_get_type())
G_DECLARE_FINAL_TYPE (GstSsdObjectDetector, gst_ssd_object_detector, GST, SSD_OBJECT_DETECTOR, GstBaseTransform)
#define GST_SSD_OBJECT_DETECTOR_META_NAME "ssd-object-detector"
#define GST_SSD_OBJECT_DETECTOR_META_PARAM_NAME "extra-data"
#define GST_SSD_OBJECT_DETECTOR_META_FIELD_LABEL "label"
#define GST_SSD_OBJECT_DETECTOR_META_FIELD_SCORE "score"
/**
* GstSsdObjectDetector:
*
* @label_file label file
* @score_threshold score threshold
* @confidence_threshold confidence threshold
* @iou_threhsold iou threshold
* @od_ptr opaque pointer to GstOd object detection implementation
*
* Since: 1.20
*/
struct _GstSsdObjectDetector
{
GstBaseTransform basetransform;
gchar *label_file;
gfloat score_threshold;
gfloat confidence_threshold;
gfloat iou_threshold;
gpointer odutils;
GstVideoInfo video_info;
};
/**
* GstSsdObjectDetectorClass:
*
* @parent_class base transform base class
*
* Since: 1.20
*/
struct _GstSsdObjectDetectorClass
{
GstBaseTransformClass parent_class;
};
GST_ELEMENT_REGISTER_DECLARE (ssd_object_detector)
G_END_DECLS
#endif /* __GST_SSD_OBJECT_DETECTOR_H__ */

View file

@ -0,0 +1,41 @@
/*
* GStreamer gstreamer-ml
* Copyright (C) 2021 Collabora Ltd
*
* gstml.h
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public
* License along with this library; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
#ifndef __GST_ML_H__
#define __GST_ML_H__
/**
* GstMlInputImageFormat:
*
* @GST_ML_INPUT_IMAGE_FORMAT_HWC Height Width Channel (a.k.a. interleaved) format
* @GST_ML_INPUT_IMAGE_FORMAT_CHW Channel Height Width (a.k.a. planar) format
*
* Since: 1.20
*/
typedef enum {
GST_ML_INPUT_IMAGE_FORMAT_HWC,
GST_ML_INPUT_IMAGE_FORMAT_CHW,
} GstMlInputImageFormat;
#endif

View file

@ -1,3 +1,4 @@
/* /*
* GStreamer gstreamer-onnx * GStreamer gstreamer-onnx
* Copyright (C) 2021 Collabora Ltd * Copyright (C) 2021 Collabora Ltd
@ -23,14 +24,17 @@
#include "config.h" #include "config.h"
#endif #endif
#include "gstonnxobjectdetector.h" #include "decoders/gstssdobjectdetector.h"
#include "gstonnxinference.h"
#include "tensor/gsttensormeta.h"
static gboolean static gboolean
plugin_init (GstPlugin * plugin) plugin_init (GstPlugin * plugin)
{ {
GST_ELEMENT_REGISTER (onnx_object_detector, plugin); gboolean success = GST_ELEMENT_REGISTER (ssd_object_detector, plugin);
success |= GST_ELEMENT_REGISTER (onnx_inference, plugin);
return TRUE; return success;
} }
GST_PLUGIN_DEFINE (GST_VERSION_MAJOR, GST_PLUGIN_DEFINE (GST_VERSION_MAJOR,

View file

@ -21,23 +21,15 @@
*/ */
#include "gstonnxclient.h" #include "gstonnxclient.h"
#include <tensor/gsttensorid.h>
#include <providers/cpu/cpu_provider_factory.h> #include <providers/cpu/cpu_provider_factory.h>
#ifdef GST_ML_ONNX_RUNTIME_HAVE_CUDA
#include <providers/cuda/cuda_provider_factory.h>
#endif
#include <exception>
#include <fstream>
#include <iostream>
#include <limits>
#include <numeric>
#include <cmath>
#include <sstream> #include <sstream>
namespace GstOnnxNamespace namespace GstOnnxNamespace
{ {
template < typename T > template < typename T >
std::ostream & operator<< (std::ostream & os, const std::vector < T > &v) std::ostream & operator<< (std::ostream & os, const std::vector < T > &v)
{ {
os << "["; os << "[";
for (size_t i = 0; i < v.size (); ++i) for (size_t i = 0; i < v.size (); ++i)
{ {
@ -50,13 +42,7 @@ template < typename T >
os << "]"; os << "]";
return os; return os;
} }
GstMlOutputNodeInfo::GstMlOutputNodeInfo (void):index
(GST_ML_NODE_INDEX_DISABLED),
type (ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
{
}
GstOnnxClient::GstOnnxClient ():session (nullptr), GstOnnxClient::GstOnnxClient ():session (nullptr),
width (0), width (0),
@ -64,123 +50,59 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
channels (0), channels (0),
dest (nullptr), dest (nullptr),
m_provider (GST_ONNX_EXECUTION_PROVIDER_CPU), m_provider (GST_ONNX_EXECUTION_PROVIDER_CPU),
inputImageFormat (GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC), inputImageFormat (GST_ML_INPUT_IMAGE_FORMAT_HWC),
fixedInputImageSize (true) fixedInputImageSize (false) {
{ }
for (size_t i = 0; i < GST_ML_OUTPUT_NODE_NUMBER_OF; ++i)
outputNodeIndexToFunction[i] = (GstMlOutputNodeFunction) i;
}
GstOnnxClient::~GstOnnxClient () GstOnnxClient::~GstOnnxClient () {
{
outputNames.clear();
delete session; delete session;
delete[]dest; delete[]dest;
} }
Ort::Env & GstOnnxClient::getEnv (void) int32_t GstOnnxClient::getWidth (void)
{ {
static Ort::Env env (OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING,
"GstOnnxNamespace");
return env;
}
int32_t GstOnnxClient::getWidth (void)
{
return width; return width;
} }
int32_t GstOnnxClient::getHeight (void) int32_t GstOnnxClient::getHeight (void)
{ {
return height; return height;
} }
bool GstOnnxClient::isFixedInputImageSize (void) bool GstOnnxClient::isFixedInputImageSize (void)
{ {
return fixedInputImageSize; return fixedInputImageSize;
} }
std::string GstOnnxClient::getOutputNodeName (GstMlOutputNodeFunction nodeType) void GstOnnxClient::setInputImageFormat (GstMlInputImageFormat format)
{ {
switch (nodeType) {
case GST_ML_OUTPUT_NODE_FUNCTION_DETECTION:
return "detection";
break;
case GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX:
return "bounding box";
break;
case GST_ML_OUTPUT_NODE_FUNCTION_SCORE:
return "score";
break;
case GST_ML_OUTPUT_NODE_FUNCTION_CLASS:
return "label";
break;
case GST_ML_OUTPUT_NODE_NUMBER_OF:
g_assert_not_reached();
GST_WARNING("Invalid parameter");
break;
};
return "";
}
void GstOnnxClient::setInputImageFormat (GstMlModelInputImageFormat format)
{
inputImageFormat = format; inputImageFormat = format;
} }
GstMlModelInputImageFormat GstOnnxClient::getInputImageFormat (void) GstMlInputImageFormat GstOnnxClient::getInputImageFormat (void)
{ {
return inputImageFormat; return inputImageFormat;
} }
std::vector< const char *> GstOnnxClient::getOutputNodeNames (void) std::vector < const char *>GstOnnxClient::genOutputNamesRaw (void)
{ {
if (!outputNames.empty() && outputNamesRaw.size() != outputNames.size()) { if (!outputNames.empty () && outputNamesRaw.size () != outputNames.size ()) {
outputNamesRaw.resize(outputNames.size()); outputNamesRaw.resize (outputNames.size ());
for (size_t i = 0; i < outputNamesRaw.size(); i++) { for (size_t i = 0; i < outputNamesRaw.size (); i++)
outputNamesRaw[i] = outputNames[i].get(); outputNamesRaw[i] = outputNames[i].get ();
}
} }
return outputNamesRaw; return outputNamesRaw;
} }
void GstOnnxClient::setOutputNodeIndex (GstMlOutputNodeFunction node, bool GstOnnxClient::hasSession (void)
gint index) {
{
g_assert (index < GST_ML_OUTPUT_NODE_NUMBER_OF);
outputNodeInfo[node].index = index;
if (index != GST_ML_NODE_INDEX_DISABLED)
outputNodeIndexToFunction[index] = node;
}
gint GstOnnxClient::getOutputNodeIndex (GstMlOutputNodeFunction node)
{
return outputNodeInfo[node].index;
}
void GstOnnxClient::setOutputNodeType (GstMlOutputNodeFunction node,
ONNXTensorElementDataType type)
{
outputNodeInfo[node].type = type;
}
ONNXTensorElementDataType
GstOnnxClient::getOutputNodeType (GstMlOutputNodeFunction node)
{
return outputNodeInfo[node].type;
}
bool GstOnnxClient::hasSession (void)
{
return session != nullptr; return session != nullptr;
} }
bool GstOnnxClient::createSession (std::string modelFile, bool GstOnnxClient::createSession (std::string modelFile,
GstOnnxOptimizationLevel optim, GstOnnxExecutionProvider provider) GstOnnxOptimizationLevel optim, GstOnnxExecutionProvider provider)
{ {
if (session) if (session)
return true; return true;
@ -205,30 +127,43 @@ bool GstOnnxClient::createSession (std::string modelFile,
try { try {
Ort::SessionOptions sessionOptions; Ort::SessionOptions sessionOptions;
const auto & api = Ort::GetApi ();
// for debugging // for debugging
//sessionOptions.SetIntraOpNumThreads (1); //sessionOptions.SetIntraOpNumThreads (1);
sessionOptions.SetGraphOptimizationLevel (onnx_optim); sessionOptions.SetGraphOptimizationLevel (onnx_optim);
m_provider = provider; m_provider = provider;
switch (m_provider) { switch (m_provider) {
case GST_ONNX_EXECUTION_PROVIDER_CUDA: case GST_ONNX_EXECUTION_PROVIDER_CUDA:
#ifdef GST_ML_ONNX_RUNTIME_HAVE_CUDA try {
Ort::ThrowOnError (OrtSessionOptionsAppendExecutionProvider_CUDA OrtCUDAProviderOptionsV2 *cuda_options = nullptr;
(sessionOptions, 0)); Ort::ThrowOnError (api.CreateCUDAProviderOptions (&cuda_options));
#else std::unique_ptr < OrtCUDAProviderOptionsV2,
GST_ERROR ("ONNX CUDA execution provider not supported"); decltype (api.ReleaseCUDAProviderOptions) >
return false; rel_cuda_options (cuda_options, api.ReleaseCUDAProviderOptions);
#endif Ort::ThrowOnError (api.SessionOptionsAppendExecutionProvider_CUDA_V2
(static_cast < OrtSessionOptions * >(sessionOptions),
rel_cuda_options.get ()));
}
catch (Ort::Exception & ortex) {
GST_WARNING
("Failed to create CUDA provider - dropping back to CPU");
Ort::ThrowOnError (OrtSessionOptionsAppendExecutionProvider_CPU
(sessionOptions, 1));
}
break; break;
default: default:
Ort::ThrowOnError (OrtSessionOptionsAppendExecutionProvider_CPU
(sessionOptions, 1));
break; break;
}; };
session = env =
new Ort::Session (getEnv (), modelFile.c_str (), sessionOptions); Ort::Env (OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING,
"GstOnnxNamespace");
session = new Ort::Session (env, modelFile.c_str (), sessionOptions);
auto inputTypeInfo = session->GetInputTypeInfo (0); auto inputTypeInfo = session->GetInputTypeInfo (0);
std::vector < int64_t > inputDims = std::vector < int64_t > inputDims =
inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape (); inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape ();
if (inputImageFormat == GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC) { if (inputImageFormat == GST_ML_INPUT_IMAGE_FORMAT_HWC) {
height = inputDims[1]; height = inputDims[1];
width = inputDims[2]; width = inputDims[2];
channels = inputDims[3]; channels = inputDims[3];
@ -250,14 +185,37 @@ bool GstOnnxClient::createSession (std::string modelFile,
auto output_name = session->GetOutputNameAllocated (i, allocator); auto output_name = session->GetOutputNameAllocated (i, allocator);
GST_DEBUG ("Output name %lu:%s", i, output_name.get ()); GST_DEBUG ("Output name %lu:%s", i, output_name.get ());
outputNames.push_back (std::move (output_name)); outputNames.push_back (std::move (output_name));
auto type_info = session->GetOutputTypeInfo (i); }
auto tensor_info = type_info.GetTensorTypeAndShapeInfo (); genOutputNamesRaw ();
if (i < GST_ML_OUTPUT_NODE_NUMBER_OF) { // look up tensor ids
auto function = outputNodeIndexToFunction[i]; auto metaData = session->GetModelMetadata ();
outputNodeInfo[function].type = tensor_info.GetElementType (); OrtAllocator *ortAllocator;
auto status =
Ort::GetApi ().GetAllocatorWithDefaultOptions (&ortAllocator);
if (status) {
// Handle the error case
const char *errorString = Ort::GetApi ().GetErrorMessage (status);
GST_WARNING ("Failed to get allocator: %s", errorString);
// Clean up the error status
Ort::GetApi ().ReleaseStatus (status);
return false;
} else {
for (auto & name:outputNamesRaw) {
Ort::AllocatedStringPtr res =
metaData.LookupCustomMetadataMapAllocated (name, ortAllocator);
if (res) {
GQuark quark = gst_tensorid_get_quark (res.get ());
outputIds.push_back (quark);
} else {
GST_ERROR ("Failed to look up id for key %s", name);
return false;
} }
} }
}
} }
catch (Ort::Exception & ortex) { catch (Ort::Exception & ortex) {
GST_ERROR ("%s", ortex.what ()); GST_ERROR ("%s", ortex.what ());
@ -265,40 +223,110 @@ bool GstOnnxClient::createSession (std::string modelFile,
} }
return true; return true;
} }
std::vector < GstMlBoundingBox > GstOnnxClient::run (uint8_t * img_data, void GstOnnxClient::parseDimensions (GstVideoInfo vinfo)
GstVideoMeta * vmeta, std::string labelPath, float scoreThreshold) {
{ int32_t newWidth = fixedInputImageSize ? width : vinfo.width;
auto type = getOutputNodeType (GST_ML_OUTPUT_NODE_FUNCTION_CLASS); int32_t newHeight = fixedInputImageSize ? height : vinfo.height;
return (type ==
ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) ?
doRun < float >(img_data, vmeta, labelPath, scoreThreshold)
: doRun < int >(img_data, vmeta, labelPath, scoreThreshold);
}
void GstOnnxClient::parseDimensions (GstVideoMeta * vmeta)
{
int32_t newWidth = fixedInputImageSize ? width : vmeta->width;
int32_t newHeight = fixedInputImageSize ? height : vmeta->height;
if (!dest || width * height < newWidth * newHeight) { if (!dest || width * height < newWidth * newHeight) {
delete[] dest; delete[]dest;
dest = new uint8_t[newWidth * newHeight * channels]; dest = new uint8_t[newWidth * newHeight * channels];
} }
width = newWidth; width = newWidth;
height = newHeight; height = newHeight;
} }
template < typename T > std::vector < GstMlBoundingBox > // copy tensor data to a GstTensorMeta
GstOnnxClient::doRun (uint8_t * img_data, GstVideoMeta * vmeta, GstTensorMeta *GstOnnxClient::copy_tensors_to_meta (std::vector < Ort::Value >
std::string labelPath, float scoreThreshold) &outputs, GstBuffer * buffer)
{ {
std::vector < GstMlBoundingBox > boundingBoxes; size_t num_tensors = outputNamesRaw.size ();
GstTensorMeta *tmeta = (GstTensorMeta *) gst_buffer_add_meta (buffer,
gst_tensor_meta_get_info (),
NULL);
tmeta->num_tensors = num_tensors;
tmeta->tensor = (GstTensor *) g_malloc (num_tensors * sizeof (GstTensor));
bool hasIds = outputIds.size () == num_tensors;
for (size_t i = 0; i < num_tensors; i++) {
Ort::Value outputTensor = std::move (outputs[i]);
ONNXTensorElementDataType tensorType =
outputTensor.GetTensorTypeAndShapeInfo ().GetElementType ();
GstTensor *tensor = &tmeta->tensor[i];
if (hasIds)
tensor->id = outputIds[i];
tensor->data = gst_buffer_new ();
auto tensorShape = outputTensor.GetTensorTypeAndShapeInfo ().GetShape ();
tensor->num_dims = tensorShape.size ();
tensor->dims = g_new (int64_t, tensor->num_dims);
for (size_t j = 0; j < tensorShape.size (); ++j) {
tensor->dims[j] = tensorShape[j];
}
size_t numElements =
outputTensor.GetTensorTypeAndShapeInfo ().GetElementCount ();
size_t buffer_size = 0;
guint8 *buffer_data = NULL;
if (tensorType == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
buffer_size = numElements * sizeof (float);
// Allocate memory for the buffer data
buffer_data = (guint8 *) malloc (buffer_size);
if (buffer_data == NULL) {
GST_ERROR ("Failed to allocate memory");
return NULL;
}
// Copy the data from the source buffer to the allocated memory
memcpy (buffer_data, outputTensor.GetTensorData < float >(),
buffer_size);
tensor->type = GST_TENSOR_TYPE_FLOAT32;
} else if (tensorType == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) {
buffer_size = numElements * sizeof (int);
// Allocate memory for the buffer data
guint8 *buffer_data = (guint8 *) malloc (buffer_size);
if (buffer_data == NULL) {
GST_ERROR ("Failed to allocate memory");
return NULL;
}
// Copy the data from the source buffer to the allocated memory
memcpy (buffer_data, outputTensor.GetTensorData < int >(),
buffer_size);
tensor->type = GST_TENSOR_TYPE_INT32;
}
if (buffer_data) {
// Create a GstMemory object from the allocated memory
GstMemory *memory = gst_memory_new_wrapped ((GstMemoryFlags) 0,
buffer_data, buffer_size, 0, buffer_size, NULL, NULL);
// Append the GstMemory object to the GstBuffer
gst_buffer_append_memory (tmeta->tensor[i].data, memory);
}
}
return tmeta;
}
std::vector < Ort::Value > GstOnnxClient::run (uint8_t * img_data,
GstVideoInfo vinfo) {
std::vector < Ort::Value > modelOutput;
doRun (img_data, vinfo, modelOutput);
return modelOutput;
}
bool GstOnnxClient::doRun (uint8_t * img_data, GstVideoInfo vinfo,
std::vector < Ort::Value > &modelOutput)
{
if (!img_data) if (!img_data)
return boundingBoxes; return false;
parseDimensions (vmeta);
Ort::AllocatorWithDefaultOptions allocator; Ort::AllocatorWithDefaultOptions allocator;
auto inputName = session->GetInputNameAllocated (0, allocator); auto inputName = session->GetInputNameAllocated (0, allocator);
@ -306,7 +334,7 @@ template < typename T > std::vector < GstMlBoundingBox >
std::vector < int64_t > inputDims = std::vector < int64_t > inputDims =
inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape (); inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape ();
inputDims[0] = 1; inputDims[0] = 1;
if (inputImageFormat == GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC) { if (inputImageFormat == GST_ML_INPUT_IMAGE_FORMAT_HWC) {
inputDims[1] = height; inputDims[1] = height;
inputDims[2] = width; inputDims[2] = width;
} else { } else {
@ -321,7 +349,7 @@ template < typename T > std::vector < GstMlBoundingBox >
// copy video frame // copy video frame
uint8_t *srcPtr[3] = { img_data, img_data + 1, img_data + 2 }; uint8_t *srcPtr[3] = { img_data, img_data + 1, img_data + 2 };
uint32_t srcSamplesPerPixel = 3; uint32_t srcSamplesPerPixel = 3;
switch (vmeta->format) { switch (vinfo.finfo->format) {
case GST_VIDEO_FORMAT_RGBA: case GST_VIDEO_FORMAT_RGBA:
srcSamplesPerPixel = 4; srcSamplesPerPixel = 4;
break; break;
@ -352,8 +380,8 @@ template < typename T > std::vector < GstMlBoundingBox >
break; break;
} }
size_t destIndex = 0; size_t destIndex = 0;
uint32_t stride = vmeta->stride[0]; uint32_t stride = vinfo.stride[0];
if (inputImageFormat == GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC) { if (inputImageFormat == GST_ML_INPUT_IMAGE_FORMAT_HWC) {
for (int32_t j = 0; j < height; ++j) { for (int32_t j = 0; j < height; ++j) {
for (int32_t i = 0; i < width; ++i) { for (int32_t i = 0; i < width; ++i) {
for (int32_t k = 0; k < channels; ++k) { for (int32_t k = 0; k < channels; ++k) {
@ -389,58 +417,17 @@ template < typename T > std::vector < GstMlBoundingBox >
std::vector < Ort::Value > inputTensors; std::vector < Ort::Value > inputTensors;
inputTensors.push_back (Ort::Value::CreateTensor < uint8_t > (memoryInfo, inputTensors.push_back (Ort::Value::CreateTensor < uint8_t > (memoryInfo,
dest, inputTensorSize, inputDims.data (), inputDims.size ())); dest, inputTensorSize, inputDims.data (), inputDims.size ()));
std::vector < const char *>inputNames { inputName.get () }; std::vector < const char *>inputNames
{
inputName.get ()};
std::vector < Ort::Value > modelOutput = session->Run (Ort::RunOptions { nullptr}, modelOutput = session->Run (Ort::RunOptions {
nullptr},
inputNames.data (), inputNames.data (),
inputTensors.data (), 1, outputNamesRaw.data (), outputNamesRaw.size ()); inputTensors.data (), 1, outputNamesRaw.data (),
outputNamesRaw.size ());
auto numDetections = return true;
modelOutput[getOutputNodeIndex
(GST_ML_OUTPUT_NODE_FUNCTION_DETECTION)].GetTensorMutableData < float >();
auto bboxes =
modelOutput[getOutputNodeIndex
(GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX)].GetTensorMutableData < float >();
auto scores =
modelOutput[getOutputNodeIndex
(GST_ML_OUTPUT_NODE_FUNCTION_SCORE)].GetTensorMutableData < float >();
T *labelIndex = nullptr;
if (getOutputNodeIndex (GST_ML_OUTPUT_NODE_FUNCTION_CLASS) !=
GST_ML_NODE_INDEX_DISABLED) {
labelIndex =
modelOutput[getOutputNodeIndex
(GST_ML_OUTPUT_NODE_FUNCTION_CLASS)].GetTensorMutableData < T > ();
}
if (labels.empty () && !labelPath.empty ())
labels = ReadLabels (labelPath);
for (int i = 0; i < numDetections[0]; ++i) {
if (scores[i] > scoreThreshold) {
std::string label = "";
if (labelIndex && !labels.empty ())
label = labels[labelIndex[i] - 1];
auto score = scores[i];
auto y0 = bboxes[i * 4] * height;
auto x0 = bboxes[i * 4 + 1] * width;
auto bheight = bboxes[i * 4 + 2] * height - y0;
auto bwidth = bboxes[i * 4 + 3] * width - x0;
boundingBoxes.push_back (GstMlBoundingBox (label, score, x0, y0, bwidth,
bheight));
}
}
return boundingBoxes;
}
std::vector < std::string >
GstOnnxClient::ReadLabels (const std::string & labelsFile)
{
std::vector < std::string > labels;
std::string line;
std::ifstream fp (labelsFile);
while (std::getline (fp, line))
labels.push_back (line);
return labels;
} }
} }

View file

@ -25,45 +25,11 @@
#include <gst/gst.h> #include <gst/gst.h>
#include <onnxruntime_cxx_api.h> #include <onnxruntime_cxx_api.h>
#include <gst/video/video.h> #include <gst/video/video.h>
#include "gstonnxelement.h" #include "gstml.h"
#include <string> #include "gstonnxenums.h"
#include <vector> #include "tensor/gsttensormeta.h"
namespace GstOnnxNamespace { namespace GstOnnxNamespace {
enum GstMlOutputNodeFunction {
GST_ML_OUTPUT_NODE_FUNCTION_DETECTION,
GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX,
GST_ML_OUTPUT_NODE_FUNCTION_SCORE,
GST_ML_OUTPUT_NODE_FUNCTION_CLASS,
GST_ML_OUTPUT_NODE_NUMBER_OF,
};
const gint GST_ML_NODE_INDEX_DISABLED = -1;
struct GstMlOutputNodeInfo {
GstMlOutputNodeInfo(void);
gint index;
ONNXTensorElementDataType type;
};
struct GstMlBoundingBox {
GstMlBoundingBox(std::string lbl,
float score,
float _x0,
float _y0,
float _width,
float _height):label(lbl),
score(score), x0(_x0), y0(_y0), width(_width), height(_height) {
}
GstMlBoundingBox():GstMlBoundingBox("", 0.0f, 0.0f, 0.0f, 0.0f, 0.0f) {
}
std::string label;
float score;
float x0;
float y0;
float width;
float height;
};
class GstOnnxClient { class GstOnnxClient {
public: public:
@ -72,30 +38,18 @@ namespace GstOnnxNamespace {
bool createSession(std::string modelFile, GstOnnxOptimizationLevel optim, bool createSession(std::string modelFile, GstOnnxOptimizationLevel optim,
GstOnnxExecutionProvider provider); GstOnnxExecutionProvider provider);
bool hasSession(void); bool hasSession(void);
void setInputImageFormat(GstMlModelInputImageFormat format); void setInputImageFormat(GstMlInputImageFormat format);
GstMlModelInputImageFormat getInputImageFormat(void); GstMlInputImageFormat getInputImageFormat(void);
void setOutputNodeIndex(GstMlOutputNodeFunction nodeType, gint index); std::vector < Ort::Value > run (uint8_t * img_data, GstVideoInfo vinfo);
gint getOutputNodeIndex(GstMlOutputNodeFunction nodeType); std::vector < const char *> genOutputNamesRaw(void);
void setOutputNodeType(GstMlOutputNodeFunction nodeType,
ONNXTensorElementDataType type);
ONNXTensorElementDataType getOutputNodeType(GstMlOutputNodeFunction type);
std::string getOutputNodeName(GstMlOutputNodeFunction nodeType);
std::vector < GstMlBoundingBox > run(uint8_t * img_data,
GstVideoMeta * vmeta,
std::string labelPath,
float scoreThreshold);
std::vector < GstMlBoundingBox > &getBoundingBoxes(void);
std::vector < const char *>getOutputNodeNames(void);
bool isFixedInputImageSize(void); bool isFixedInputImageSize(void);
int32_t getWidth(void); int32_t getWidth(void);
int32_t getHeight(void); int32_t getHeight(void);
GstTensorMeta* copy_tensors_to_meta(std::vector < Ort::Value > &outputs,GstBuffer* buffer);
void parseDimensions(GstVideoInfo vinfo);
private: private:
void parseDimensions(GstVideoMeta * vmeta); bool doRun(uint8_t * img_data, GstVideoInfo vinfo, std::vector < Ort::Value > &modelOutput);
template < typename T > std::vector < GstMlBoundingBox > Ort::Env env;
doRun(uint8_t * img_data, GstVideoMeta * vmeta, std::string labelPath,
float scoreThreshold);
std::vector < std::string > ReadLabels(const std::string & labelsFile);
Ort::Env & getEnv(void);
Ort::Session * session; Ort::Session * session;
int32_t width; int32_t width;
int32_t height; int32_t height;
@ -104,13 +58,10 @@ namespace GstOnnxNamespace {
GstOnnxExecutionProvider m_provider; GstOnnxExecutionProvider m_provider;
std::vector < Ort::Value > modelOutput; std::vector < Ort::Value > modelOutput;
std::vector < std::string > labels; std::vector < std::string > labels;
// !! indexed by function
GstMlOutputNodeInfo outputNodeInfo[GST_ML_OUTPUT_NODE_NUMBER_OF];
// !! indexed by array index
size_t outputNodeIndexToFunction[GST_ML_OUTPUT_NODE_NUMBER_OF];
std::vector < const char *> outputNamesRaw; std::vector < const char *> outputNamesRaw;
std::vector < Ort::AllocatedStringPtr > outputNames; std::vector < Ort::AllocatedStringPtr > outputNames;
GstMlModelInputImageFormat inputImageFormat; std::vector < GQuark > outputIds;
GstMlInputImageFormat inputImageFormat;
bool fixedInputImageSize; bool fixedInputImageSize;
}; };
} }

View file

@ -1,104 +0,0 @@
/*
* GStreamer gstreamer-onnxelement
* Copyright (C) 2021 Collabora Ltd
*
* gstonnxelement.c
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public
* License along with this library; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
#ifdef HAVE_CONFIG_H
#include "config.h"
#endif
#include "gstonnxelement.h"
GType
gst_onnx_optimization_level_get_type (void)
{
static GType onnx_optimization_type = 0;
if (g_once_init_enter (&onnx_optimization_type)) {
static GEnumValue optimization_level_types[] = {
{GST_ONNX_OPTIMIZATION_LEVEL_DISABLE_ALL, "Disable all optimization",
"disable-all"},
{GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_BASIC,
"Enable basic optimizations (redundant node removals))",
"enable-basic"},
{GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED,
"Enable extended optimizations (redundant node removals + node fusions)",
"enable-extended"},
{GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_ALL,
"Enable all possible optimizations", "enable-all"},
{0, NULL, NULL},
};
GType temp = g_enum_register_static ("GstOnnxOptimizationLevel",
optimization_level_types);
g_once_init_leave (&onnx_optimization_type, temp);
}
return onnx_optimization_type;
}
GType
gst_onnx_execution_provider_get_type (void)
{
static GType onnx_execution_type = 0;
if (g_once_init_enter (&onnx_execution_type)) {
static GEnumValue execution_provider_types[] = {
{GST_ONNX_EXECUTION_PROVIDER_CPU, "CPU execution provider",
"cpu"},
{GST_ONNX_EXECUTION_PROVIDER_CUDA,
"CUDA execution provider",
"cuda"},
{0, NULL, NULL},
};
GType temp = g_enum_register_static ("GstOnnxExecutionProvider",
execution_provider_types);
g_once_init_leave (&onnx_execution_type, temp);
}
return onnx_execution_type;
}
GType
gst_ml_model_input_image_format_get_type (void)
{
static GType ml_model_input_image_format = 0;
if (g_once_init_enter (&ml_model_input_image_format)) {
static GEnumValue ml_model_input_image_format_types[] = {
{GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC,
"Height Width Channel (HWC) a.k.a. interleaved image data format",
"hwc"},
{GST_ML_MODEL_INPUT_IMAGE_FORMAT_CHW,
"Channel Height Width (CHW) a.k.a. planar image data format",
"chw"},
{0, NULL, NULL},
};
GType temp = g_enum_register_static ("GstMlModelInputImageFormat",
ml_model_input_image_format_types);
g_once_init_leave (&ml_model_input_image_format, temp);
}
return ml_model_input_image_format;
}

View file

@ -1,8 +1,8 @@
/* /*
* GStreamer gstreamer-onnxelement * GStreamer gstreamer-onnxenums
* Copyright (C) 2021 Collabora Ltd * Copyright (C) 2021 Collabora Ltd
* *
* gstonnxelement.h * gstonnxenums.h
* *
* This library is free software; you can redistribute it and/or * This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public * modify it under the terms of the GNU Library General Public
@ -20,10 +20,8 @@
* Boston, MA 02110-1301, USA. * Boston, MA 02110-1301, USA.
*/ */
#ifndef __GST_ONNX_ELEMENT_H__ #ifndef __GST_ONNX_ENUMS_H__
#define __GST_ONNX_ELEMENT_H__ #define __GST_ONNX_ENUMS_H__
#include <gst/gst.h>
typedef enum typedef enum
{ {
@ -39,26 +37,5 @@ typedef enum
GST_ONNX_EXECUTION_PROVIDER_CUDA, GST_ONNX_EXECUTION_PROVIDER_CUDA,
} GstOnnxExecutionProvider; } GstOnnxExecutionProvider;
typedef enum {
/* Height Width Channel (a.k.a. interleaved) format */
GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC,
/* Channel Height Width (a.k.a. planar) format */ #endif /* __GST_ONNX_ENUMS_H__ */
GST_ML_MODEL_INPUT_IMAGE_FORMAT_CHW,
} GstMlModelInputImageFormat;
G_BEGIN_DECLS
GType gst_onnx_optimization_level_get_type (void);
#define GST_TYPE_ONNX_OPTIMIZATION_LEVEL (gst_onnx_optimization_level_get_type ())
GType gst_onnx_execution_provider_get_type (void);
#define GST_TYPE_ONNX_EXECUTION_PROVIDER (gst_onnx_execution_provider_get_type ())
GType gst_ml_model_input_image_format_get_type (void);
#define GST_TYPE_ML_MODEL_INPUT_IMAGE_FORMAT (gst_ml_model_input_image_format_get_type ())
G_END_DECLS
#endif

View file

@ -0,0 +1,539 @@
/*
* GStreamer gstreamer-onnxinference
* Copyright (C) 2023 Collabora Ltd.
*
* gstonnxinference.cpp
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public
* License along with this library; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
/**
* SECTION:element-onnxinference
* @short_description: Run ONNX inference model on video buffers
*
* This element can apply an ONNX model to video buffers. It attaches
* the tensor output to the buffer as a @ref GstTensorMeta.
*
* To install ONNX on your system, recursively clone this repository
* https://github.com/microsoft/onnxruntime.git
*
* and build and install with cmake:
*
* CPU:
*
* cmake -Donnxruntime_BUILD_SHARED_LIB:ON -DBUILD_TESTING:OFF \
* $SRC_DIR/onnxruntime/cmake && make -j$(nproc) && sudo make install
*
*
* CUDA :
*
* cmake -Donnxruntime_BUILD_SHARED_LIB:ON -DBUILD_TESTING:OFF -Donnxruntime_USE_CUDA:ON \
* -Donnxruntime_CUDA_HOME=$CUDA_PATH -Donnxruntime_CUDNN_HOME=$CUDA_PATH \
* $SRC_DIR/onnxruntime/cmake && make -j$(nproc) && sudo make install
*
*
* where :
*
* 1. $SRC_DIR and $BUILD_DIR are local source and build directories
* 2. To run with CUDA, both CUDA and cuDNN libraries must be installed.
* $CUDA_PATH is an environment variable set to the CUDA root path.
* On Linux, it would be /usr/local/cuda
*
*
* ## Example launch command:
*
* GST_DEBUG=onnxinference:5 gst-launch-1.0 multifilesrc location=bus.jpg ! \
* jpegdec ! videoconvert ! \
* onnxinference execution-provider=cpu model-file=model.onnx \
* videoconvert ! autovideosink
*
*
* Note: in order for downstream tensor decoders to correctly parse the tensor
* data in the GstTensorMeta, meta data must be attached to the ONNX model
* assigning a unique string id to each output layer. These unique string ids
* and corresponding GQuark ids are currently stored in the ONNX plugin source
* in the file 'gsttensorid.h'. For an output layer with name Foo and with context
* unique string id Gst.Model.Bar, a meta data key/value pair must be added
* to the ONNX model with "Foo" mapped to "Gst.Model.Bar" in order for a downstream
* decoder to make use of this model. If the meta data is absent, the pipeline will
* fail.
*
* As a convenience, there is a python script
* currently stored at
* https://gitlab.collabora.com/gstreamer/onnx-models/-/blob/master/scripts/modify_onnx_metadata.py
* to enable users to easily add and remove meta data from json files. It can also dump
* the names of all output layers, which can then be used to craft the json meta data file.
*
*
*/
#ifdef HAVE_CONFIG_H
#include "config.h"
#endif
#include <gst/gst.h>
#include "gstonnxinference.h"
#include "gstonnxclient.h"
GST_DEBUG_CATEGORY_STATIC (onnx_inference_debug);
#define GST_CAT_DEFAULT onnx_inference_debug
#define GST_ONNX_CLIENT_MEMBER( self ) ((GstOnnxNamespace::GstOnnxClient *) (self->onnx_client))
GST_ELEMENT_REGISTER_DEFINE (onnx_inference, "onnxinference",
GST_RANK_PRIMARY, GST_TYPE_ONNX_INFERENCE);
/* GstOnnxInference properties */
enum
{
PROP_0,
PROP_MODEL_FILE,
PROP_INPUT_IMAGE_FORMAT,
PROP_OPTIMIZATION_LEVEL,
PROP_EXECUTION_PROVIDER
};
#define GST_ONNX_INFERENCE_DEFAULT_EXECUTION_PROVIDER GST_ONNX_EXECUTION_PROVIDER_CPU
#define GST_ONNX_INFERENCE_DEFAULT_OPTIMIZATION_LEVEL GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED
static GstStaticPadTemplate gst_onnx_inference_src_template =
GST_STATIC_PAD_TEMPLATE ("src",
GST_PAD_SRC,
GST_PAD_ALWAYS,
GST_STATIC_CAPS (GST_VIDEO_CAPS_MAKE ("{ RGB,RGBA,BGR,BGRA }"))
);
static GstStaticPadTemplate gst_onnx_inference_sink_template =
GST_STATIC_PAD_TEMPLATE ("sink",
GST_PAD_SINK,
GST_PAD_ALWAYS,
GST_STATIC_CAPS (GST_VIDEO_CAPS_MAKE ("{ RGB,RGBA,BGR,BGRA }"))
);
static void gst_onnx_inference_set_property (GObject * object,
guint prop_id, const GValue * value, GParamSpec * pspec);
static void gst_onnx_inference_get_property (GObject * object,
guint prop_id, GValue * value, GParamSpec * pspec);
static void gst_onnx_inference_finalize (GObject * object);
static GstFlowReturn gst_onnx_inference_transform_ip (GstBaseTransform *
trans, GstBuffer * buf);
static gboolean gst_onnx_inference_process (GstBaseTransform * trans,
GstBuffer * buf);
static gboolean gst_onnx_inference_create_session (GstBaseTransform * trans);
static GstCaps *gst_onnx_inference_transform_caps (GstBaseTransform *
trans, GstPadDirection direction, GstCaps * caps, GstCaps * filter_caps);
static gboolean
gst_onnx_inference_set_caps (GstBaseTransform * trans, GstCaps * incaps,
GstCaps * outcaps);
G_DEFINE_TYPE (GstOnnxInference, gst_onnx_inference, GST_TYPE_BASE_TRANSFORM);
GType gst_onnx_optimization_level_get_type (void);
#define GST_TYPE_ONNX_OPTIMIZATION_LEVEL (gst_onnx_optimization_level_get_type ())
GType gst_onnx_execution_provider_get_type (void);
#define GST_TYPE_ONNX_EXECUTION_PROVIDER (gst_onnx_execution_provider_get_type ())
GType gst_ml_model_input_image_format_get_type (void);
#define GST_TYPE_ML_MODEL_INPUT_IMAGE_FORMAT (gst_ml_model_input_image_format_get_type ())
GType
gst_onnx_optimization_level_get_type (void)
{
static GType onnx_optimization_type = 0;
if (g_once_init_enter (&onnx_optimization_type)) {
static GEnumValue optimization_level_types[] = {
{GST_ONNX_OPTIMIZATION_LEVEL_DISABLE_ALL, "Disable all optimization",
"disable-all"},
{GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_BASIC,
"Enable basic optimizations (redundant node removals))",
"enable-basic"},
{GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED,
"Enable extended optimizations (redundant node removals + node fusions)",
"enable-extended"},
{GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_ALL,
"Enable all possible optimizations", "enable-all"},
{0, NULL, NULL},
};
GType temp = g_enum_register_static ("GstOnnxOptimizationLevel",
optimization_level_types);
g_once_init_leave (&onnx_optimization_type, temp);
}
return onnx_optimization_type;
}
GType
gst_onnx_execution_provider_get_type (void)
{
static GType onnx_execution_type = 0;
if (g_once_init_enter (&onnx_execution_type)) {
static GEnumValue execution_provider_types[] = {
{GST_ONNX_EXECUTION_PROVIDER_CPU, "CPU execution provider",
"cpu"},
{GST_ONNX_EXECUTION_PROVIDER_CUDA,
"CUDA execution provider",
"cuda"},
{0, NULL, NULL},
};
GType temp = g_enum_register_static ("GstOnnxExecutionProvider",
execution_provider_types);
g_once_init_leave (&onnx_execution_type, temp);
}
return onnx_execution_type;
}
GType
gst_ml_model_input_image_format_get_type (void)
{
static GType ml_model_input_image_format = 0;
if (g_once_init_enter (&ml_model_input_image_format)) {
static GEnumValue ml_model_input_image_format_types[] = {
{GST_ML_INPUT_IMAGE_FORMAT_HWC,
"Height Width Channel (HWC) a.k.a. interleaved image data format",
"hwc"},
{GST_ML_INPUT_IMAGE_FORMAT_CHW,
"Channel Height Width (CHW) a.k.a. planar image data format",
"chw"},
{0, NULL, NULL},
};
GType temp = g_enum_register_static ("GstMlInputImageFormat",
ml_model_input_image_format_types);
g_once_init_leave (&ml_model_input_image_format, temp);
}
return ml_model_input_image_format;
}
static void
gst_onnx_inference_class_init (GstOnnxInferenceClass * klass)
{
GObjectClass *gobject_class = (GObjectClass *) klass;
GstElementClass *element_class = (GstElementClass *) klass;
GstBaseTransformClass *basetransform_class = (GstBaseTransformClass *) klass;
GST_DEBUG_CATEGORY_INIT (onnx_inference_debug, "onnxinference",
0, "onnx_inference");
gobject_class->set_property = gst_onnx_inference_set_property;
gobject_class->get_property = gst_onnx_inference_get_property;
gobject_class->finalize = gst_onnx_inference_finalize;
/**
* GstOnnxInference:model-file
*
* ONNX model file
*
* Since: 1.24
*/
g_object_class_install_property (G_OBJECT_CLASS (klass), PROP_MODEL_FILE,
g_param_spec_string ("model-file",
"ONNX model file", "ONNX model file", NULL, (GParamFlags)
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
/**
* GstOnnxInference:input-image-format
*
* Model input image format
*
* Since: 1.24
*/
g_object_class_install_property (G_OBJECT_CLASS (klass),
PROP_INPUT_IMAGE_FORMAT,
g_param_spec_enum ("input-image-format",
"Input image format",
"Input image format",
GST_TYPE_ML_MODEL_INPUT_IMAGE_FORMAT,
GST_ML_INPUT_IMAGE_FORMAT_HWC, (GParamFlags)
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
/**
* GstOnnxInference:optimization-level
*
* ONNX optimization level
*
* Since: 1.24
*/
g_object_class_install_property (G_OBJECT_CLASS (klass),
PROP_OPTIMIZATION_LEVEL,
g_param_spec_enum ("optimization-level",
"Optimization level",
"ONNX optimization level",
GST_TYPE_ONNX_OPTIMIZATION_LEVEL,
GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED, (GParamFlags)
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
/**
* GstOnnxInference:execution-provider
*
* ONNX execution provider
*
* Since: 1.24
*/
g_object_class_install_property (G_OBJECT_CLASS (klass),
PROP_EXECUTION_PROVIDER,
g_param_spec_enum ("execution-provider",
"Execution provider",
"ONNX execution provider",
GST_TYPE_ONNX_EXECUTION_PROVIDER,
GST_ONNX_EXECUTION_PROVIDER_CPU, (GParamFlags)
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
gst_element_class_set_static_metadata (element_class, "onnxinference",
"Filter/Effect/Video",
"Apply neural network to video frames and create tensor output",
"Aaron Boxer <aaron.boxer@collabora.com>");
gst_element_class_add_pad_template (element_class,
gst_static_pad_template_get (&gst_onnx_inference_sink_template));
gst_element_class_add_pad_template (element_class,
gst_static_pad_template_get (&gst_onnx_inference_src_template));
basetransform_class->transform_ip =
GST_DEBUG_FUNCPTR (gst_onnx_inference_transform_ip);
basetransform_class->transform_caps =
GST_DEBUG_FUNCPTR (gst_onnx_inference_transform_caps);
basetransform_class->set_caps =
GST_DEBUG_FUNCPTR (gst_onnx_inference_set_caps);
}
static void
gst_onnx_inference_init (GstOnnxInference * self)
{
self->onnx_client = new GstOnnxNamespace::GstOnnxClient ();
self->onnx_disabled = TRUE;
}
static void
gst_onnx_inference_finalize (GObject * object)
{
GstOnnxInference *self = GST_ONNX_INFERENCE (object);
g_free (self->model_file);
delete GST_ONNX_CLIENT_MEMBER (self);
G_OBJECT_CLASS (gst_onnx_inference_parent_class)->finalize (object);
}
static void
gst_onnx_inference_set_property (GObject * object, guint prop_id,
const GValue * value, GParamSpec * pspec)
{
GstOnnxInference *self = GST_ONNX_INFERENCE (object);
const gchar *filename;
auto onnxClient = GST_ONNX_CLIENT_MEMBER (self);
switch (prop_id) {
case PROP_MODEL_FILE:
filename = g_value_get_string (value);
if (filename
&& g_file_test (filename,
(GFileTest) (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR))) {
if (self->model_file)
g_free (self->model_file);
self->model_file = g_strdup (filename);
self->onnx_disabled = FALSE;
} else {
GST_WARNING_OBJECT (self, "Model file '%s' not found!", filename);
}
break;
case PROP_OPTIMIZATION_LEVEL:
self->optimization_level =
(GstOnnxOptimizationLevel) g_value_get_enum (value);
break;
case PROP_EXECUTION_PROVIDER:
self->execution_provider =
(GstOnnxExecutionProvider) g_value_get_enum (value);
break;
case PROP_INPUT_IMAGE_FORMAT:
onnxClient->setInputImageFormat ((GstMlInputImageFormat)
g_value_get_enum (value));
break;
default:
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
break;
}
}
static void
gst_onnx_inference_get_property (GObject * object, guint prop_id,
GValue * value, GParamSpec * pspec)
{
GstOnnxInference *self = GST_ONNX_INFERENCE (object);
auto onnxClient = GST_ONNX_CLIENT_MEMBER (self);
switch (prop_id) {
case PROP_MODEL_FILE:
g_value_set_string (value, self->model_file);
break;
case PROP_OPTIMIZATION_LEVEL:
g_value_set_enum (value, self->optimization_level);
break;
case PROP_EXECUTION_PROVIDER:
g_value_set_enum (value, self->execution_provider);
break;
case PROP_INPUT_IMAGE_FORMAT:
g_value_set_enum (value, onnxClient->getInputImageFormat ());
break;
default:
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
break;
}
}
static gboolean
gst_onnx_inference_create_session (GstBaseTransform * trans)
{
GstOnnxInference *self = GST_ONNX_INFERENCE (trans);
auto onnxClient = GST_ONNX_CLIENT_MEMBER (self);
GST_OBJECT_LOCK (self);
if (self->onnx_disabled) {
GST_OBJECT_UNLOCK (self);
return FALSE;
}
if (onnxClient->hasSession ()) {
GST_OBJECT_UNLOCK (self);
return TRUE;
}
if (self->model_file) {
gboolean ret =
GST_ONNX_CLIENT_MEMBER (self)->createSession (self->model_file,
self->optimization_level,
self->execution_provider);
if (!ret) {
GST_ERROR_OBJECT (self,
"Unable to create ONNX session. Model is disabled.");
self->onnx_disabled = TRUE;
}
} else {
self->onnx_disabled = TRUE;
GST_ELEMENT_ERROR (self, STREAM, FAILED, (NULL), ("Model file not found"));
}
GST_OBJECT_UNLOCK (self);
if (self->onnx_disabled) {
gst_base_transform_set_passthrough (GST_BASE_TRANSFORM (self), TRUE);
}
return TRUE;
}
static GstCaps *
gst_onnx_inference_transform_caps (GstBaseTransform *
trans, GstPadDirection direction, GstCaps * caps, GstCaps * filter_caps)
{
GstOnnxInference *self = GST_ONNX_INFERENCE (trans);
auto onnxClient = GST_ONNX_CLIENT_MEMBER (self);
GstCaps *other_caps;
guint i;
if (!gst_onnx_inference_create_session (trans))
return NULL;
GST_LOG_OBJECT (self, "transforming caps %" GST_PTR_FORMAT, caps);
if (gst_base_transform_is_passthrough (trans)
|| (!onnxClient->isFixedInputImageSize ()))
return gst_caps_ref (caps);
other_caps = gst_caps_new_empty ();
for (i = 0; i < gst_caps_get_size (caps); ++i) {
GstStructure *structure, *new_structure;
structure = gst_caps_get_structure (caps, i);
new_structure = gst_structure_copy (structure);
gst_structure_set (new_structure, "width", G_TYPE_INT,
onnxClient->getWidth (), "height", G_TYPE_INT,
onnxClient->getHeight (), NULL);
GST_LOG_OBJECT (self,
"transformed structure %2d: %" GST_PTR_FORMAT " => %"
GST_PTR_FORMAT, i, structure, new_structure);
gst_caps_append_structure (other_caps, new_structure);
}
if (!gst_caps_is_empty (other_caps) && filter_caps) {
GstCaps *tmp = gst_caps_intersect_full (other_caps, filter_caps,
GST_CAPS_INTERSECT_FIRST);
gst_caps_replace (&other_caps, tmp);
gst_caps_unref (tmp);
}
return other_caps;
}
static gboolean
gst_onnx_inference_set_caps (GstBaseTransform * trans, GstCaps * incaps,
GstCaps * outcaps)
{
GstOnnxInference *self = GST_ONNX_INFERENCE (trans);
auto onnxClient = GST_ONNX_CLIENT_MEMBER (self);
if (!gst_video_info_from_caps (&self->video_info, incaps)) {
GST_ERROR_OBJECT (self, "Failed to parse caps");
return FALSE;
}
onnxClient->parseDimensions (self->video_info);
return TRUE;
}
static GstFlowReturn
gst_onnx_inference_transform_ip (GstBaseTransform * trans, GstBuffer * buf)
{
if (!gst_base_transform_is_passthrough (trans)
&& !gst_onnx_inference_process (trans, buf)) {
GST_ELEMENT_ERROR (trans, STREAM, FAILED,
(NULL), ("ONNX inference failed"));
return GST_FLOW_ERROR;
}
return GST_FLOW_OK;
}
static gboolean
gst_onnx_inference_process (GstBaseTransform * trans, GstBuffer * buf)
{
GstMapInfo info;
GstVideoMeta *vmeta = gst_buffer_get_video_meta (buf);
if (!vmeta) {
GST_WARNING_OBJECT (trans, "missing video meta");
return FALSE;
}
if (gst_buffer_map (buf, &info, GST_MAP_READ)) {
GstOnnxInference *self = GST_ONNX_INFERENCE (trans);
try {
auto client = GST_ONNX_CLIENT_MEMBER (self);
auto outputs = client->run (info.data, self->video_info);
auto meta = client->copy_tensors_to_meta (outputs, buf);
if (!meta)
return FALSE;
meta->batch_size = 1;
}
catch (Ort::Exception & ortex) {
GST_ERROR_OBJECT (self, "%s", ortex.what ());
return FALSE;
}
gst_buffer_unmap (buf, &info);
}
return TRUE;
}

View file

@ -0,0 +1,64 @@
/*
* GStreamer gstreamer-onnxinference
* Copyright (C) 2023 Collabora Ltd
*
* gstonnxinference.h
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public
* License along with this library; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
#ifndef __GST_ONNX_INFERENCE_H__
#define __GST_ONNX_INFERENCE_H__
#include <gst/gst.h>
#include <gst/video/video.h>
#include <gst/video/gstvideofilter.h>
#include "gstonnxenums.h"
G_BEGIN_DECLS
#define GST_TYPE_ONNX_INFERENCE (gst_onnx_inference_get_type())
G_DECLARE_FINAL_TYPE (GstOnnxInference, gst_onnx_inference, GST,
ONNX_INFERENCE, GstBaseTransform)
/**
* GstOnnxInference:
*
* @model_file model file
* @optimization_level: ONNX session optimization level
* @execution_provider: ONNX execution provider
* @onnx_client opaque pointer to ONNX client
* @onnx_disabled true if inference is disabled
* @video_info @ref GstVideoInfo of sink caps
*
* Since: 1.24
*/
struct _GstOnnxInference
{
GstBaseTransform basetransform;
gchar *model_file;
GstOnnxOptimizationLevel optimization_level;
GstOnnxExecutionProvider execution_provider;
gpointer onnx_client;
gboolean onnx_disabled;
GstVideoInfo video_info;
};
GST_ELEMENT_REGISTER_DECLARE (onnx_inference)
G_END_DECLS
#endif /* __GST_ONNX_INFERENCE_H__ */

View file

@ -1,684 +0,0 @@
/*
* GStreamer gstreamer-onnxobjectdetector
* Copyright (C) 2021 Collabora Ltd.
*
* gstonnxobjectdetector.c
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public
* License along with this library; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
/**
* SECTION:element-onnxobjectdetector
* @short_description: Detect objects in video frame
*
* This element can apply a generic ONNX object detection model such as YOLO or SSD
* to each video frame.
*
* To install ONNX on your system, recursively clone this repository
* https://github.com/microsoft/onnxruntime.git
*
* and build and install with cmake:
*
* CPU:
*
* cmake -Donnxruntime_BUILD_SHARED_LIB:ON -DBUILD_TESTING:OFF \
* $SRC_DIR/onnxruntime/cmake && make -j8 && sudo make install
*
*
* GPU :
*
* cmake -Donnxruntime_BUILD_SHARED_LIB:ON -DBUILD_TESTING:OFF -Donnxruntime_USE_CUDA:ON \
* -Donnxruntime_CUDA_HOME=$CUDA_PATH -Donnxruntime_CUDNN_HOME=$CUDA_PATH \
* $SRC_DIR/onnxruntime/cmake && make -j8 && sudo make install
*
*
* where :
*
* 1. $SRC_DIR and $BUILD_DIR are local source and build directories
* 2. To run with CUDA, both CUDA and cuDNN libraries must be installed.
* $CUDA_PATH is an environment variable set to the CUDA root path.
* On Linux, it would be /usr/local/cuda-XX.X where XX.X is the installed version of CUDA.
*
*
* ## Example launch command:
*
* (note: an object detection model has 3 or 4 output nodes, but there is no
* naming convention to indicate which node outputs the bounding box, which
* node outputs the label, etc. So, the `onnxobjectdetector` element has
* properties to map each node's functionality to its respective node index in
* the specified model. Image resolution also need to be adapted to the model.
* The videoscale in the pipeline below will scale the image, using padding if
* required, to 640x383 resolution required by the model.)
*
* model.onnx can be found here:
* https://github.com/zoq/onnx-runtime-examples/raw/main/data/models/model.onnx
*
* ```
* GST_DEBUG=objectdetector:5 gst-launch-1.0 multifilesrc \
* location=000000088462.jpg caps=image/jpeg,framerate=\(fraction\)30/1 ! jpegdec ! \
* videoconvert ! \
* videoscale ! \
* 'video/x-raw,width=640,height=383' ! \
* onnxobjectdetector \
* box-node-index=0 \
* class-node-index=1 \
* score-node-index=2 \
* detection-node-index=3 \
* execution-provider=cpu \
* model-file=model.onnx \
* label-file=COCO_classes.txt ! \
* videoconvert ! \
* autovideosink
* ```
*/
#ifdef HAVE_CONFIG_H
#include "config.h"
#endif
#include "gstonnxobjectdetector.h"
#include "gstonnxclient.h"
#include <gst/gst.h>
#include <gst/video/video.h>
#include <gst/video/gstvideometa.h>
#include <stdlib.h>
#include <string.h>
#include <glib.h>
GST_DEBUG_CATEGORY_STATIC (onnx_object_detector_debug);
#define GST_CAT_DEFAULT onnx_object_detector_debug
#define GST_ONNX_MEMBER( self ) ((GstOnnxNamespace::GstOnnxClient *) (self->onnx_ptr))
GST_ELEMENT_REGISTER_DEFINE (onnx_object_detector, "onnxobjectdetector",
GST_RANK_PRIMARY, GST_TYPE_ONNX_OBJECT_DETECTOR);
/* GstOnnxObjectDetector properties */
enum
{
PROP_0,
PROP_MODEL_FILE,
PROP_LABEL_FILE,
PROP_SCORE_THRESHOLD,
PROP_DETECTION_NODE_INDEX,
PROP_BOUNDING_BOX_NODE_INDEX,
PROP_SCORE_NODE_INDEX,
PROP_CLASS_NODE_INDEX,
PROP_INPUT_IMAGE_FORMAT,
PROP_OPTIMIZATION_LEVEL,
PROP_EXECUTION_PROVIDER
};
#define GST_ONNX_OBJECT_DETECTOR_DEFAULT_EXECUTION_PROVIDER GST_ONNX_EXECUTION_PROVIDER_CPU
#define GST_ONNX_OBJECT_DETECTOR_DEFAULT_OPTIMIZATION_LEVEL GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED
#define GST_ONNX_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD 0.3f /* 0 to 1 */
static GstStaticPadTemplate gst_onnx_object_detector_src_template =
GST_STATIC_PAD_TEMPLATE ("src",
GST_PAD_SRC,
GST_PAD_ALWAYS,
GST_STATIC_CAPS (GST_VIDEO_CAPS_MAKE ("{ RGB,RGBA,BGR,BGRA }"))
);
static GstStaticPadTemplate gst_onnx_object_detector_sink_template =
GST_STATIC_PAD_TEMPLATE ("sink",
GST_PAD_SINK,
GST_PAD_ALWAYS,
GST_STATIC_CAPS (GST_VIDEO_CAPS_MAKE ("{ RGB,RGBA,BGR,BGRA }"))
);
static void gst_onnx_object_detector_set_property (GObject * object,
guint prop_id, const GValue * value, GParamSpec * pspec);
static void gst_onnx_object_detector_get_property (GObject * object,
guint prop_id, GValue * value, GParamSpec * pspec);
static void gst_onnx_object_detector_finalize (GObject * object);
static GstFlowReturn gst_onnx_object_detector_transform_ip (GstBaseTransform *
trans, GstBuffer * buf);
static gboolean gst_onnx_object_detector_process (GstBaseTransform * trans,
GstBuffer * buf);
static gboolean gst_onnx_object_detector_create_session (GstBaseTransform * trans);
static GstCaps *gst_onnx_object_detector_transform_caps (GstBaseTransform *
trans, GstPadDirection direction, GstCaps * caps, GstCaps * filter_caps);
G_DEFINE_TYPE (GstOnnxObjectDetector, gst_onnx_object_detector,
GST_TYPE_BASE_TRANSFORM);
static void
gst_onnx_object_detector_class_init (GstOnnxObjectDetectorClass * klass)
{
GObjectClass *gobject_class = (GObjectClass *) klass;
GstElementClass *element_class = (GstElementClass *) klass;
GstBaseTransformClass *basetransform_class = (GstBaseTransformClass *) klass;
GST_DEBUG_CATEGORY_INIT (onnx_object_detector_debug, "onnxobjectdetector",
0, "onnx_objectdetector");
gobject_class->set_property = gst_onnx_object_detector_set_property;
gobject_class->get_property = gst_onnx_object_detector_get_property;
gobject_class->finalize = gst_onnx_object_detector_finalize;
/**
* GstOnnxObjectDetector:model-file
*
* ONNX model file
*
* Since: 1.20
*/
g_object_class_install_property (G_OBJECT_CLASS (klass), PROP_MODEL_FILE,
g_param_spec_string ("model-file",
"ONNX model file", "ONNX model file", NULL, (GParamFlags)
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
/**
* GstOnnxObjectDetector:label-file
*
* Label file for ONNX model
*
* Since: 1.20
*/
g_object_class_install_property (G_OBJECT_CLASS (klass), PROP_LABEL_FILE,
g_param_spec_string ("label-file",
"Label file", "Label file associated with model", NULL, (GParamFlags)
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
/**
* GstOnnxObjectDetector:detection-node-index
*
* Index of model detection node
*
* Since: 1.20
*/
g_object_class_install_property (G_OBJECT_CLASS (klass),
PROP_DETECTION_NODE_INDEX,
g_param_spec_int ("detection-node-index",
"Detection node index",
"Index of neural network output node corresponding to number of detected objects",
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED,
GstOnnxNamespace::GST_ML_OUTPUT_NODE_NUMBER_OF-1,
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED, (GParamFlags)
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
/**
* GstOnnxObjectDetector:bounding-box-node-index
*
* Index of model bounding box node
*
* Since: 1.20
*/
g_object_class_install_property (G_OBJECT_CLASS (klass),
PROP_BOUNDING_BOX_NODE_INDEX,
g_param_spec_int ("box-node-index",
"Bounding box node index",
"Index of neural network output node corresponding to bounding box",
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED,
GstOnnxNamespace::GST_ML_OUTPUT_NODE_NUMBER_OF-1,
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED, (GParamFlags)
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
/**
* GstOnnxObjectDetector:score-node-index
*
* Index of model score node
*
* Since: 1.20
*/
g_object_class_install_property (G_OBJECT_CLASS (klass),
PROP_SCORE_NODE_INDEX,
g_param_spec_int ("score-node-index",
"Score node index",
"Index of neural network output node corresponding to score",
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED,
GstOnnxNamespace::GST_ML_OUTPUT_NODE_NUMBER_OF-1,
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED, (GParamFlags)
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
/**
* GstOnnxObjectDetector:class-node-index
*
* Index of model class (label) node
*
* Since: 1.20
*/
g_object_class_install_property (G_OBJECT_CLASS (klass),
PROP_CLASS_NODE_INDEX,
g_param_spec_int ("class-node-index",
"Class node index",
"Index of neural network output node corresponding to class (label)",
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED,
GstOnnxNamespace::GST_ML_OUTPUT_NODE_NUMBER_OF-1,
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED, (GParamFlags)
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
/**
* GstOnnxObjectDetector:score-threshold
*
* Threshold for deciding when to remove boxes based on score
*
* Since: 1.20
*/
g_object_class_install_property (G_OBJECT_CLASS (klass), PROP_SCORE_THRESHOLD,
g_param_spec_float ("score-threshold",
"Score threshold",
"Threshold for deciding when to remove boxes based on score",
0.0, 1.0,
GST_ONNX_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD, (GParamFlags)
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
/**
* GstOnnxObjectDetector:input-image-format
*
* Model input image format
*
* Since: 1.20
*/
g_object_class_install_property (G_OBJECT_CLASS (klass),
PROP_INPUT_IMAGE_FORMAT,
g_param_spec_enum ("input-image-format",
"Input image format",
"Input image format",
GST_TYPE_ML_MODEL_INPUT_IMAGE_FORMAT,
GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC, (GParamFlags)
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
/**
* GstOnnxObjectDetector:optimization-level
*
* ONNX optimization level
*
* Since: 1.20
*/
g_object_class_install_property (G_OBJECT_CLASS (klass),
PROP_OPTIMIZATION_LEVEL,
g_param_spec_enum ("optimization-level",
"Optimization level",
"ONNX optimization level",
GST_TYPE_ONNX_OPTIMIZATION_LEVEL,
GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED, (GParamFlags)
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
/**
* GstOnnxObjectDetector:execution-provider
*
* ONNX execution provider
*
* Since: 1.20
*/
g_object_class_install_property (G_OBJECT_CLASS (klass),
PROP_EXECUTION_PROVIDER,
g_param_spec_enum ("execution-provider",
"Execution provider",
"ONNX execution provider",
GST_TYPE_ONNX_EXECUTION_PROVIDER,
GST_ONNX_EXECUTION_PROVIDER_CPU, (GParamFlags)
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
gst_element_class_set_static_metadata (element_class, "onnxobjectdetector",
"Filter/Effect/Video",
"Apply neural network to detect objects in video frames",
"Aaron Boxer <aaron.boxer@collabora.com>, Marcus Edel <marcus.edel@collabora.com>");
gst_element_class_add_pad_template (element_class,
gst_static_pad_template_get (&gst_onnx_object_detector_sink_template));
gst_element_class_add_pad_template (element_class,
gst_static_pad_template_get (&gst_onnx_object_detector_src_template));
basetransform_class->transform_ip =
GST_DEBUG_FUNCPTR (gst_onnx_object_detector_transform_ip);
basetransform_class->transform_caps =
GST_DEBUG_FUNCPTR (gst_onnx_object_detector_transform_caps);
}
static void
gst_onnx_object_detector_init (GstOnnxObjectDetector * self)
{
self->onnx_ptr = new GstOnnxNamespace::GstOnnxClient ();
self->onnx_disabled = false;
}
static void
gst_onnx_object_detector_finalize (GObject * object)
{
GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (object);
g_free (self->model_file);
delete GST_ONNX_MEMBER (self);
G_OBJECT_CLASS (gst_onnx_object_detector_parent_class)->finalize (object);
}
static void
gst_onnx_object_detector_set_property (GObject * object, guint prop_id,
const GValue * value, GParamSpec * pspec)
{
GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (object);
const gchar *filename;
auto onnxClient = GST_ONNX_MEMBER (self);
switch (prop_id) {
case PROP_MODEL_FILE:
filename = g_value_get_string (value);
if (filename
&& g_file_test (filename,
(GFileTest) (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR))) {
if (self->model_file)
g_free (self->model_file);
self->model_file = g_strdup (filename);
} else {
GST_WARNING_OBJECT (self, "Model file '%s' not found!", filename);
gst_base_transform_set_passthrough (GST_BASE_TRANSFORM (self), TRUE);
}
break;
case PROP_LABEL_FILE:
filename = g_value_get_string (value);
if (filename
&& g_file_test (filename,
(GFileTest) (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR))) {
if (self->label_file)
g_free (self->label_file);
self->label_file = g_strdup (filename);
} else {
GST_WARNING_OBJECT (self, "Label file '%s' not found!", filename);
}
break;
case PROP_SCORE_THRESHOLD:
GST_OBJECT_LOCK (self);
self->score_threshold = g_value_get_float (value);
GST_OBJECT_UNLOCK (self);
break;
case PROP_OPTIMIZATION_LEVEL:
self->optimization_level =
(GstOnnxOptimizationLevel) g_value_get_enum (value);
break;
case PROP_EXECUTION_PROVIDER:
self->execution_provider =
(GstOnnxExecutionProvider) g_value_get_enum (value);
break;
case PROP_DETECTION_NODE_INDEX:
onnxClient->setOutputNodeIndex
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_DETECTION,
g_value_get_int (value));
break;
case PROP_BOUNDING_BOX_NODE_INDEX:
onnxClient->setOutputNodeIndex
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX,
g_value_get_int (value));
break;
break;
case PROP_SCORE_NODE_INDEX:
onnxClient->setOutputNodeIndex
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_SCORE,
g_value_get_int (value));
break;
break;
case PROP_CLASS_NODE_INDEX:
onnxClient->setOutputNodeIndex
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_CLASS,
g_value_get_int (value));
break;
case PROP_INPUT_IMAGE_FORMAT:
onnxClient->setInputImageFormat ((GstMlModelInputImageFormat)
g_value_get_enum (value));
break;
default:
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
break;
}
}
static void
gst_onnx_object_detector_get_property (GObject * object, guint prop_id,
GValue * value, GParamSpec * pspec)
{
GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (object);
auto onnxClient = GST_ONNX_MEMBER (self);
switch (prop_id) {
case PROP_MODEL_FILE:
g_value_set_string (value, self->model_file);
break;
case PROP_LABEL_FILE:
g_value_set_string (value, self->label_file);
break;
case PROP_SCORE_THRESHOLD:
GST_OBJECT_LOCK (self);
g_value_set_float (value, self->score_threshold);
GST_OBJECT_UNLOCK (self);
break;
case PROP_OPTIMIZATION_LEVEL:
g_value_set_enum (value, self->optimization_level);
break;
case PROP_EXECUTION_PROVIDER:
g_value_set_enum (value, self->execution_provider);
break;
case PROP_DETECTION_NODE_INDEX:
g_value_set_int (value,
onnxClient->getOutputNodeIndex
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_DETECTION));
break;
case PROP_BOUNDING_BOX_NODE_INDEX:
g_value_set_int (value,
onnxClient->getOutputNodeIndex
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX));
break;
break;
case PROP_SCORE_NODE_INDEX:
g_value_set_int (value,
onnxClient->getOutputNodeIndex
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_SCORE));
break;
break;
case PROP_CLASS_NODE_INDEX:
g_value_set_int (value,
onnxClient->getOutputNodeIndex
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_CLASS));
break;
case PROP_INPUT_IMAGE_FORMAT:
g_value_set_enum (value, onnxClient->getInputImageFormat ());
break;
default:
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
break;
}
}
static gboolean
gst_onnx_object_detector_create_session (GstBaseTransform * trans)
{
GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (trans);
auto onnxClient = GST_ONNX_MEMBER (self);
GST_OBJECT_LOCK (self);
if (self->onnx_disabled || onnxClient->hasSession ()) {
GST_OBJECT_UNLOCK (self);
return TRUE;
}
if (self->model_file) {
gboolean ret = GST_ONNX_MEMBER (self)->createSession (self->model_file,
self->optimization_level,
self->execution_provider);
if (!ret) {
GST_ERROR_OBJECT (self,
"Unable to create ONNX session. Detection disabled.");
} else {
auto outputNames = onnxClient->getOutputNodeNames ();
for (size_t i = 0; i < outputNames.size (); ++i)
GST_INFO_OBJECT (self, "Output node index: %d for node: %s", (gint) i,
outputNames[i]);
if (outputNames.size () < 3) {
GST_ERROR_OBJECT (self,
"Number of output tensor nodes %d does not match the 3 or 4 nodes "
"required for an object detection model. Detection is disabled.",
(gint) outputNames.size ());
self->onnx_disabled = TRUE;
}
// sanity check on output node indices
if (onnxClient->getOutputNodeIndex
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_DETECTION) ==
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED) {
GST_ERROR_OBJECT (self,
"Output detection node index not set. Detection disabled.");
self->onnx_disabled = TRUE;
}
if (onnxClient->getOutputNodeIndex
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX) ==
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED) {
GST_ERROR_OBJECT (self,
"Output bounding box node index not set. Detection disabled.");
self->onnx_disabled = TRUE;
}
if (onnxClient->getOutputNodeIndex
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_SCORE) ==
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED) {
GST_ERROR_OBJECT (self,
"Output score node index not set. Detection disabled.");
self->onnx_disabled = TRUE;
}
if (outputNames.size () == 4 && onnxClient->getOutputNodeIndex
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_CLASS) ==
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED) {
GST_ERROR_OBJECT (self,
"Output class node index not set. Detection disabled.");
self->onnx_disabled = TRUE;
}
// model is not usable, so fail
if (self->onnx_disabled) {
GST_ELEMENT_WARNING (self, RESOURCE, FAILED,
("ONNX model cannot be used for object detection"), (NULL));
return FALSE;
}
}
} else {
self->onnx_disabled = TRUE;
}
GST_OBJECT_UNLOCK (self);
if (self->onnx_disabled){
gst_base_transform_set_passthrough (GST_BASE_TRANSFORM (self), TRUE);
}
return TRUE;
}
static GstCaps *
gst_onnx_object_detector_transform_caps (GstBaseTransform *
trans, GstPadDirection direction, GstCaps * caps, GstCaps * filter_caps)
{
GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (trans);
auto onnxClient = GST_ONNX_MEMBER (self);
GstCaps *other_caps;
guint i;
if ( !gst_onnx_object_detector_create_session (trans) )
return NULL;
GST_LOG_OBJECT (self, "transforming caps %" GST_PTR_FORMAT, caps);
if (gst_base_transform_is_passthrough (trans)
|| (!onnxClient->isFixedInputImageSize ()))
return gst_caps_ref (caps);
other_caps = gst_caps_new_empty ();
for (i = 0; i < gst_caps_get_size (caps); ++i) {
GstStructure *structure, *new_structure;
structure = gst_caps_get_structure (caps, i);
new_structure = gst_structure_copy (structure);
gst_structure_set (new_structure, "width", G_TYPE_INT,
onnxClient->getWidth (), "height", G_TYPE_INT,
onnxClient->getHeight (), NULL);
GST_LOG_OBJECT (self,
"transformed structure %2d: %" GST_PTR_FORMAT " => %"
GST_PTR_FORMAT, i, structure, new_structure);
gst_caps_append_structure (other_caps, new_structure);
}
if (!gst_caps_is_empty (other_caps) && filter_caps) {
GstCaps *tmp = gst_caps_intersect_full (other_caps,filter_caps,
GST_CAPS_INTERSECT_FIRST);
gst_caps_replace (&other_caps, tmp);
gst_caps_unref (tmp);
}
return other_caps;
}
static GstFlowReturn
gst_onnx_object_detector_transform_ip (GstBaseTransform * trans,
GstBuffer * buf)
{
if (!gst_base_transform_is_passthrough (trans)
&& !gst_onnx_object_detector_process (trans, buf)){
GST_ELEMENT_WARNING (trans, STREAM, FAILED,
("ONNX object detection failed"), (NULL));
return GST_FLOW_ERROR;
}
return GST_FLOW_OK;
}
static gboolean
gst_onnx_object_detector_process (GstBaseTransform * trans, GstBuffer * buf)
{
GstMapInfo info;
GstVideoMeta *vmeta = gst_buffer_get_video_meta (buf);
if (!vmeta) {
GST_WARNING_OBJECT (trans, "missing video meta");
return FALSE;
}
if (gst_buffer_map (buf, &info, GST_MAP_READ)) {
GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (trans);
std::vector < GstOnnxNamespace::GstMlBoundingBox > boxes;
try {
boxes = GST_ONNX_MEMBER (self)->run (info.data, vmeta,
self->label_file ? self->label_file : "", self->score_threshold);
}
catch (Ort::Exception & ortex) {
GST_ERROR_OBJECT (self, "%s", ortex.what ());
return FALSE;
}
for (auto & b:boxes) {
auto vroi_meta = gst_buffer_add_video_region_of_interest_meta (buf,
GST_ONNX_OBJECT_DETECTOR_META_NAME,
b.x0, b.y0,
b.width,
b.height);
if (!vroi_meta) {
GST_WARNING_OBJECT (trans,
"Unable to attach GstVideoRegionOfInterestMeta to buffer");
return FALSE;
}
auto s = gst_structure_new (GST_ONNX_OBJECT_DETECTOR_META_PARAM_NAME,
GST_ONNX_OBJECT_DETECTOR_META_FIELD_LABEL,
G_TYPE_STRING,
b.label.c_str (),
GST_ONNX_OBJECT_DETECTOR_META_FIELD_SCORE,
G_TYPE_DOUBLE,
b.score,
NULL);
gst_video_region_of_interest_meta_add_param (vroi_meta, s);
GST_DEBUG_OBJECT (self,
"Object detected with label : %s, score: %f, bound box: (%f,%f,%f,%f) \n",
b.label.c_str (), b.score, b.x0, b.y0,
b.x0 + b.width, b.y0 + b.height);
}
gst_buffer_unmap (buf, &info);
}
return TRUE;
}

View file

@ -1,93 +0,0 @@
/*
* GStreamer gstreamer-onnxobjectdetector
* Copyright (C) 2021 Collabora Ltd
*
* gstonnxobjectdetector.h
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public
* License along with this library; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
#ifndef __GST_ONNX_OBJECT_DETECTOR_H__
#define __GST_ONNX_OBJECT_DETECTOR_H__
#include <gst/gst.h>
#include <gst/video/video.h>
#include <gst/video/gstvideofilter.h>
#include "gstonnxelement.h"
G_BEGIN_DECLS
#define GST_TYPE_ONNX_OBJECT_DETECTOR (gst_onnx_object_detector_get_type())
G_DECLARE_FINAL_TYPE (GstOnnxObjectDetector, gst_onnx_object_detector, GST, ONNX_OBJECT_DETECTOR, GstBaseTransform)
#define GST_ONNX_OBJECT_DETECTOR(obj) (G_TYPE_CHECK_INSTANCE_CAST((obj),GST_TYPE_ONNX_OBJECT_DETECTOR,GstOnnxObjectDetector))
#define GST_ONNX_OBJECT_DETECTOR_CLASS(klass) (G_TYPE_CHECK_CLASS_CAST((klass), GST_TYPE_ONNX_OBJECT_DETECTOR,GstOnnxObjectDetectorClass))
#define GST_ONNX_OBJECT_DETECTOR_GET_CLASS(obj) (G_TYPE_INSTANCE_GET_CLASS((obj), GST_TYPE_ONNX_OBJECT_DETECTOR,GstOnnxObjectDetectorClass))
#define GST_IS_ONNX_OBJECT_DETECTOR(obj) (G_TYPE_CHECK_INSTANCE_TYPE((obj),GST_TYPE_ONNX_OBJECT_DETECTOR))
#define GST_IS_ONNX_OBJECT_DETECTOR_CLASS(klass) (G_TYPE_CHECK_CLASS_TYPE((klass), GST_TYPE_ONNX_OBJECT_DETECTOR))
#define GST_ONNX_OBJECT_DETECTOR_META_NAME "onnx-object_detector"
#define GST_ONNX_OBJECT_DETECTOR_META_PARAM_NAME "extra-data"
#define GST_ONNX_OBJECT_DETECTOR_META_FIELD_LABEL "label"
#define GST_ONNX_OBJECT_DETECTOR_META_FIELD_SCORE "score"
/**
* GstOnnxObjectDetector:
*
* @model_file model file
* @label_file label file
* @score_threshold score threshold
* @confidence_threshold confidence threshold
* @iou_threhsold iou threshold
* @optimization_level ONNX optimization level
* @execution_provider: ONNX execution provider
* @onnx_ptr opaque pointer to ONNX implementation
*
* Since: 1.20
*/
struct _GstOnnxObjectDetector
{
GstBaseTransform basetransform;
gchar *model_file;
gchar *label_file;
gfloat score_threshold;
gfloat confidence_threshold;
gfloat iou_threshold;
GstOnnxOptimizationLevel optimization_level;
GstOnnxExecutionProvider execution_provider;
gpointer onnx_ptr;
gboolean onnx_disabled;
void (*process) (GstOnnxObjectDetector * onnx_object_detector,
GstVideoFrame * inframe, GstVideoFrame * outframe);
};
/**
* GstOnnxObjectDetectorClass:
*
* @parent_class base transform base class
*
* Since: 1.20
*/
struct _GstOnnxObjectDetectorClass
{
GstBaseTransformClass parent_class;
};
GST_ELEMENT_REGISTER_DECLARE (onnx_object_detector)
G_END_DECLS
#endif /* __GST_ONNX_OBJECT_DETECTOR_H__ */

View file

@ -3,27 +3,31 @@ if get_option('onnx').disabled()
endif endif
onnxrt_dep = dependency('libonnxruntime', version : '>= 1.13.1', required : get_option('onnx')) onnxrt_dep = dependency('libonnxruntime', version : '>= 1.15.1', required : get_option('onnx'))
extra_args = []
extra_deps = []
if gstcuda_dep.found()
extra_args += ['-DHAVE_CUDA']
extra_deps += [gstcuda_dep]
endif
if onnxrt_dep.found() if onnxrt_dep.found()
onnxrt_include_root = onnxrt_dep.get_variable('includedir') onnxrt_include_root = onnxrt_dep.get_variable('includedir')
onnxrt_includes = [onnxrt_include_root / 'core/session', onnxrt_include_root / 'core'] onnxrt_includes = [onnxrt_include_root / 'core/session', onnxrt_include_root / 'core']
onnxrt_dep_args = []
compiler = meson.get_compiler('cpp')
if compiler.has_header(onnxrt_include_root / 'core/providers/cuda/cuda_provider_factory.h')
onnxrt_dep_args = ['-DGST_ML_ONNX_RUNTIME_HAVE_CUDA']
endif
gstonnx = library('gstonnx', gstonnx = library('gstonnx',
'gstonnx.c', 'gstonnx.c',
'gstonnxelement.c', 'decoders/gstobjectdetectorutils.cpp',
'gstonnxobjectdetector.cpp', 'decoders/gstssdobjectdetector.cpp',
'gstonnxinference.cpp',
'gstonnxclient.cpp', 'gstonnxclient.cpp',
c_args : gst_plugins_bad_args, 'tensor/gsttensorid.cpp',
cpp_args: onnxrt_dep_args, 'tensor/gsttensormeta.c',
c_args : gst_plugins_bad_args + extra_args,
cpp_args : gst_plugins_bad_args + extra_args,
link_args : noseh_link_args, link_args : noseh_link_args,
include_directories : [configinc, libsinc, onnxrt_includes], include_directories : [configinc, libsinc, onnxrt_includes, cuda_stubinc],
dependencies : [gstbase_dep, gstvideo_dep, onnxrt_dep, libm], dependencies : [gstbase_dep, gstvideo_dep, onnxrt_dep, libm] + extra_deps,
install : true, install : true,
install_dir : plugins_install_dir, install_dir : plugins_install_dir,
) )

View file

@ -0,0 +1,69 @@
/*
* GStreamer gstreamer-tensor
* Copyright (C) 2023 Collabora Ltd
*
* gsttensor.h
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public
* License along with this library; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
#ifndef __GST_TENSOR_H__
#define __GST_TENSOR_H__
/**
* GstTensorType:
*
* @GST_TENSOR_TYPE_INT8 8 bit integer tensor data
* @GST_TENSOR_TYPE_INT16 16 bit integer tensor data
* @GST_TENSOR_TYPE_INT32 32 bit integer tensor data
* @GST_TENSOR_TYPE_FLOAT16 16 bit floating point tensor data
* @GST_TENSOR_TYPE_FLOAT32 32 bit floating point tensor data
*
* Since: 1.24
*/
typedef enum _GstTensorType
{
GST_TENSOR_TYPE_INT8,
GST_TENSOR_TYPE_INT16,
GST_TENSOR_TYPE_INT32,
GST_TENSOR_TYPE_FLOAT16,
GST_TENSOR_TYPE_FLOAT32
} GstTensorType;
/**
* GstTensor:
*
* @id unique tensor identifier
* @num_dims number of tensor dimensions
* @dims tensor dimensions
* @type @ref GstTensorType of tensor data
* @data @ref GstBuffer holding tensor data
*
* Since: 1.24
*/
typedef struct _GstTensor
{
GQuark id;
gint num_dims;
int64_t *dims;
GstTensorType type;
GstBuffer *data;
} GstTensor;
#define GST_TENSOR_MISSING_ID -1
#endif

View file

@ -0,0 +1,82 @@
/*
* GStreamer gstreamer-tensorid
* Copyright (C) 2023 Collabora Ltd
*
* gsttensorid.c
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public
* License along with this library; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
#include <gst/gst.h>
#include "gsttensorid.h"
/* Structure to encapsulate a string and its associated GQuark */
struct TensorQuark
{
const char *string;
GQuark quark_id;
};
class TensorId
{
public:
TensorId (void):tensor_quarks_array (g_array_new (FALSE, FALSE,
sizeof (TensorQuark)))
{
}
~TensorId (void)
{
if (tensor_quarks_array) {
for (guint i = 0; i < tensor_quarks_array->len; i++) {
TensorQuark *quark =
&g_array_index (tensor_quarks_array, TensorQuark, i);
g_free ((gpointer) quark->string); // free the duplicated string
}
g_array_free (tensor_quarks_array, TRUE);
}
}
GQuark get_quark (const char *str)
{
for (guint i = 0; i < tensor_quarks_array->len; i++) {
TensorQuark *quark = &g_array_index (tensor_quarks_array, TensorQuark, i);
if (g_strcmp0 (quark->string, str) == 0) {
return quark->quark_id; // already registered
}
}
// Register the new quark and append to the GArray
TensorQuark new_quark;
new_quark.string = g_strdup (str); // create a copy of the string
new_quark.quark_id = g_quark_from_string (new_quark.string);
g_array_append_val (tensor_quarks_array, new_quark);
return new_quark.quark_id;
}
private:
GArray * tensor_quarks_array;
};
static TensorId tensorId;
G_BEGIN_DECLS
GQuark
gst_tensorid_get_quark (const char *tensor_id)
{
return tensorId.get_quark (tensor_id);
}
G_END_DECLS

View file

@ -0,0 +1,34 @@
/*
* GStreamer gstreamer-tensorid
* Copyright (C) 2023 Collabora Ltd
*
* gsttensorid.h
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public
* License along with this library; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
#ifndef __GST_TENSOR_ID_H__
#define __GST_TENSOR_ID_H__
G_BEGIN_DECLS
/**
* gst_tensorid_get_quark get tensor id
*
* @param tensor_id unique string id for tensor node
*/
GQuark gst_tensorid_get_quark (const char *tensor_id);
G_END_DECLS
#endif

View file

@ -0,0 +1,107 @@
/*
* GStreamer gstreamer-tensormeta
* Copyright (C) 2023 Collabora Ltd
*
* gsttensormeta.c
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public
* License along with this library; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
#include "gsttensormeta.h"
#include "gsttensor.h"
static gboolean
gst_tensor_meta_init (GstMeta * meta, gpointer params, GstBuffer * buffer)
{
GstTensorMeta *tmeta = (GstTensorMeta *) meta;
tmeta->num_tensors = 0;
tmeta->tensor = NULL;
return TRUE;
}
static void
gst_tensor_meta_free (GstMeta * meta, GstBuffer * buffer)
{
GstTensorMeta *tmeta = (GstTensorMeta *) meta;
for (int i = 0; i < tmeta->num_tensors; i++) {
g_free (tmeta->tensor[i].dims);
gst_buffer_unref (tmeta->tensor[i].data);
}
g_free (tmeta->tensor);
}
GType
gst_tensor_meta_api_get_type (void)
{
static GType type = 0;
static const gchar *tags[] = { NULL };
if (g_once_init_enter (&type)) {
type = gst_meta_api_type_register ("GstTensorMetaAPI", tags);
}
return type;
}
const GstMetaInfo *
gst_tensor_meta_get_info (void)
{
static const GstMetaInfo *tmeta_info = NULL;
if (g_once_init_enter (&tmeta_info)) {
const GstMetaInfo *meta =
gst_meta_register (gst_tensor_meta_api_get_type (),
"GstTensorMeta",
sizeof (GstTensorMeta),
gst_tensor_meta_init,
gst_tensor_meta_free,
NULL); /* tensor_meta_transform not implemented */
g_once_init_leave (&tmeta_info, meta);
}
return tmeta_info;
}
GList *
gst_tensor_meta_get_all_from_buffer (GstBuffer * buffer)
{
GType tensor_meta_api_type = gst_tensor_meta_api_get_type ();
GList *tensor_metas = NULL;
gpointer state = NULL;
GstMeta *meta;
while ((meta = gst_buffer_iterate_meta (buffer, &state))) {
if (meta->info->api == tensor_meta_api_type) {
tensor_metas = g_list_append (tensor_metas, meta);
}
}
return tensor_metas;
}
gint
gst_tensor_meta_get_index_from_id (GstTensorMeta * meta, GQuark id)
{
for (int i = 0; i < meta->num_tensors; ++i) {
if ((meta->tensor + i)->id == id)
return i;
}
return GST_TENSOR_MISSING_ID;
}

View file

@ -0,0 +1,56 @@
/*
* GStreamer gstreamer-tensormeta
* Copyright (C) 2023 Collabora Ltd
*
* gsttensormeta.h
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public
* License along with this library; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
#ifndef __GST_TENSOR_META_H__
#define __GST_TENSOR_META_H__
#include <gst/gst.h>
#include "gsttensor.h"
/**
* GstTensorMeta:
*
* @meta base GstMeta
* @num_tensors number of tensors
* @tensor @ref GstTensor for each tensor
* @batch_size model batch size
*
* Since: 1.24
*/
typedef struct _GstTensorMeta
{
GstMeta meta;
gint num_tensors;
GstTensor *tensor;
int batch_size;
} GstTensorMeta;
G_BEGIN_DECLS
GType gst_tensor_meta_api_get_type (void);
const GstMetaInfo *gst_tensor_meta_get_info (void);
GList *gst_tensor_meta_get_all_from_buffer (GstBuffer * buffer);
gint gst_tensor_meta_get_index_from_id(GstTensorMeta *meta, GQuark id);
G_END_DECLS
#endif