mirror of
https://gitlab.freedesktop.org/gstreamer/gstreamer.git
synced 2024-12-23 00:36:51 +00:00
onnx: add gstonnxinference element
This element refactors functionality from gstonnxinference element, namely separating out the ONNX inference from the subsequent analysis. The new element runs an ONNX model on each video frame, and then attaches a TensorMeta meta with the output tensor data. This tensor data will then be consumed by downstream elements such as gstobjectdetector. At the moment, a provisional TensorMeta is used just in the ONNX plugin, but in future this will upgraded to a GStreamer API for other plugins to consume. Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/4916>
This commit is contained in:
parent
6053dd0d1b
commit
1ff585233a
20 changed files with 1934 additions and 1197 deletions
|
@ -0,0 +1,195 @@
|
|||
/*
|
||||
* GStreamer gstreamer-objectdetectorutils
|
||||
* Copyright (C) 2023 Collabora Ltd
|
||||
*
|
||||
* gstobjectdetectorutils.cpp
|
||||
*
|
||||
* This library is free software; you can redistribute it and/or
|
||||
* modify it under the terms of the GNU Library General Public
|
||||
* License as published by the Free Software Foundation; either
|
||||
* version 2 of the License, or (at your option) any later version.
|
||||
*
|
||||
* This library is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||
* Library General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Library General Public
|
||||
* License along with this library; if not, write to the
|
||||
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
|
||||
* Boston, MA 02110-1301, USA.
|
||||
*/
|
||||
|
||||
#include "gstobjectdetectorutils.h"
|
||||
|
||||
#include <fstream>
|
||||
#include "tensor/gsttensorid.h"
|
||||
|
||||
GstMlBoundingBox::GstMlBoundingBox (std::string lbl, float score, float _x0,
|
||||
float _y0, float _width, float _height):
|
||||
label (lbl),
|
||||
score (score),
|
||||
x0 (_x0),
|
||||
y0 (_y0),
|
||||
width (_width),
|
||||
height (_height)
|
||||
{
|
||||
}
|
||||
|
||||
GstMlBoundingBox::GstMlBoundingBox ():
|
||||
GstMlBoundingBox ("", 0.0f, 0.0f, 0.0f, 0.0f, 0.0f)
|
||||
{
|
||||
}
|
||||
|
||||
namespace GstObjectDetectorUtils
|
||||
{
|
||||
|
||||
GstObjectDetectorUtils::GstObjectDetectorUtils ()
|
||||
{
|
||||
}
|
||||
|
||||
std::vector < std::string >
|
||||
GstObjectDetectorUtils::ReadLabels (const std::string & labelsFile)
|
||||
{
|
||||
std::vector < std::string > labels;
|
||||
std::string line;
|
||||
std::ifstream fp (labelsFile);
|
||||
while (std::getline (fp, line))
|
||||
labels.push_back (line);
|
||||
|
||||
return labels;
|
||||
}
|
||||
|
||||
std::vector < GstMlBoundingBox > GstObjectDetectorUtils::run (int32_t w,
|
||||
int32_t h, GstTensorMeta * tmeta, std::string labelPath,
|
||||
float scoreThreshold)
|
||||
{
|
||||
|
||||
auto classIndex = gst_tensor_meta_get_index_from_id (tmeta,
|
||||
gst_tensorid_get_quark (GST_MODEL_OBJECT_DETECTOR_CLASSES));
|
||||
if (classIndex == GST_TENSOR_MISSING_ID) {
|
||||
GST_ERROR ("Missing class tensor id");
|
||||
return std::vector < GstMlBoundingBox > ();
|
||||
}
|
||||
auto type = tmeta->tensor[classIndex].type;
|
||||
return (type == GST_TENSOR_TYPE_FLOAT32) ?
|
||||
doRun < float >(w, h, tmeta, labelPath, scoreThreshold)
|
||||
: doRun < int >(w, h, tmeta, labelPath, scoreThreshold);
|
||||
}
|
||||
|
||||
template < typename T > std::vector < GstMlBoundingBox >
|
||||
GstObjectDetectorUtils::doRun (int32_t w, int32_t h,
|
||||
GstTensorMeta * tmeta, std::string labelPath, float scoreThreshold)
|
||||
{
|
||||
std::vector < GstMlBoundingBox > boundingBoxes;
|
||||
GstMapInfo map_info[GstObjectDetectorMaxNodes];
|
||||
GstMemory *memory[GstObjectDetectorMaxNodes] = { NULL };
|
||||
std::vector < std::string > labels;
|
||||
gint index;
|
||||
T *numDetections = nullptr, *bboxes = nullptr, *scores =
|
||||
nullptr, *labelIndex = nullptr;
|
||||
|
||||
// number of detections
|
||||
index = gst_tensor_meta_get_index_from_id (tmeta,
|
||||
gst_tensorid_get_quark (GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS));
|
||||
if (index == GST_TENSOR_MISSING_ID) {
|
||||
GST_WARNING ("Missing tensor data for tensor index %d", index);
|
||||
goto cleanup;
|
||||
}
|
||||
memory[index] = gst_buffer_peek_memory (tmeta->tensor[index].data, 0);
|
||||
if (!memory[index]) {
|
||||
GST_WARNING ("Missing tensor data for tensor index %d", index);
|
||||
goto cleanup;
|
||||
}
|
||||
if (!gst_memory_map (memory[index], map_info + index, GST_MAP_READ)) {
|
||||
GST_WARNING ("Failed to map tensor memory for index %d", index);
|
||||
goto cleanup;
|
||||
}
|
||||
numDetections = (T *) map_info[index].data;
|
||||
|
||||
// bounding boxes
|
||||
index =
|
||||
gst_tensor_meta_get_index_from_id (tmeta,
|
||||
gst_tensorid_get_quark (GST_MODEL_OBJECT_DETECTOR_BOXES));
|
||||
if (index == GST_TENSOR_MISSING_ID) {
|
||||
GST_WARNING ("Missing tensor data for tensor index %d", index);
|
||||
goto cleanup;
|
||||
}
|
||||
memory[index] = gst_buffer_peek_memory (tmeta->tensor[index].data, 0);
|
||||
if (!memory[index]) {
|
||||
GST_WARNING ("Failed to map tensor memory for index %d", index);
|
||||
goto cleanup;
|
||||
}
|
||||
if (!gst_memory_map (memory[index], map_info + index, GST_MAP_READ)) {
|
||||
GST_ERROR ("Failed to map GstMemory");
|
||||
goto cleanup;
|
||||
}
|
||||
bboxes = (T *) map_info[index].data;
|
||||
|
||||
// scores
|
||||
index =
|
||||
gst_tensor_meta_get_index_from_id (tmeta,
|
||||
gst_tensorid_get_quark (GST_MODEL_OBJECT_DETECTOR_SCORES));
|
||||
if (index == GST_TENSOR_MISSING_ID) {
|
||||
GST_ERROR ("Missing scores tensor id");
|
||||
goto cleanup;
|
||||
}
|
||||
memory[index] = gst_buffer_peek_memory (tmeta->tensor[index].data, 0);
|
||||
if (!memory[index]) {
|
||||
GST_WARNING ("Missing tensor data for tensor index %d", index);
|
||||
goto cleanup;
|
||||
}
|
||||
if (!gst_memory_map (memory[index], map_info + index, GST_MAP_READ)) {
|
||||
GST_ERROR ("Failed to map GstMemory");
|
||||
goto cleanup;
|
||||
}
|
||||
scores = (T *) map_info[index].data;
|
||||
|
||||
// optional label
|
||||
labelIndex = nullptr;
|
||||
index =
|
||||
gst_tensor_meta_get_index_from_id (tmeta,
|
||||
gst_tensorid_get_quark (GST_MODEL_OBJECT_DETECTOR_CLASSES));
|
||||
if (index != GST_TENSOR_MISSING_ID) {
|
||||
memory[index] = gst_buffer_peek_memory (tmeta->tensor[index].data, 0);
|
||||
if (!memory[index]) {
|
||||
GST_WARNING ("Missing tensor data for tensor index %d", index);
|
||||
goto cleanup;
|
||||
}
|
||||
if (!gst_memory_map (memory[index], map_info + index, GST_MAP_READ)) {
|
||||
GST_ERROR ("Failed to map GstMemory");
|
||||
goto cleanup;
|
||||
}
|
||||
labelIndex = (T *) map_info[index].data;
|
||||
}
|
||||
|
||||
if (!labelPath.empty ())
|
||||
labels = ReadLabels (labelPath);
|
||||
|
||||
for (int i = 0; i < numDetections[0]; ++i) {
|
||||
if (scores[i] > scoreThreshold) {
|
||||
std::string label = "";
|
||||
|
||||
if (labelIndex && !labels.empty ())
|
||||
label = labels[labelIndex[i] - 1];
|
||||
auto score = scores[i];
|
||||
auto y0 = bboxes[i * 4] * h;
|
||||
auto x0 = bboxes[i * 4 + 1] * w;
|
||||
auto bheight = bboxes[i * 4 + 2] * h - y0;
|
||||
auto bwidth = bboxes[i * 4 + 3] * w - x0;
|
||||
boundingBoxes.push_back (GstMlBoundingBox (label, score, x0, y0, bwidth,
|
||||
bheight));
|
||||
}
|
||||
}
|
||||
|
||||
cleanup:
|
||||
for (int i = 0; i < GstObjectDetectorMaxNodes; ++i) {
|
||||
if (memory[i])
|
||||
gst_memory_unmap (memory[i], map_info + i);
|
||||
|
||||
}
|
||||
|
||||
return boundingBoxes;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,82 @@
|
|||
/*
|
||||
* GStreamer gstreamer-objectdetectorutils
|
||||
* Copyright (C) 2023 Collabora Ltd
|
||||
*
|
||||
* gstobjectdetectorutils.h
|
||||
*
|
||||
* This library is free software; you can redistribute it and/or
|
||||
* modify it under the terms of the GNU Library General Public
|
||||
* License as published by the Free Software Foundation; either
|
||||
* version 2 of the License, or (at your option) any later version.
|
||||
*
|
||||
* This library is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||
* Library General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Library General Public
|
||||
* License along with this library; if not, write to the
|
||||
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
|
||||
* Boston, MA 02110-1301, USA.
|
||||
*/
|
||||
#ifndef __GST_OBJECT_DETECTOR_UTILS_H__
|
||||
#define __GST_OBJECT_DETECTOR_UTILS_H__
|
||||
|
||||
#include <gst/gst.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "gstml.h"
|
||||
#include "tensor/gsttensormeta.h"
|
||||
|
||||
/* Object detection tensor id strings */
|
||||
#define GST_MODEL_OBJECT_DETECTOR_BOXES "Gst.Model.ObjectDetector.Boxes"
|
||||
#define GST_MODEL_OBJECT_DETECTOR_SCORES "Gst.Model.ObjectDetector.Scores"
|
||||
#define GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS "Gst.Model.ObjectDetector.NumDetections"
|
||||
#define GST_MODEL_OBJECT_DETECTOR_CLASSES "Gst.Model.ObjectDetector.Classes"
|
||||
|
||||
|
||||
/**
|
||||
* GstMlBoundingBox:
|
||||
*
|
||||
* @label label
|
||||
* @score detection confidence
|
||||
* @x0 top left hand x coordinate
|
||||
* @y0 top left hand y coordinate
|
||||
* @width width
|
||||
* @height height
|
||||
*
|
||||
* Since: 1.20
|
||||
*/
|
||||
struct GstMlBoundingBox {
|
||||
GstMlBoundingBox(std::string lbl, float score, float _x0, float _y0,
|
||||
float _width, float _height);
|
||||
GstMlBoundingBox();
|
||||
std::string label;
|
||||
float score;
|
||||
float x0;
|
||||
float y0;
|
||||
float width;
|
||||
float height;
|
||||
};
|
||||
|
||||
namespace GstObjectDetectorUtils {
|
||||
const int GstObjectDetectorMaxNodes = 4;
|
||||
class GstObjectDetectorUtils {
|
||||
public:
|
||||
GstObjectDetectorUtils(void);
|
||||
~GstObjectDetectorUtils(void) = default;
|
||||
std::vector < GstMlBoundingBox > run(int32_t w, int32_t h,
|
||||
GstTensorMeta *tmeta,
|
||||
std::string labelPath,
|
||||
float scoreThreshold);
|
||||
private:
|
||||
template < typename T > std::vector < GstMlBoundingBox >
|
||||
doRun(int32_t w, int32_t h,
|
||||
GstTensorMeta *tmeta, std::string labelPath,
|
||||
float scoreThreshold);
|
||||
std::vector < std::string > ReadLabels(const std::string & labelsFile);
|
||||
};
|
||||
}
|
||||
|
||||
#endif /* __GST_OBJECT_DETECTOR_UTILS_H__ */
|
|
@ -0,0 +1,348 @@
|
|||
/*
|
||||
* GStreamer gstreamer-ssdobjectdetector
|
||||
* Copyright (C) 2021 Collabora Ltd.
|
||||
*
|
||||
* gstssdobjectdetector.cpp
|
||||
*
|
||||
* This library is free software; you can redistribute it and/or
|
||||
* modify it under the terms of the GNU Library General Public
|
||||
* License as published by the Free Software Foundation; either
|
||||
* version 2 of the License, or (at your option) any later version.
|
||||
*
|
||||
* This library is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||
* Library General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Library General Public
|
||||
* License along with this library; if not, write to the
|
||||
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
|
||||
* Boston, MA 02110-1301, USA.
|
||||
*/
|
||||
|
||||
/**
|
||||
* SECTION:element-ssdobjectdetector
|
||||
* @short_description: Detect objects in video buffers using SSD neural network
|
||||
*
|
||||
* This element can parse per-buffer inference tensor meta data generated by an upstream
|
||||
* inference element
|
||||
*
|
||||
*
|
||||
* ## Example launch command:
|
||||
*
|
||||
* note: image resolution may need to be adapted to the model, if the model expects
|
||||
* a certain input resolution. The `videoscale` element in the pipeline below will scale
|
||||
* the image, using padding if required, to 640x383 resolution required by model
|
||||
*
|
||||
*
|
||||
* GST_DEBUG=objectdetector:5 gst-launch-1.0 multifilesrc \
|
||||
* location=bus.jpg ! jpegdec ! videoconvert ! \
|
||||
* videoscale ! 'video/x-raw,width=640,height=383' ! \
|
||||
* onnxinference execution-provider=cpu model-file=model.onnx \
|
||||
* ssdobjectdetector label-file=COCO_classes.txt ! \
|
||||
* videoconvert ! autovideosink
|
||||
*
|
||||
*/
|
||||
|
||||
#ifdef HAVE_CONFIG_H
|
||||
#include "config.h"
|
||||
#endif
|
||||
|
||||
|
||||
#include "gstssdobjectdetector.h"
|
||||
#include "gstobjectdetectorutils.h"
|
||||
|
||||
#include <gst/gst.h>
|
||||
#include <gst/video/video.h>
|
||||
#include <gst/video/gstvideometa.h>
|
||||
#include "tensor/gsttensormeta.h"
|
||||
#include "tensor/gsttensorid.h"
|
||||
|
||||
GST_DEBUG_CATEGORY_STATIC (ssd_object_detector_debug);
|
||||
#define GST_CAT_DEFAULT ssd_object_detector_debug
|
||||
#define GST_ODUTILS_MEMBER( self ) ((GstObjectDetectorUtils::GstObjectDetectorUtils *) (self->odutils))
|
||||
GST_ELEMENT_REGISTER_DEFINE (ssd_object_detector, "ssdobjectdetector",
|
||||
GST_RANK_PRIMARY, GST_TYPE_SSD_OBJECT_DETECTOR);
|
||||
|
||||
/* GstSsdObjectDetector properties */
|
||||
enum
|
||||
{
|
||||
PROP_0,
|
||||
PROP_LABEL_FILE,
|
||||
PROP_SCORE_THRESHOLD,
|
||||
};
|
||||
|
||||
#define GST_SSD_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD 0.3f /* 0 to 1 */
|
||||
|
||||
static GstStaticPadTemplate gst_ssd_object_detector_src_template =
|
||||
GST_STATIC_PAD_TEMPLATE ("src",
|
||||
GST_PAD_SRC,
|
||||
GST_PAD_ALWAYS,
|
||||
GST_STATIC_CAPS ("video/x-raw")
|
||||
);
|
||||
|
||||
static GstStaticPadTemplate gst_ssd_object_detector_sink_template =
|
||||
GST_STATIC_PAD_TEMPLATE ("sink",
|
||||
GST_PAD_SINK,
|
||||
GST_PAD_ALWAYS,
|
||||
GST_STATIC_CAPS ("video/x-raw")
|
||||
);
|
||||
|
||||
static void gst_ssd_object_detector_set_property (GObject * object,
|
||||
guint prop_id, const GValue * value, GParamSpec * pspec);
|
||||
static void gst_ssd_object_detector_get_property (GObject * object,
|
||||
guint prop_id, GValue * value, GParamSpec * pspec);
|
||||
static void gst_ssd_object_detector_finalize (GObject * object);
|
||||
static GstFlowReturn gst_ssd_object_detector_transform_ip (GstBaseTransform *
|
||||
trans, GstBuffer * buf);
|
||||
static gboolean gst_ssd_object_detector_process (GstBaseTransform * trans,
|
||||
GstBuffer * buf);
|
||||
static gboolean
|
||||
gst_ssd_object_detector_set_caps (GstBaseTransform * trans, GstCaps * incaps,
|
||||
GstCaps * outcaps);
|
||||
|
||||
G_DEFINE_TYPE (GstSsdObjectDetector, gst_ssd_object_detector, GST_TYPE_BASE_TRANSFORM);
|
||||
|
||||
static void
|
||||
gst_ssd_object_detector_class_init (GstSsdObjectDetectorClass * klass)
|
||||
{
|
||||
GObjectClass *gobject_class = (GObjectClass *) klass;
|
||||
GstElementClass *element_class = (GstElementClass *) klass;
|
||||
GstBaseTransformClass *basetransform_class = (GstBaseTransformClass *) klass;
|
||||
|
||||
GST_DEBUG_CATEGORY_INIT (ssd_object_detector_debug, "ssdobjectdetector",
|
||||
0, "ssdobjectdetector");
|
||||
gobject_class->set_property = gst_ssd_object_detector_set_property;
|
||||
gobject_class->get_property = gst_ssd_object_detector_get_property;
|
||||
gobject_class->finalize = gst_ssd_object_detector_finalize;
|
||||
|
||||
/**
|
||||
* GstSsdObjectDetector:label-file
|
||||
*
|
||||
* Label file
|
||||
*
|
||||
* Since: 1.24
|
||||
*/
|
||||
g_object_class_install_property (G_OBJECT_CLASS (klass), PROP_LABEL_FILE,
|
||||
g_param_spec_string ("label-file",
|
||||
"Label file", "Label file", NULL, (GParamFlags)
|
||||
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
||||
|
||||
/**
|
||||
* GstSsdObjectDetector:score-threshold
|
||||
*
|
||||
* Threshold for deciding when to remove boxes based on score
|
||||
*
|
||||
* Since: 1.24
|
||||
*/
|
||||
g_object_class_install_property (G_OBJECT_CLASS (klass), PROP_SCORE_THRESHOLD,
|
||||
g_param_spec_float ("score-threshold",
|
||||
"Score threshold",
|
||||
"Threshold for deciding when to remove boxes based on score",
|
||||
0.0, 1.0, GST_SSD_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD, (GParamFlags)
|
||||
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
||||
|
||||
gst_element_class_set_static_metadata (element_class, "objectdetector",
|
||||
"Filter/Effect/Video",
|
||||
"Apply tensor output from inference to detect objects in video frames",
|
||||
"Aaron Boxer <aaron.boxer@collabora.com>, Marcus Edel <marcus.edel@collabora.com>");
|
||||
gst_element_class_add_pad_template (element_class,
|
||||
gst_static_pad_template_get (&gst_ssd_object_detector_sink_template));
|
||||
gst_element_class_add_pad_template (element_class,
|
||||
gst_static_pad_template_get (&gst_ssd_object_detector_src_template));
|
||||
basetransform_class->transform_ip =
|
||||
GST_DEBUG_FUNCPTR (gst_ssd_object_detector_transform_ip);
|
||||
basetransform_class->set_caps =
|
||||
GST_DEBUG_FUNCPTR (gst_ssd_object_detector_set_caps);
|
||||
}
|
||||
|
||||
static void
|
||||
gst_ssd_object_detector_init (GstSsdObjectDetector * self)
|
||||
{
|
||||
self->odutils = new GstObjectDetectorUtils::GstObjectDetectorUtils ();
|
||||
}
|
||||
|
||||
static void
|
||||
gst_ssd_object_detector_finalize (GObject * object)
|
||||
{
|
||||
GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (object);
|
||||
|
||||
delete GST_ODUTILS_MEMBER (self);
|
||||
G_OBJECT_CLASS (gst_ssd_object_detector_parent_class)->finalize (object);
|
||||
}
|
||||
|
||||
static void
|
||||
gst_ssd_object_detector_set_property (GObject * object, guint prop_id,
|
||||
const GValue * value, GParamSpec * pspec)
|
||||
{
|
||||
GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (object);
|
||||
const gchar *filename;
|
||||
|
||||
switch (prop_id) {
|
||||
case PROP_LABEL_FILE:
|
||||
filename = g_value_get_string (value);
|
||||
if (filename
|
||||
&& g_file_test (filename,
|
||||
(GFileTest) (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR))) {
|
||||
g_free (self->label_file);
|
||||
self->label_file = g_strdup (filename);
|
||||
} else {
|
||||
GST_WARNING_OBJECT (self, "Label file '%s' not found!", filename);
|
||||
}
|
||||
break;
|
||||
case PROP_SCORE_THRESHOLD:
|
||||
GST_OBJECT_LOCK (self);
|
||||
self->score_threshold = g_value_get_float (value);
|
||||
GST_OBJECT_UNLOCK (self);
|
||||
break;
|
||||
default:
|
||||
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
gst_ssd_object_detector_get_property (GObject * object, guint prop_id,
|
||||
GValue * value, GParamSpec * pspec)
|
||||
{
|
||||
GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (object);
|
||||
|
||||
switch (prop_id) {
|
||||
case PROP_LABEL_FILE:
|
||||
g_value_set_string (value, self->label_file);
|
||||
break;
|
||||
case PROP_SCORE_THRESHOLD:
|
||||
GST_OBJECT_LOCK (self);
|
||||
g_value_set_float (value, self->score_threshold);
|
||||
GST_OBJECT_UNLOCK (self);
|
||||
break;
|
||||
default:
|
||||
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static GstTensorMeta *
|
||||
gst_ssd_object_detector_get_tensor_meta (GstSsdObjectDetector * object_detector,
|
||||
GstBuffer * buf)
|
||||
{
|
||||
GstTensorMeta *tmeta = NULL;
|
||||
GList *tensor_metas;
|
||||
GList *iter;
|
||||
|
||||
// get all tensor metas
|
||||
tensor_metas = gst_tensor_meta_get_all_from_buffer (buf);
|
||||
if (!tensor_metas) {
|
||||
GST_TRACE_OBJECT (object_detector,
|
||||
"missing tensor meta from buffer %" GST_PTR_FORMAT, buf);
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
// find object detector meta
|
||||
for (iter = tensor_metas; iter != NULL; iter = g_list_next (iter)) {
|
||||
GstTensorMeta *tensor_meta = (GstTensorMeta *) iter->data;
|
||||
gint numTensors = tensor_meta->num_tensors;
|
||||
/* SSD model must have either 3 or 4 output tensor nodes: 4 if there is a label node,
|
||||
* and only 3 if there is no label */
|
||||
if (numTensors != 3 && numTensors != 4)
|
||||
continue;
|
||||
|
||||
gint boxesIndex = gst_tensor_meta_get_index_from_id (tensor_meta,
|
||||
gst_tensorid_get_quark (GST_MODEL_OBJECT_DETECTOR_BOXES));
|
||||
gint scoresIndex = gst_tensor_meta_get_index_from_id (tensor_meta,
|
||||
gst_tensorid_get_quark (GST_MODEL_OBJECT_DETECTOR_SCORES));
|
||||
gint numDetectionsIndex = gst_tensor_meta_get_index_from_id (tensor_meta,
|
||||
gst_tensorid_get_quark (GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS));
|
||||
gint clasesIndex = gst_tensor_meta_get_index_from_id (tensor_meta,
|
||||
gst_tensorid_get_quark (GST_MODEL_OBJECT_DETECTOR_CLASSES));
|
||||
|
||||
if (boxesIndex == GST_TENSOR_MISSING_ID || scoresIndex == GST_TENSOR_MISSING_ID
|
||||
|| numDetectionsIndex == GST_TENSOR_MISSING_ID)
|
||||
continue;
|
||||
|
||||
if (numTensors == 4 && clasesIndex == GST_TENSOR_MISSING_ID)
|
||||
continue;
|
||||
|
||||
tmeta = tensor_meta;
|
||||
break;
|
||||
}
|
||||
|
||||
cleanup:
|
||||
g_list_free (tensor_metas);
|
||||
|
||||
return tmeta;
|
||||
}
|
||||
|
||||
static gboolean
|
||||
gst_ssd_object_detector_set_caps (GstBaseTransform * trans, GstCaps * incaps,
|
||||
GstCaps * outcaps)
|
||||
{
|
||||
GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (trans);
|
||||
|
||||
if (!gst_video_info_from_caps (&self->video_info, incaps)) {
|
||||
GST_ERROR_OBJECT (self, "Failed to parse caps");
|
||||
return FALSE;
|
||||
}
|
||||
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
static GstFlowReturn
|
||||
gst_ssd_object_detector_transform_ip (GstBaseTransform * trans, GstBuffer * buf)
|
||||
{
|
||||
if (!gst_base_transform_is_passthrough (trans)) {
|
||||
if (!gst_ssd_object_detector_process (trans, buf)) {
|
||||
GST_ELEMENT_ERROR (trans, STREAM, FAILED,
|
||||
(NULL), ("ssd object detection failed"));
|
||||
return GST_FLOW_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
return GST_FLOW_OK;
|
||||
}
|
||||
|
||||
static gboolean
|
||||
gst_ssd_object_detector_process (GstBaseTransform * trans, GstBuffer * buf)
|
||||
{
|
||||
GstTensorMeta *tmeta = NULL;
|
||||
GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (trans);
|
||||
|
||||
// get all tensor metas
|
||||
tmeta = gst_ssd_object_detector_get_tensor_meta (self, buf);
|
||||
if (!tmeta) {
|
||||
GST_WARNING_OBJECT (trans, "missing tensor meta");
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
std::vector < GstMlBoundingBox > boxes =
|
||||
GST_ODUTILS_MEMBER (self)->run (self->video_info.width,
|
||||
self->video_info.height, tmeta, self->label_file ? self->label_file : "",
|
||||
self->score_threshold);
|
||||
|
||||
for (auto & b:boxes) {
|
||||
auto vroi_meta = gst_buffer_add_video_region_of_interest_meta (buf,
|
||||
GST_SSD_OBJECT_DETECTOR_META_NAME,
|
||||
b.x0, b.y0,
|
||||
b.width,
|
||||
b.height);
|
||||
if (!vroi_meta) {
|
||||
GST_WARNING_OBJECT (trans,
|
||||
"Unable to attach GstVideoRegionOfInterestMeta to buffer");
|
||||
return FALSE;
|
||||
}
|
||||
auto s = gst_structure_new (GST_SSD_OBJECT_DETECTOR_META_PARAM_NAME,
|
||||
GST_SSD_OBJECT_DETECTOR_META_FIELD_LABEL,
|
||||
G_TYPE_STRING,
|
||||
b.label.c_str (),
|
||||
GST_SSD_OBJECT_DETECTOR_META_FIELD_SCORE,
|
||||
G_TYPE_DOUBLE,
|
||||
b.score,
|
||||
NULL);
|
||||
gst_video_region_of_interest_meta_add_param (vroi_meta, s);
|
||||
GST_DEBUG_OBJECT (self,
|
||||
"Object detected with label : %s, score: %f, bound box: (%f,%f,%f,%f) \n",
|
||||
b.label.c_str (), b.score, b.x0, b.y0, b.x0 + b.width, b.y0 + b.height);
|
||||
}
|
||||
|
||||
return TRUE;
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
/*
|
||||
* GStreamer gstreamer-ssdobjectdetector
|
||||
* Copyright (C) 2021 Collabora Ltd
|
||||
*
|
||||
* gstssdobjectdetector.h
|
||||
*
|
||||
* This library is free software; you can redistribute it and/or
|
||||
* modify it under the terms of the GNU Library General Public
|
||||
* License as published by the Free Software Foundation; either
|
||||
* version 2 of the License, or (at your option) any later version.
|
||||
*
|
||||
* This library is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||
* Library General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Library General Public
|
||||
* License along with this library; if not, write to the
|
||||
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
|
||||
* Boston, MA 02110-1301, USA.
|
||||
*/
|
||||
|
||||
#ifndef __GST_SSD_OBJECT_DETECTOR_H__
|
||||
#define __GST_SSD_OBJECT_DETECTOR_H__
|
||||
|
||||
#include <gst/gst.h>
|
||||
#include <gst/video/video.h>
|
||||
#include <gst/video/gstvideofilter.h>
|
||||
|
||||
G_BEGIN_DECLS
|
||||
|
||||
#define GST_TYPE_SSD_OBJECT_DETECTOR (gst_ssd_object_detector_get_type())
|
||||
G_DECLARE_FINAL_TYPE (GstSsdObjectDetector, gst_ssd_object_detector, GST, SSD_OBJECT_DETECTOR, GstBaseTransform)
|
||||
|
||||
#define GST_SSD_OBJECT_DETECTOR_META_NAME "ssd-object-detector"
|
||||
#define GST_SSD_OBJECT_DETECTOR_META_PARAM_NAME "extra-data"
|
||||
#define GST_SSD_OBJECT_DETECTOR_META_FIELD_LABEL "label"
|
||||
#define GST_SSD_OBJECT_DETECTOR_META_FIELD_SCORE "score"
|
||||
|
||||
/**
|
||||
* GstSsdObjectDetector:
|
||||
*
|
||||
* @label_file label file
|
||||
* @score_threshold score threshold
|
||||
* @confidence_threshold confidence threshold
|
||||
* @iou_threhsold iou threshold
|
||||
* @od_ptr opaque pointer to GstOd object detection implementation
|
||||
*
|
||||
* Since: 1.20
|
||||
*/
|
||||
struct _GstSsdObjectDetector
|
||||
{
|
||||
GstBaseTransform basetransform;
|
||||
gchar *label_file;
|
||||
gfloat score_threshold;
|
||||
gfloat confidence_threshold;
|
||||
gfloat iou_threshold;
|
||||
gpointer odutils;
|
||||
GstVideoInfo video_info;
|
||||
};
|
||||
|
||||
/**
|
||||
* GstSsdObjectDetectorClass:
|
||||
*
|
||||
* @parent_class base transform base class
|
||||
*
|
||||
* Since: 1.20
|
||||
*/
|
||||
struct _GstSsdObjectDetectorClass
|
||||
{
|
||||
GstBaseTransformClass parent_class;
|
||||
};
|
||||
|
||||
GST_ELEMENT_REGISTER_DECLARE (ssd_object_detector)
|
||||
|
||||
G_END_DECLS
|
||||
|
||||
#endif /* __GST_SSD_OBJECT_DETECTOR_H__ */
|
41
subprojects/gst-plugins-bad/ext/onnx/gstml.h
Normal file
41
subprojects/gst-plugins-bad/ext/onnx/gstml.h
Normal file
|
@ -0,0 +1,41 @@
|
|||
/*
|
||||
* GStreamer gstreamer-ml
|
||||
* Copyright (C) 2021 Collabora Ltd
|
||||
*
|
||||
* gstml.h
|
||||
*
|
||||
* This library is free software; you can redistribute it and/or
|
||||
* modify it under the terms of the GNU Library General Public
|
||||
* License as published by the Free Software Foundation; either
|
||||
* version 2 of the License, or (at your option) any later version.
|
||||
*
|
||||
* This library is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||
* Library General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Library General Public
|
||||
* License along with this library; if not, write to the
|
||||
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
|
||||
* Boston, MA 02110-1301, USA.
|
||||
*/
|
||||
#ifndef __GST_ML_H__
|
||||
#define __GST_ML_H__
|
||||
|
||||
|
||||
/**
|
||||
* GstMlInputImageFormat:
|
||||
*
|
||||
* @GST_ML_INPUT_IMAGE_FORMAT_HWC Height Width Channel (a.k.a. interleaved) format
|
||||
* @GST_ML_INPUT_IMAGE_FORMAT_CHW Channel Height Width (a.k.a. planar) format
|
||||
*
|
||||
* Since: 1.20
|
||||
*/
|
||||
typedef enum {
|
||||
GST_ML_INPUT_IMAGE_FORMAT_HWC,
|
||||
GST_ML_INPUT_IMAGE_FORMAT_CHW,
|
||||
} GstMlInputImageFormat;
|
||||
|
||||
|
||||
|
||||
#endif
|
|
@ -1,3 +1,4 @@
|
|||
|
||||
/*
|
||||
* GStreamer gstreamer-onnx
|
||||
* Copyright (C) 2021 Collabora Ltd
|
||||
|
@ -23,14 +24,17 @@
|
|||
#include "config.h"
|
||||
#endif
|
||||
|
||||
#include "gstonnxobjectdetector.h"
|
||||
#include "decoders/gstssdobjectdetector.h"
|
||||
#include "gstonnxinference.h"
|
||||
#include "tensor/gsttensormeta.h"
|
||||
|
||||
static gboolean
|
||||
plugin_init (GstPlugin * plugin)
|
||||
{
|
||||
GST_ELEMENT_REGISTER (onnx_object_detector, plugin);
|
||||
gboolean success = GST_ELEMENT_REGISTER (ssd_object_detector, plugin);
|
||||
success |= GST_ELEMENT_REGISTER (onnx_inference, plugin);
|
||||
|
||||
return TRUE;
|
||||
return success;
|
||||
}
|
||||
|
||||
GST_PLUGIN_DEFINE (GST_VERSION_MAJOR,
|
||||
|
|
|
@ -21,23 +21,15 @@
|
|||
*/
|
||||
|
||||
#include "gstonnxclient.h"
|
||||
#include <tensor/gsttensorid.h>
|
||||
#include <providers/cpu/cpu_provider_factory.h>
|
||||
#ifdef GST_ML_ONNX_RUNTIME_HAVE_CUDA
|
||||
#include <providers/cuda/cuda_provider_factory.h>
|
||||
#endif
|
||||
#include <exception>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <numeric>
|
||||
#include <cmath>
|
||||
#include <sstream>
|
||||
|
||||
namespace GstOnnxNamespace
|
||||
{
|
||||
template < typename T >
|
||||
std::ostream & operator<< (std::ostream & os, const std::vector < T > &v)
|
||||
{
|
||||
template < typename T >
|
||||
std::ostream & operator<< (std::ostream & os, const std::vector < T > &v)
|
||||
{
|
||||
os << "[";
|
||||
for (size_t i = 0; i < v.size (); ++i)
|
||||
{
|
||||
|
@ -50,13 +42,7 @@ template < typename T >
|
|||
os << "]";
|
||||
|
||||
return os;
|
||||
}
|
||||
|
||||
GstMlOutputNodeInfo::GstMlOutputNodeInfo (void):index
|
||||
(GST_ML_NODE_INDEX_DISABLED),
|
||||
type (ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
|
||||
{
|
||||
}
|
||||
}
|
||||
|
||||
GstOnnxClient::GstOnnxClient ():session (nullptr),
|
||||
width (0),
|
||||
|
@ -64,123 +50,59 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
|
|||
channels (0),
|
||||
dest (nullptr),
|
||||
m_provider (GST_ONNX_EXECUTION_PROVIDER_CPU),
|
||||
inputImageFormat (GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC),
|
||||
fixedInputImageSize (true)
|
||||
{
|
||||
for (size_t i = 0; i < GST_ML_OUTPUT_NODE_NUMBER_OF; ++i)
|
||||
outputNodeIndexToFunction[i] = (GstMlOutputNodeFunction) i;
|
||||
}
|
||||
inputImageFormat (GST_ML_INPUT_IMAGE_FORMAT_HWC),
|
||||
fixedInputImageSize (false) {
|
||||
}
|
||||
|
||||
GstOnnxClient::~GstOnnxClient ()
|
||||
{
|
||||
outputNames.clear();
|
||||
GstOnnxClient::~GstOnnxClient () {
|
||||
delete session;
|
||||
delete[]dest;
|
||||
}
|
||||
}
|
||||
|
||||
Ort::Env & GstOnnxClient::getEnv (void)
|
||||
{
|
||||
static Ort::Env env (OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING,
|
||||
"GstOnnxNamespace");
|
||||
|
||||
return env;
|
||||
}
|
||||
|
||||
int32_t GstOnnxClient::getWidth (void)
|
||||
{
|
||||
int32_t GstOnnxClient::getWidth (void)
|
||||
{
|
||||
return width;
|
||||
}
|
||||
}
|
||||
|
||||
int32_t GstOnnxClient::getHeight (void)
|
||||
{
|
||||
int32_t GstOnnxClient::getHeight (void)
|
||||
{
|
||||
return height;
|
||||
}
|
||||
}
|
||||
|
||||
bool GstOnnxClient::isFixedInputImageSize (void)
|
||||
{
|
||||
bool GstOnnxClient::isFixedInputImageSize (void)
|
||||
{
|
||||
return fixedInputImageSize;
|
||||
}
|
||||
}
|
||||
|
||||
std::string GstOnnxClient::getOutputNodeName (GstMlOutputNodeFunction nodeType)
|
||||
{
|
||||
switch (nodeType) {
|
||||
case GST_ML_OUTPUT_NODE_FUNCTION_DETECTION:
|
||||
return "detection";
|
||||
break;
|
||||
case GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX:
|
||||
return "bounding box";
|
||||
break;
|
||||
case GST_ML_OUTPUT_NODE_FUNCTION_SCORE:
|
||||
return "score";
|
||||
break;
|
||||
case GST_ML_OUTPUT_NODE_FUNCTION_CLASS:
|
||||
return "label";
|
||||
break;
|
||||
case GST_ML_OUTPUT_NODE_NUMBER_OF:
|
||||
g_assert_not_reached();
|
||||
GST_WARNING("Invalid parameter");
|
||||
break;
|
||||
};
|
||||
|
||||
return "";
|
||||
}
|
||||
|
||||
void GstOnnxClient::setInputImageFormat (GstMlModelInputImageFormat format)
|
||||
{
|
||||
void GstOnnxClient::setInputImageFormat (GstMlInputImageFormat format)
|
||||
{
|
||||
inputImageFormat = format;
|
||||
}
|
||||
}
|
||||
|
||||
GstMlModelInputImageFormat GstOnnxClient::getInputImageFormat (void)
|
||||
{
|
||||
GstMlInputImageFormat GstOnnxClient::getInputImageFormat (void)
|
||||
{
|
||||
return inputImageFormat;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector< const char *> GstOnnxClient::getOutputNodeNames (void)
|
||||
{
|
||||
if (!outputNames.empty() && outputNamesRaw.size() != outputNames.size()) {
|
||||
outputNamesRaw.resize(outputNames.size());
|
||||
for (size_t i = 0; i < outputNamesRaw.size(); i++) {
|
||||
outputNamesRaw[i] = outputNames[i].get();
|
||||
}
|
||||
std::vector < const char *>GstOnnxClient::genOutputNamesRaw (void)
|
||||
{
|
||||
if (!outputNames.empty () && outputNamesRaw.size () != outputNames.size ()) {
|
||||
outputNamesRaw.resize (outputNames.size ());
|
||||
for (size_t i = 0; i < outputNamesRaw.size (); i++)
|
||||
outputNamesRaw[i] = outputNames[i].get ();
|
||||
}
|
||||
|
||||
return outputNamesRaw;
|
||||
}
|
||||
}
|
||||
|
||||
void GstOnnxClient::setOutputNodeIndex (GstMlOutputNodeFunction node,
|
||||
gint index)
|
||||
{
|
||||
g_assert (index < GST_ML_OUTPUT_NODE_NUMBER_OF);
|
||||
outputNodeInfo[node].index = index;
|
||||
if (index != GST_ML_NODE_INDEX_DISABLED)
|
||||
outputNodeIndexToFunction[index] = node;
|
||||
}
|
||||
|
||||
gint GstOnnxClient::getOutputNodeIndex (GstMlOutputNodeFunction node)
|
||||
{
|
||||
return outputNodeInfo[node].index;
|
||||
}
|
||||
|
||||
void GstOnnxClient::setOutputNodeType (GstMlOutputNodeFunction node,
|
||||
ONNXTensorElementDataType type)
|
||||
{
|
||||
outputNodeInfo[node].type = type;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType
|
||||
GstOnnxClient::getOutputNodeType (GstMlOutputNodeFunction node)
|
||||
{
|
||||
return outputNodeInfo[node].type;
|
||||
}
|
||||
|
||||
bool GstOnnxClient::hasSession (void)
|
||||
{
|
||||
bool GstOnnxClient::hasSession (void)
|
||||
{
|
||||
return session != nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
bool GstOnnxClient::createSession (std::string modelFile,
|
||||
bool GstOnnxClient::createSession (std::string modelFile,
|
||||
GstOnnxOptimizationLevel optim, GstOnnxExecutionProvider provider)
|
||||
{
|
||||
{
|
||||
if (session)
|
||||
return true;
|
||||
|
||||
|
@ -205,30 +127,43 @@ bool GstOnnxClient::createSession (std::string modelFile,
|
|||
|
||||
try {
|
||||
Ort::SessionOptions sessionOptions;
|
||||
const auto & api = Ort::GetApi ();
|
||||
// for debugging
|
||||
//sessionOptions.SetIntraOpNumThreads (1);
|
||||
sessionOptions.SetGraphOptimizationLevel (onnx_optim);
|
||||
m_provider = provider;
|
||||
switch (m_provider) {
|
||||
case GST_ONNX_EXECUTION_PROVIDER_CUDA:
|
||||
#ifdef GST_ML_ONNX_RUNTIME_HAVE_CUDA
|
||||
Ort::ThrowOnError (OrtSessionOptionsAppendExecutionProvider_CUDA
|
||||
(sessionOptions, 0));
|
||||
#else
|
||||
GST_ERROR ("ONNX CUDA execution provider not supported");
|
||||
return false;
|
||||
#endif
|
||||
try {
|
||||
OrtCUDAProviderOptionsV2 *cuda_options = nullptr;
|
||||
Ort::ThrowOnError (api.CreateCUDAProviderOptions (&cuda_options));
|
||||
std::unique_ptr < OrtCUDAProviderOptionsV2,
|
||||
decltype (api.ReleaseCUDAProviderOptions) >
|
||||
rel_cuda_options (cuda_options, api.ReleaseCUDAProviderOptions);
|
||||
Ort::ThrowOnError (api.SessionOptionsAppendExecutionProvider_CUDA_V2
|
||||
(static_cast < OrtSessionOptions * >(sessionOptions),
|
||||
rel_cuda_options.get ()));
|
||||
}
|
||||
catch (Ort::Exception & ortex) {
|
||||
GST_WARNING
|
||||
("Failed to create CUDA provider - dropping back to CPU");
|
||||
Ort::ThrowOnError (OrtSessionOptionsAppendExecutionProvider_CPU
|
||||
(sessionOptions, 1));
|
||||
}
|
||||
break;
|
||||
default:
|
||||
Ort::ThrowOnError (OrtSessionOptionsAppendExecutionProvider_CPU
|
||||
(sessionOptions, 1));
|
||||
break;
|
||||
|
||||
};
|
||||
session =
|
||||
new Ort::Session (getEnv (), modelFile.c_str (), sessionOptions);
|
||||
env =
|
||||
Ort::Env (OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING,
|
||||
"GstOnnxNamespace");
|
||||
session = new Ort::Session (env, modelFile.c_str (), sessionOptions);
|
||||
auto inputTypeInfo = session->GetInputTypeInfo (0);
|
||||
std::vector < int64_t > inputDims =
|
||||
inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape ();
|
||||
if (inputImageFormat == GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC) {
|
||||
if (inputImageFormat == GST_ML_INPUT_IMAGE_FORMAT_HWC) {
|
||||
height = inputDims[1];
|
||||
width = inputDims[2];
|
||||
channels = inputDims[3];
|
||||
|
@ -250,14 +185,37 @@ bool GstOnnxClient::createSession (std::string modelFile,
|
|||
auto output_name = session->GetOutputNameAllocated (i, allocator);
|
||||
GST_DEBUG ("Output name %lu:%s", i, output_name.get ());
|
||||
outputNames.push_back (std::move (output_name));
|
||||
auto type_info = session->GetOutputTypeInfo (i);
|
||||
auto tensor_info = type_info.GetTensorTypeAndShapeInfo ();
|
||||
}
|
||||
genOutputNamesRaw ();
|
||||
|
||||
if (i < GST_ML_OUTPUT_NODE_NUMBER_OF) {
|
||||
auto function = outputNodeIndexToFunction[i];
|
||||
outputNodeInfo[function].type = tensor_info.GetElementType ();
|
||||
// look up tensor ids
|
||||
auto metaData = session->GetModelMetadata ();
|
||||
OrtAllocator *ortAllocator;
|
||||
auto status =
|
||||
Ort::GetApi ().GetAllocatorWithDefaultOptions (&ortAllocator);
|
||||
if (status) {
|
||||
// Handle the error case
|
||||
const char *errorString = Ort::GetApi ().GetErrorMessage (status);
|
||||
GST_WARNING ("Failed to get allocator: %s", errorString);
|
||||
|
||||
// Clean up the error status
|
||||
Ort::GetApi ().ReleaseStatus (status);
|
||||
|
||||
return false;
|
||||
} else {
|
||||
for (auto & name:outputNamesRaw) {
|
||||
Ort::AllocatedStringPtr res =
|
||||
metaData.LookupCustomMetadataMapAllocated (name, ortAllocator);
|
||||
if (res) {
|
||||
GQuark quark = gst_tensorid_get_quark (res.get ());
|
||||
outputIds.push_back (quark);
|
||||
} else {
|
||||
GST_ERROR ("Failed to look up id for key %s", name);
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (Ort::Exception & ortex) {
|
||||
GST_ERROR ("%s", ortex.what ());
|
||||
|
@ -265,40 +223,110 @@ bool GstOnnxClient::createSession (std::string modelFile,
|
|||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector < GstMlBoundingBox > GstOnnxClient::run (uint8_t * img_data,
|
||||
GstVideoMeta * vmeta, std::string labelPath, float scoreThreshold)
|
||||
{
|
||||
auto type = getOutputNodeType (GST_ML_OUTPUT_NODE_FUNCTION_CLASS);
|
||||
return (type ==
|
||||
ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) ?
|
||||
doRun < float >(img_data, vmeta, labelPath, scoreThreshold)
|
||||
: doRun < int >(img_data, vmeta, labelPath, scoreThreshold);
|
||||
}
|
||||
|
||||
void GstOnnxClient::parseDimensions (GstVideoMeta * vmeta)
|
||||
{
|
||||
int32_t newWidth = fixedInputImageSize ? width : vmeta->width;
|
||||
int32_t newHeight = fixedInputImageSize ? height : vmeta->height;
|
||||
void GstOnnxClient::parseDimensions (GstVideoInfo vinfo)
|
||||
{
|
||||
int32_t newWidth = fixedInputImageSize ? width : vinfo.width;
|
||||
int32_t newHeight = fixedInputImageSize ? height : vinfo.height;
|
||||
|
||||
if (!dest || width * height < newWidth * newHeight) {
|
||||
delete[] dest;
|
||||
delete[]dest;
|
||||
dest = new uint8_t[newWidth * newHeight * channels];
|
||||
}
|
||||
width = newWidth;
|
||||
height = newHeight;
|
||||
}
|
||||
}
|
||||
|
||||
template < typename T > std::vector < GstMlBoundingBox >
|
||||
GstOnnxClient::doRun (uint8_t * img_data, GstVideoMeta * vmeta,
|
||||
std::string labelPath, float scoreThreshold)
|
||||
{
|
||||
std::vector < GstMlBoundingBox > boundingBoxes;
|
||||
// copy tensor data to a GstTensorMeta
|
||||
GstTensorMeta *GstOnnxClient::copy_tensors_to_meta (std::vector < Ort::Value >
|
||||
&outputs, GstBuffer * buffer)
|
||||
{
|
||||
size_t num_tensors = outputNamesRaw.size ();
|
||||
GstTensorMeta *tmeta = (GstTensorMeta *) gst_buffer_add_meta (buffer,
|
||||
gst_tensor_meta_get_info (),
|
||||
NULL);
|
||||
tmeta->num_tensors = num_tensors;
|
||||
tmeta->tensor = (GstTensor *) g_malloc (num_tensors * sizeof (GstTensor));
|
||||
bool hasIds = outputIds.size () == num_tensors;
|
||||
for (size_t i = 0; i < num_tensors; i++) {
|
||||
Ort::Value outputTensor = std::move (outputs[i]);
|
||||
|
||||
ONNXTensorElementDataType tensorType =
|
||||
outputTensor.GetTensorTypeAndShapeInfo ().GetElementType ();
|
||||
|
||||
GstTensor *tensor = &tmeta->tensor[i];
|
||||
if (hasIds)
|
||||
tensor->id = outputIds[i];
|
||||
tensor->data = gst_buffer_new ();
|
||||
auto tensorShape = outputTensor.GetTensorTypeAndShapeInfo ().GetShape ();
|
||||
tensor->num_dims = tensorShape.size ();
|
||||
tensor->dims = g_new (int64_t, tensor->num_dims);
|
||||
|
||||
for (size_t j = 0; j < tensorShape.size (); ++j) {
|
||||
tensor->dims[j] = tensorShape[j];
|
||||
}
|
||||
|
||||
size_t numElements =
|
||||
outputTensor.GetTensorTypeAndShapeInfo ().GetElementCount ();
|
||||
|
||||
size_t buffer_size = 0;
|
||||
guint8 *buffer_data = NULL;
|
||||
if (tensorType == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
|
||||
buffer_size = numElements * sizeof (float);
|
||||
|
||||
// Allocate memory for the buffer data
|
||||
buffer_data = (guint8 *) malloc (buffer_size);
|
||||
if (buffer_data == NULL) {
|
||||
GST_ERROR ("Failed to allocate memory");
|
||||
return NULL;
|
||||
}
|
||||
// Copy the data from the source buffer to the allocated memory
|
||||
memcpy (buffer_data, outputTensor.GetTensorData < float >(),
|
||||
buffer_size);
|
||||
tensor->type = GST_TENSOR_TYPE_FLOAT32;
|
||||
} else if (tensorType == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) {
|
||||
buffer_size = numElements * sizeof (int);
|
||||
|
||||
// Allocate memory for the buffer data
|
||||
guint8 *buffer_data = (guint8 *) malloc (buffer_size);
|
||||
if (buffer_data == NULL) {
|
||||
GST_ERROR ("Failed to allocate memory");
|
||||
return NULL;
|
||||
}
|
||||
// Copy the data from the source buffer to the allocated memory
|
||||
memcpy (buffer_data, outputTensor.GetTensorData < int >(),
|
||||
buffer_size);
|
||||
tensor->type = GST_TENSOR_TYPE_INT32;
|
||||
}
|
||||
if (buffer_data) {
|
||||
|
||||
// Create a GstMemory object from the allocated memory
|
||||
GstMemory *memory = gst_memory_new_wrapped ((GstMemoryFlags) 0,
|
||||
buffer_data, buffer_size, 0, buffer_size, NULL, NULL);
|
||||
|
||||
// Append the GstMemory object to the GstBuffer
|
||||
gst_buffer_append_memory (tmeta->tensor[i].data, memory);
|
||||
}
|
||||
}
|
||||
|
||||
return tmeta;
|
||||
|
||||
}
|
||||
|
||||
std::vector < Ort::Value > GstOnnxClient::run (uint8_t * img_data,
|
||||
GstVideoInfo vinfo) {
|
||||
std::vector < Ort::Value > modelOutput;
|
||||
doRun (img_data, vinfo, modelOutput);
|
||||
|
||||
return modelOutput;
|
||||
}
|
||||
|
||||
bool GstOnnxClient::doRun (uint8_t * img_data, GstVideoInfo vinfo,
|
||||
std::vector < Ort::Value > &modelOutput)
|
||||
{
|
||||
if (!img_data)
|
||||
return boundingBoxes;
|
||||
|
||||
parseDimensions (vmeta);
|
||||
return false;
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
auto inputName = session->GetInputNameAllocated (0, allocator);
|
||||
|
@ -306,7 +334,7 @@ template < typename T > std::vector < GstMlBoundingBox >
|
|||
std::vector < int64_t > inputDims =
|
||||
inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape ();
|
||||
inputDims[0] = 1;
|
||||
if (inputImageFormat == GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC) {
|
||||
if (inputImageFormat == GST_ML_INPUT_IMAGE_FORMAT_HWC) {
|
||||
inputDims[1] = height;
|
||||
inputDims[2] = width;
|
||||
} else {
|
||||
|
@ -321,7 +349,7 @@ template < typename T > std::vector < GstMlBoundingBox >
|
|||
// copy video frame
|
||||
uint8_t *srcPtr[3] = { img_data, img_data + 1, img_data + 2 };
|
||||
uint32_t srcSamplesPerPixel = 3;
|
||||
switch (vmeta->format) {
|
||||
switch (vinfo.finfo->format) {
|
||||
case GST_VIDEO_FORMAT_RGBA:
|
||||
srcSamplesPerPixel = 4;
|
||||
break;
|
||||
|
@ -352,8 +380,8 @@ template < typename T > std::vector < GstMlBoundingBox >
|
|||
break;
|
||||
}
|
||||
size_t destIndex = 0;
|
||||
uint32_t stride = vmeta->stride[0];
|
||||
if (inputImageFormat == GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC) {
|
||||
uint32_t stride = vinfo.stride[0];
|
||||
if (inputImageFormat == GST_ML_INPUT_IMAGE_FORMAT_HWC) {
|
||||
for (int32_t j = 0; j < height; ++j) {
|
||||
for (int32_t i = 0; i < width; ++i) {
|
||||
for (int32_t k = 0; k < channels; ++k) {
|
||||
|
@ -389,58 +417,17 @@ template < typename T > std::vector < GstMlBoundingBox >
|
|||
std::vector < Ort::Value > inputTensors;
|
||||
inputTensors.push_back (Ort::Value::CreateTensor < uint8_t > (memoryInfo,
|
||||
dest, inputTensorSize, inputDims.data (), inputDims.size ()));
|
||||
std::vector < const char *>inputNames { inputName.get () };
|
||||
std::vector < const char *>inputNames
|
||||
{
|
||||
inputName.get ()};
|
||||
|
||||
std::vector < Ort::Value > modelOutput = session->Run (Ort::RunOptions { nullptr},
|
||||
modelOutput = session->Run (Ort::RunOptions {
|
||||
nullptr},
|
||||
inputNames.data (),
|
||||
inputTensors.data (), 1, outputNamesRaw.data (), outputNamesRaw.size ());
|
||||
inputTensors.data (), 1, outputNamesRaw.data (),
|
||||
outputNamesRaw.size ());
|
||||
|
||||
auto numDetections =
|
||||
modelOutput[getOutputNodeIndex
|
||||
(GST_ML_OUTPUT_NODE_FUNCTION_DETECTION)].GetTensorMutableData < float >();
|
||||
auto bboxes =
|
||||
modelOutput[getOutputNodeIndex
|
||||
(GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX)].GetTensorMutableData < float >();
|
||||
auto scores =
|
||||
modelOutput[getOutputNodeIndex
|
||||
(GST_ML_OUTPUT_NODE_FUNCTION_SCORE)].GetTensorMutableData < float >();
|
||||
T *labelIndex = nullptr;
|
||||
if (getOutputNodeIndex (GST_ML_OUTPUT_NODE_FUNCTION_CLASS) !=
|
||||
GST_ML_NODE_INDEX_DISABLED) {
|
||||
labelIndex =
|
||||
modelOutput[getOutputNodeIndex
|
||||
(GST_ML_OUTPUT_NODE_FUNCTION_CLASS)].GetTensorMutableData < T > ();
|
||||
}
|
||||
if (labels.empty () && !labelPath.empty ())
|
||||
labels = ReadLabels (labelPath);
|
||||
|
||||
for (int i = 0; i < numDetections[0]; ++i) {
|
||||
if (scores[i] > scoreThreshold) {
|
||||
std::string label = "";
|
||||
|
||||
if (labelIndex && !labels.empty ())
|
||||
label = labels[labelIndex[i] - 1];
|
||||
auto score = scores[i];
|
||||
auto y0 = bboxes[i * 4] * height;
|
||||
auto x0 = bboxes[i * 4 + 1] * width;
|
||||
auto bheight = bboxes[i * 4 + 2] * height - y0;
|
||||
auto bwidth = bboxes[i * 4 + 3] * width - x0;
|
||||
boundingBoxes.push_back (GstMlBoundingBox (label, score, x0, y0, bwidth,
|
||||
bheight));
|
||||
}
|
||||
}
|
||||
return boundingBoxes;
|
||||
}
|
||||
|
||||
std::vector < std::string >
|
||||
GstOnnxClient::ReadLabels (const std::string & labelsFile)
|
||||
{
|
||||
std::vector < std::string > labels;
|
||||
std::string line;
|
||||
std::ifstream fp (labelsFile);
|
||||
while (std::getline (fp, line))
|
||||
labels.push_back (line);
|
||||
|
||||
return labels;
|
||||
return true;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -25,45 +25,11 @@
|
|||
#include <gst/gst.h>
|
||||
#include <onnxruntime_cxx_api.h>
|
||||
#include <gst/video/video.h>
|
||||
#include "gstonnxelement.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "gstml.h"
|
||||
#include "gstonnxenums.h"
|
||||
#include "tensor/gsttensormeta.h"
|
||||
|
||||
namespace GstOnnxNamespace {
|
||||
enum GstMlOutputNodeFunction {
|
||||
GST_ML_OUTPUT_NODE_FUNCTION_DETECTION,
|
||||
GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX,
|
||||
GST_ML_OUTPUT_NODE_FUNCTION_SCORE,
|
||||
GST_ML_OUTPUT_NODE_FUNCTION_CLASS,
|
||||
GST_ML_OUTPUT_NODE_NUMBER_OF,
|
||||
};
|
||||
|
||||
const gint GST_ML_NODE_INDEX_DISABLED = -1;
|
||||
|
||||
struct GstMlOutputNodeInfo {
|
||||
GstMlOutputNodeInfo(void);
|
||||
gint index;
|
||||
ONNXTensorElementDataType type;
|
||||
};
|
||||
|
||||
struct GstMlBoundingBox {
|
||||
GstMlBoundingBox(std::string lbl,
|
||||
float score,
|
||||
float _x0,
|
||||
float _y0,
|
||||
float _width,
|
||||
float _height):label(lbl),
|
||||
score(score), x0(_x0), y0(_y0), width(_width), height(_height) {
|
||||
}
|
||||
GstMlBoundingBox():GstMlBoundingBox("", 0.0f, 0.0f, 0.0f, 0.0f, 0.0f) {
|
||||
}
|
||||
std::string label;
|
||||
float score;
|
||||
float x0;
|
||||
float y0;
|
||||
float width;
|
||||
float height;
|
||||
};
|
||||
|
||||
class GstOnnxClient {
|
||||
public:
|
||||
|
@ -72,30 +38,18 @@ namespace GstOnnxNamespace {
|
|||
bool createSession(std::string modelFile, GstOnnxOptimizationLevel optim,
|
||||
GstOnnxExecutionProvider provider);
|
||||
bool hasSession(void);
|
||||
void setInputImageFormat(GstMlModelInputImageFormat format);
|
||||
GstMlModelInputImageFormat getInputImageFormat(void);
|
||||
void setOutputNodeIndex(GstMlOutputNodeFunction nodeType, gint index);
|
||||
gint getOutputNodeIndex(GstMlOutputNodeFunction nodeType);
|
||||
void setOutputNodeType(GstMlOutputNodeFunction nodeType,
|
||||
ONNXTensorElementDataType type);
|
||||
ONNXTensorElementDataType getOutputNodeType(GstMlOutputNodeFunction type);
|
||||
std::string getOutputNodeName(GstMlOutputNodeFunction nodeType);
|
||||
std::vector < GstMlBoundingBox > run(uint8_t * img_data,
|
||||
GstVideoMeta * vmeta,
|
||||
std::string labelPath,
|
||||
float scoreThreshold);
|
||||
std::vector < GstMlBoundingBox > &getBoundingBoxes(void);
|
||||
std::vector < const char *>getOutputNodeNames(void);
|
||||
void setInputImageFormat(GstMlInputImageFormat format);
|
||||
GstMlInputImageFormat getInputImageFormat(void);
|
||||
std::vector < Ort::Value > run (uint8_t * img_data, GstVideoInfo vinfo);
|
||||
std::vector < const char *> genOutputNamesRaw(void);
|
||||
bool isFixedInputImageSize(void);
|
||||
int32_t getWidth(void);
|
||||
int32_t getHeight(void);
|
||||
GstTensorMeta* copy_tensors_to_meta(std::vector < Ort::Value > &outputs,GstBuffer* buffer);
|
||||
void parseDimensions(GstVideoInfo vinfo);
|
||||
private:
|
||||
void parseDimensions(GstVideoMeta * vmeta);
|
||||
template < typename T > std::vector < GstMlBoundingBox >
|
||||
doRun(uint8_t * img_data, GstVideoMeta * vmeta, std::string labelPath,
|
||||
float scoreThreshold);
|
||||
std::vector < std::string > ReadLabels(const std::string & labelsFile);
|
||||
Ort::Env & getEnv(void);
|
||||
bool doRun(uint8_t * img_data, GstVideoInfo vinfo, std::vector < Ort::Value > &modelOutput);
|
||||
Ort::Env env;
|
||||
Ort::Session * session;
|
||||
int32_t width;
|
||||
int32_t height;
|
||||
|
@ -104,13 +58,10 @@ namespace GstOnnxNamespace {
|
|||
GstOnnxExecutionProvider m_provider;
|
||||
std::vector < Ort::Value > modelOutput;
|
||||
std::vector < std::string > labels;
|
||||
// !! indexed by function
|
||||
GstMlOutputNodeInfo outputNodeInfo[GST_ML_OUTPUT_NODE_NUMBER_OF];
|
||||
// !! indexed by array index
|
||||
size_t outputNodeIndexToFunction[GST_ML_OUTPUT_NODE_NUMBER_OF];
|
||||
std::vector < const char *> outputNamesRaw;
|
||||
std::vector < Ort::AllocatedStringPtr > outputNames;
|
||||
GstMlModelInputImageFormat inputImageFormat;
|
||||
std::vector < GQuark > outputIds;
|
||||
GstMlInputImageFormat inputImageFormat;
|
||||
bool fixedInputImageSize;
|
||||
};
|
||||
}
|
||||
|
|
|
@ -1,104 +0,0 @@
|
|||
/*
|
||||
* GStreamer gstreamer-onnxelement
|
||||
* Copyright (C) 2021 Collabora Ltd
|
||||
*
|
||||
* gstonnxelement.c
|
||||
*
|
||||
* This library is free software; you can redistribute it and/or
|
||||
* modify it under the terms of the GNU Library General Public
|
||||
* License as published by the Free Software Foundation; either
|
||||
* version 2 of the License, or (at your option) any later version.
|
||||
*
|
||||
* This library is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||
* Library General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Library General Public
|
||||
* License along with this library; if not, write to the
|
||||
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
|
||||
* Boston, MA 02110-1301, USA.
|
||||
*/
|
||||
#ifdef HAVE_CONFIG_H
|
||||
#include "config.h"
|
||||
#endif
|
||||
|
||||
#include "gstonnxelement.h"
|
||||
|
||||
GType
|
||||
gst_onnx_optimization_level_get_type (void)
|
||||
{
|
||||
static GType onnx_optimization_type = 0;
|
||||
|
||||
if (g_once_init_enter (&onnx_optimization_type)) {
|
||||
static GEnumValue optimization_level_types[] = {
|
||||
{GST_ONNX_OPTIMIZATION_LEVEL_DISABLE_ALL, "Disable all optimization",
|
||||
"disable-all"},
|
||||
{GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_BASIC,
|
||||
"Enable basic optimizations (redundant node removals))",
|
||||
"enable-basic"},
|
||||
{GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED,
|
||||
"Enable extended optimizations (redundant node removals + node fusions)",
|
||||
"enable-extended"},
|
||||
{GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_ALL,
|
||||
"Enable all possible optimizations", "enable-all"},
|
||||
{0, NULL, NULL},
|
||||
};
|
||||
|
||||
GType temp = g_enum_register_static ("GstOnnxOptimizationLevel",
|
||||
optimization_level_types);
|
||||
|
||||
g_once_init_leave (&onnx_optimization_type, temp);
|
||||
}
|
||||
|
||||
return onnx_optimization_type;
|
||||
}
|
||||
|
||||
GType
|
||||
gst_onnx_execution_provider_get_type (void)
|
||||
{
|
||||
static GType onnx_execution_type = 0;
|
||||
|
||||
if (g_once_init_enter (&onnx_execution_type)) {
|
||||
static GEnumValue execution_provider_types[] = {
|
||||
{GST_ONNX_EXECUTION_PROVIDER_CPU, "CPU execution provider",
|
||||
"cpu"},
|
||||
{GST_ONNX_EXECUTION_PROVIDER_CUDA,
|
||||
"CUDA execution provider",
|
||||
"cuda"},
|
||||
{0, NULL, NULL},
|
||||
};
|
||||
|
||||
GType temp = g_enum_register_static ("GstOnnxExecutionProvider",
|
||||
execution_provider_types);
|
||||
|
||||
g_once_init_leave (&onnx_execution_type, temp);
|
||||
}
|
||||
|
||||
return onnx_execution_type;
|
||||
}
|
||||
|
||||
GType
|
||||
gst_ml_model_input_image_format_get_type (void)
|
||||
{
|
||||
static GType ml_model_input_image_format = 0;
|
||||
|
||||
if (g_once_init_enter (&ml_model_input_image_format)) {
|
||||
static GEnumValue ml_model_input_image_format_types[] = {
|
||||
{GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC,
|
||||
"Height Width Channel (HWC) a.k.a. interleaved image data format",
|
||||
"hwc"},
|
||||
{GST_ML_MODEL_INPUT_IMAGE_FORMAT_CHW,
|
||||
"Channel Height Width (CHW) a.k.a. planar image data format",
|
||||
"chw"},
|
||||
{0, NULL, NULL},
|
||||
};
|
||||
|
||||
GType temp = g_enum_register_static ("GstMlModelInputImageFormat",
|
||||
ml_model_input_image_format_types);
|
||||
|
||||
g_once_init_leave (&ml_model_input_image_format, temp);
|
||||
}
|
||||
|
||||
return ml_model_input_image_format;
|
||||
}
|
|
@ -1,8 +1,8 @@
|
|||
/*
|
||||
* GStreamer gstreamer-onnxelement
|
||||
* GStreamer gstreamer-onnxenums
|
||||
* Copyright (C) 2021 Collabora Ltd
|
||||
*
|
||||
* gstonnxelement.h
|
||||
* gstonnxenums.h
|
||||
*
|
||||
* This library is free software; you can redistribute it and/or
|
||||
* modify it under the terms of the GNU Library General Public
|
||||
|
@ -20,10 +20,8 @@
|
|||
* Boston, MA 02110-1301, USA.
|
||||
*/
|
||||
|
||||
#ifndef __GST_ONNX_ELEMENT_H__
|
||||
#define __GST_ONNX_ELEMENT_H__
|
||||
|
||||
#include <gst/gst.h>
|
||||
#ifndef __GST_ONNX_ENUMS_H__
|
||||
#define __GST_ONNX_ENUMS_H__
|
||||
|
||||
typedef enum
|
||||
{
|
||||
|
@ -39,26 +37,5 @@ typedef enum
|
|||
GST_ONNX_EXECUTION_PROVIDER_CUDA,
|
||||
} GstOnnxExecutionProvider;
|
||||
|
||||
typedef enum {
|
||||
/* Height Width Channel (a.k.a. interleaved) format */
|
||||
GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC,
|
||||
|
||||
/* Channel Height Width (a.k.a. planar) format */
|
||||
GST_ML_MODEL_INPUT_IMAGE_FORMAT_CHW,
|
||||
} GstMlModelInputImageFormat;
|
||||
|
||||
|
||||
G_BEGIN_DECLS
|
||||
|
||||
GType gst_onnx_optimization_level_get_type (void);
|
||||
#define GST_TYPE_ONNX_OPTIMIZATION_LEVEL (gst_onnx_optimization_level_get_type ())
|
||||
|
||||
GType gst_onnx_execution_provider_get_type (void);
|
||||
#define GST_TYPE_ONNX_EXECUTION_PROVIDER (gst_onnx_execution_provider_get_type ())
|
||||
|
||||
GType gst_ml_model_input_image_format_get_type (void);
|
||||
#define GST_TYPE_ML_MODEL_INPUT_IMAGE_FORMAT (gst_ml_model_input_image_format_get_type ())
|
||||
|
||||
G_END_DECLS
|
||||
|
||||
#endif
|
||||
#endif /* __GST_ONNX_ENUMS_H__ */
|
539
subprojects/gst-plugins-bad/ext/onnx/gstonnxinference.cpp
Normal file
539
subprojects/gst-plugins-bad/ext/onnx/gstonnxinference.cpp
Normal file
|
@ -0,0 +1,539 @@
|
|||
/*
|
||||
* GStreamer gstreamer-onnxinference
|
||||
* Copyright (C) 2023 Collabora Ltd.
|
||||
*
|
||||
* gstonnxinference.cpp
|
||||
*
|
||||
* This library is free software; you can redistribute it and/or
|
||||
* modify it under the terms of the GNU Library General Public
|
||||
* License as published by the Free Software Foundation; either
|
||||
* version 2 of the License, or (at your option) any later version.
|
||||
*
|
||||
* This library is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||
* Library General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Library General Public
|
||||
* License along with this library; if not, write to the
|
||||
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
|
||||
* Boston, MA 02110-1301, USA.
|
||||
*/
|
||||
|
||||
/**
|
||||
* SECTION:element-onnxinference
|
||||
* @short_description: Run ONNX inference model on video buffers
|
||||
*
|
||||
* This element can apply an ONNX model to video buffers. It attaches
|
||||
* the tensor output to the buffer as a @ref GstTensorMeta.
|
||||
*
|
||||
* To install ONNX on your system, recursively clone this repository
|
||||
* https://github.com/microsoft/onnxruntime.git
|
||||
*
|
||||
* and build and install with cmake:
|
||||
*
|
||||
* CPU:
|
||||
*
|
||||
* cmake -Donnxruntime_BUILD_SHARED_LIB:ON -DBUILD_TESTING:OFF \
|
||||
* $SRC_DIR/onnxruntime/cmake && make -j$(nproc) && sudo make install
|
||||
*
|
||||
*
|
||||
* CUDA :
|
||||
*
|
||||
* cmake -Donnxruntime_BUILD_SHARED_LIB:ON -DBUILD_TESTING:OFF -Donnxruntime_USE_CUDA:ON \
|
||||
* -Donnxruntime_CUDA_HOME=$CUDA_PATH -Donnxruntime_CUDNN_HOME=$CUDA_PATH \
|
||||
* $SRC_DIR/onnxruntime/cmake && make -j$(nproc) && sudo make install
|
||||
*
|
||||
*
|
||||
* where :
|
||||
*
|
||||
* 1. $SRC_DIR and $BUILD_DIR are local source and build directories
|
||||
* 2. To run with CUDA, both CUDA and cuDNN libraries must be installed.
|
||||
* $CUDA_PATH is an environment variable set to the CUDA root path.
|
||||
* On Linux, it would be /usr/local/cuda
|
||||
*
|
||||
*
|
||||
* ## Example launch command:
|
||||
*
|
||||
* GST_DEBUG=onnxinference:5 gst-launch-1.0 multifilesrc location=bus.jpg ! \
|
||||
* jpegdec ! videoconvert ! \
|
||||
* onnxinference execution-provider=cpu model-file=model.onnx \
|
||||
* videoconvert ! autovideosink
|
||||
*
|
||||
*
|
||||
* Note: in order for downstream tensor decoders to correctly parse the tensor
|
||||
* data in the GstTensorMeta, meta data must be attached to the ONNX model
|
||||
* assigning a unique string id to each output layer. These unique string ids
|
||||
* and corresponding GQuark ids are currently stored in the ONNX plugin source
|
||||
* in the file 'gsttensorid.h'. For an output layer with name Foo and with context
|
||||
* unique string id Gst.Model.Bar, a meta data key/value pair must be added
|
||||
* to the ONNX model with "Foo" mapped to "Gst.Model.Bar" in order for a downstream
|
||||
* decoder to make use of this model. If the meta data is absent, the pipeline will
|
||||
* fail.
|
||||
*
|
||||
* As a convenience, there is a python script
|
||||
* currently stored at
|
||||
* https://gitlab.collabora.com/gstreamer/onnx-models/-/blob/master/scripts/modify_onnx_metadata.py
|
||||
* to enable users to easily add and remove meta data from json files. It can also dump
|
||||
* the names of all output layers, which can then be used to craft the json meta data file.
|
||||
*
|
||||
*
|
||||
*/
|
||||
#ifdef HAVE_CONFIG_H
|
||||
#include "config.h"
|
||||
#endif
|
||||
|
||||
#include <gst/gst.h>
|
||||
#include "gstonnxinference.h"
|
||||
#include "gstonnxclient.h"
|
||||
|
||||
GST_DEBUG_CATEGORY_STATIC (onnx_inference_debug);
|
||||
#define GST_CAT_DEFAULT onnx_inference_debug
|
||||
#define GST_ONNX_CLIENT_MEMBER( self ) ((GstOnnxNamespace::GstOnnxClient *) (self->onnx_client))
|
||||
GST_ELEMENT_REGISTER_DEFINE (onnx_inference, "onnxinference",
|
||||
GST_RANK_PRIMARY, GST_TYPE_ONNX_INFERENCE);
|
||||
|
||||
/* GstOnnxInference properties */
|
||||
enum
|
||||
{
|
||||
PROP_0,
|
||||
PROP_MODEL_FILE,
|
||||
PROP_INPUT_IMAGE_FORMAT,
|
||||
PROP_OPTIMIZATION_LEVEL,
|
||||
PROP_EXECUTION_PROVIDER
|
||||
};
|
||||
|
||||
#define GST_ONNX_INFERENCE_DEFAULT_EXECUTION_PROVIDER GST_ONNX_EXECUTION_PROVIDER_CPU
|
||||
#define GST_ONNX_INFERENCE_DEFAULT_OPTIMIZATION_LEVEL GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED
|
||||
|
||||
static GstStaticPadTemplate gst_onnx_inference_src_template =
|
||||
GST_STATIC_PAD_TEMPLATE ("src",
|
||||
GST_PAD_SRC,
|
||||
GST_PAD_ALWAYS,
|
||||
GST_STATIC_CAPS (GST_VIDEO_CAPS_MAKE ("{ RGB,RGBA,BGR,BGRA }"))
|
||||
);
|
||||
|
||||
static GstStaticPadTemplate gst_onnx_inference_sink_template =
|
||||
GST_STATIC_PAD_TEMPLATE ("sink",
|
||||
GST_PAD_SINK,
|
||||
GST_PAD_ALWAYS,
|
||||
GST_STATIC_CAPS (GST_VIDEO_CAPS_MAKE ("{ RGB,RGBA,BGR,BGRA }"))
|
||||
);
|
||||
|
||||
static void gst_onnx_inference_set_property (GObject * object,
|
||||
guint prop_id, const GValue * value, GParamSpec * pspec);
|
||||
static void gst_onnx_inference_get_property (GObject * object,
|
||||
guint prop_id, GValue * value, GParamSpec * pspec);
|
||||
static void gst_onnx_inference_finalize (GObject * object);
|
||||
static GstFlowReturn gst_onnx_inference_transform_ip (GstBaseTransform *
|
||||
trans, GstBuffer * buf);
|
||||
static gboolean gst_onnx_inference_process (GstBaseTransform * trans,
|
||||
GstBuffer * buf);
|
||||
static gboolean gst_onnx_inference_create_session (GstBaseTransform * trans);
|
||||
static GstCaps *gst_onnx_inference_transform_caps (GstBaseTransform *
|
||||
trans, GstPadDirection direction, GstCaps * caps, GstCaps * filter_caps);
|
||||
static gboolean
|
||||
gst_onnx_inference_set_caps (GstBaseTransform * trans, GstCaps * incaps,
|
||||
GstCaps * outcaps);
|
||||
|
||||
G_DEFINE_TYPE (GstOnnxInference, gst_onnx_inference, GST_TYPE_BASE_TRANSFORM);
|
||||
|
||||
GType gst_onnx_optimization_level_get_type (void);
|
||||
#define GST_TYPE_ONNX_OPTIMIZATION_LEVEL (gst_onnx_optimization_level_get_type ())
|
||||
|
||||
GType gst_onnx_execution_provider_get_type (void);
|
||||
#define GST_TYPE_ONNX_EXECUTION_PROVIDER (gst_onnx_execution_provider_get_type ())
|
||||
|
||||
GType gst_ml_model_input_image_format_get_type (void);
|
||||
#define GST_TYPE_ML_MODEL_INPUT_IMAGE_FORMAT (gst_ml_model_input_image_format_get_type ())
|
||||
|
||||
GType
|
||||
gst_onnx_optimization_level_get_type (void)
|
||||
{
|
||||
static GType onnx_optimization_type = 0;
|
||||
|
||||
if (g_once_init_enter (&onnx_optimization_type)) {
|
||||
static GEnumValue optimization_level_types[] = {
|
||||
{GST_ONNX_OPTIMIZATION_LEVEL_DISABLE_ALL, "Disable all optimization",
|
||||
"disable-all"},
|
||||
{GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_BASIC,
|
||||
"Enable basic optimizations (redundant node removals))",
|
||||
"enable-basic"},
|
||||
{GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED,
|
||||
"Enable extended optimizations (redundant node removals + node fusions)",
|
||||
"enable-extended"},
|
||||
{GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_ALL,
|
||||
"Enable all possible optimizations", "enable-all"},
|
||||
{0, NULL, NULL},
|
||||
};
|
||||
|
||||
GType temp = g_enum_register_static ("GstOnnxOptimizationLevel",
|
||||
optimization_level_types);
|
||||
|
||||
g_once_init_leave (&onnx_optimization_type, temp);
|
||||
}
|
||||
|
||||
return onnx_optimization_type;
|
||||
}
|
||||
|
||||
GType
|
||||
gst_onnx_execution_provider_get_type (void)
|
||||
{
|
||||
static GType onnx_execution_type = 0;
|
||||
|
||||
if (g_once_init_enter (&onnx_execution_type)) {
|
||||
static GEnumValue execution_provider_types[] = {
|
||||
{GST_ONNX_EXECUTION_PROVIDER_CPU, "CPU execution provider",
|
||||
"cpu"},
|
||||
{GST_ONNX_EXECUTION_PROVIDER_CUDA,
|
||||
"CUDA execution provider",
|
||||
"cuda"},
|
||||
{0, NULL, NULL},
|
||||
};
|
||||
|
||||
GType temp = g_enum_register_static ("GstOnnxExecutionProvider",
|
||||
execution_provider_types);
|
||||
|
||||
g_once_init_leave (&onnx_execution_type, temp);
|
||||
}
|
||||
|
||||
return onnx_execution_type;
|
||||
}
|
||||
|
||||
GType
|
||||
gst_ml_model_input_image_format_get_type (void)
|
||||
{
|
||||
static GType ml_model_input_image_format = 0;
|
||||
|
||||
if (g_once_init_enter (&ml_model_input_image_format)) {
|
||||
static GEnumValue ml_model_input_image_format_types[] = {
|
||||
{GST_ML_INPUT_IMAGE_FORMAT_HWC,
|
||||
"Height Width Channel (HWC) a.k.a. interleaved image data format",
|
||||
"hwc"},
|
||||
{GST_ML_INPUT_IMAGE_FORMAT_CHW,
|
||||
"Channel Height Width (CHW) a.k.a. planar image data format",
|
||||
"chw"},
|
||||
{0, NULL, NULL},
|
||||
};
|
||||
|
||||
GType temp = g_enum_register_static ("GstMlInputImageFormat",
|
||||
ml_model_input_image_format_types);
|
||||
|
||||
g_once_init_leave (&ml_model_input_image_format, temp);
|
||||
}
|
||||
|
||||
return ml_model_input_image_format;
|
||||
}
|
||||
|
||||
static void
|
||||
gst_onnx_inference_class_init (GstOnnxInferenceClass * klass)
|
||||
{
|
||||
GObjectClass *gobject_class = (GObjectClass *) klass;
|
||||
GstElementClass *element_class = (GstElementClass *) klass;
|
||||
GstBaseTransformClass *basetransform_class = (GstBaseTransformClass *) klass;
|
||||
|
||||
GST_DEBUG_CATEGORY_INIT (onnx_inference_debug, "onnxinference",
|
||||
0, "onnx_inference");
|
||||
gobject_class->set_property = gst_onnx_inference_set_property;
|
||||
gobject_class->get_property = gst_onnx_inference_get_property;
|
||||
gobject_class->finalize = gst_onnx_inference_finalize;
|
||||
|
||||
/**
|
||||
* GstOnnxInference:model-file
|
||||
*
|
||||
* ONNX model file
|
||||
*
|
||||
* Since: 1.24
|
||||
*/
|
||||
g_object_class_install_property (G_OBJECT_CLASS (klass), PROP_MODEL_FILE,
|
||||
g_param_spec_string ("model-file",
|
||||
"ONNX model file", "ONNX model file", NULL, (GParamFlags)
|
||||
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
||||
|
||||
/**
|
||||
* GstOnnxInference:input-image-format
|
||||
*
|
||||
* Model input image format
|
||||
*
|
||||
* Since: 1.24
|
||||
*/
|
||||
g_object_class_install_property (G_OBJECT_CLASS (klass),
|
||||
PROP_INPUT_IMAGE_FORMAT,
|
||||
g_param_spec_enum ("input-image-format",
|
||||
"Input image format",
|
||||
"Input image format",
|
||||
GST_TYPE_ML_MODEL_INPUT_IMAGE_FORMAT,
|
||||
GST_ML_INPUT_IMAGE_FORMAT_HWC, (GParamFlags)
|
||||
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
||||
|
||||
/**
|
||||
* GstOnnxInference:optimization-level
|
||||
*
|
||||
* ONNX optimization level
|
||||
*
|
||||
* Since: 1.24
|
||||
*/
|
||||
g_object_class_install_property (G_OBJECT_CLASS (klass),
|
||||
PROP_OPTIMIZATION_LEVEL,
|
||||
g_param_spec_enum ("optimization-level",
|
||||
"Optimization level",
|
||||
"ONNX optimization level",
|
||||
GST_TYPE_ONNX_OPTIMIZATION_LEVEL,
|
||||
GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED, (GParamFlags)
|
||||
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
||||
|
||||
/**
|
||||
* GstOnnxInference:execution-provider
|
||||
*
|
||||
* ONNX execution provider
|
||||
*
|
||||
* Since: 1.24
|
||||
*/
|
||||
g_object_class_install_property (G_OBJECT_CLASS (klass),
|
||||
PROP_EXECUTION_PROVIDER,
|
||||
g_param_spec_enum ("execution-provider",
|
||||
"Execution provider",
|
||||
"ONNX execution provider",
|
||||
GST_TYPE_ONNX_EXECUTION_PROVIDER,
|
||||
GST_ONNX_EXECUTION_PROVIDER_CPU, (GParamFlags)
|
||||
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
||||
|
||||
gst_element_class_set_static_metadata (element_class, "onnxinference",
|
||||
"Filter/Effect/Video",
|
||||
"Apply neural network to video frames and create tensor output",
|
||||
"Aaron Boxer <aaron.boxer@collabora.com>");
|
||||
gst_element_class_add_pad_template (element_class,
|
||||
gst_static_pad_template_get (&gst_onnx_inference_sink_template));
|
||||
gst_element_class_add_pad_template (element_class,
|
||||
gst_static_pad_template_get (&gst_onnx_inference_src_template));
|
||||
basetransform_class->transform_ip =
|
||||
GST_DEBUG_FUNCPTR (gst_onnx_inference_transform_ip);
|
||||
basetransform_class->transform_caps =
|
||||
GST_DEBUG_FUNCPTR (gst_onnx_inference_transform_caps);
|
||||
basetransform_class->set_caps =
|
||||
GST_DEBUG_FUNCPTR (gst_onnx_inference_set_caps);
|
||||
}
|
||||
|
||||
static void
|
||||
gst_onnx_inference_init (GstOnnxInference * self)
|
||||
{
|
||||
self->onnx_client = new GstOnnxNamespace::GstOnnxClient ();
|
||||
self->onnx_disabled = TRUE;
|
||||
}
|
||||
|
||||
static void
|
||||
gst_onnx_inference_finalize (GObject * object)
|
||||
{
|
||||
GstOnnxInference *self = GST_ONNX_INFERENCE (object);
|
||||
|
||||
g_free (self->model_file);
|
||||
delete GST_ONNX_CLIENT_MEMBER (self);
|
||||
G_OBJECT_CLASS (gst_onnx_inference_parent_class)->finalize (object);
|
||||
}
|
||||
|
||||
static void
|
||||
gst_onnx_inference_set_property (GObject * object, guint prop_id,
|
||||
const GValue * value, GParamSpec * pspec)
|
||||
{
|
||||
GstOnnxInference *self = GST_ONNX_INFERENCE (object);
|
||||
const gchar *filename;
|
||||
auto onnxClient = GST_ONNX_CLIENT_MEMBER (self);
|
||||
|
||||
switch (prop_id) {
|
||||
case PROP_MODEL_FILE:
|
||||
filename = g_value_get_string (value);
|
||||
if (filename
|
||||
&& g_file_test (filename,
|
||||
(GFileTest) (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR))) {
|
||||
if (self->model_file)
|
||||
g_free (self->model_file);
|
||||
self->model_file = g_strdup (filename);
|
||||
self->onnx_disabled = FALSE;
|
||||
} else {
|
||||
GST_WARNING_OBJECT (self, "Model file '%s' not found!", filename);
|
||||
}
|
||||
break;
|
||||
case PROP_OPTIMIZATION_LEVEL:
|
||||
self->optimization_level =
|
||||
(GstOnnxOptimizationLevel) g_value_get_enum (value);
|
||||
break;
|
||||
case PROP_EXECUTION_PROVIDER:
|
||||
self->execution_provider =
|
||||
(GstOnnxExecutionProvider) g_value_get_enum (value);
|
||||
break;
|
||||
case PROP_INPUT_IMAGE_FORMAT:
|
||||
onnxClient->setInputImageFormat ((GstMlInputImageFormat)
|
||||
g_value_get_enum (value));
|
||||
break;
|
||||
default:
|
||||
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
gst_onnx_inference_get_property (GObject * object, guint prop_id,
|
||||
GValue * value, GParamSpec * pspec)
|
||||
{
|
||||
GstOnnxInference *self = GST_ONNX_INFERENCE (object);
|
||||
auto onnxClient = GST_ONNX_CLIENT_MEMBER (self);
|
||||
|
||||
switch (prop_id) {
|
||||
case PROP_MODEL_FILE:
|
||||
g_value_set_string (value, self->model_file);
|
||||
break;
|
||||
case PROP_OPTIMIZATION_LEVEL:
|
||||
g_value_set_enum (value, self->optimization_level);
|
||||
break;
|
||||
case PROP_EXECUTION_PROVIDER:
|
||||
g_value_set_enum (value, self->execution_provider);
|
||||
break;
|
||||
case PROP_INPUT_IMAGE_FORMAT:
|
||||
g_value_set_enum (value, onnxClient->getInputImageFormat ());
|
||||
break;
|
||||
default:
|
||||
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static gboolean
|
||||
gst_onnx_inference_create_session (GstBaseTransform * trans)
|
||||
{
|
||||
GstOnnxInference *self = GST_ONNX_INFERENCE (trans);
|
||||
auto onnxClient = GST_ONNX_CLIENT_MEMBER (self);
|
||||
|
||||
GST_OBJECT_LOCK (self);
|
||||
if (self->onnx_disabled) {
|
||||
GST_OBJECT_UNLOCK (self);
|
||||
|
||||
return FALSE;
|
||||
}
|
||||
if (onnxClient->hasSession ()) {
|
||||
GST_OBJECT_UNLOCK (self);
|
||||
|
||||
return TRUE;
|
||||
}
|
||||
if (self->model_file) {
|
||||
gboolean ret =
|
||||
GST_ONNX_CLIENT_MEMBER (self)->createSession (self->model_file,
|
||||
self->optimization_level,
|
||||
self->execution_provider);
|
||||
if (!ret) {
|
||||
GST_ERROR_OBJECT (self,
|
||||
"Unable to create ONNX session. Model is disabled.");
|
||||
self->onnx_disabled = TRUE;
|
||||
}
|
||||
} else {
|
||||
self->onnx_disabled = TRUE;
|
||||
GST_ELEMENT_ERROR (self, STREAM, FAILED, (NULL), ("Model file not found"));
|
||||
}
|
||||
GST_OBJECT_UNLOCK (self);
|
||||
if (self->onnx_disabled) {
|
||||
gst_base_transform_set_passthrough (GST_BASE_TRANSFORM (self), TRUE);
|
||||
}
|
||||
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
static GstCaps *
|
||||
gst_onnx_inference_transform_caps (GstBaseTransform *
|
||||
trans, GstPadDirection direction, GstCaps * caps, GstCaps * filter_caps)
|
||||
{
|
||||
GstOnnxInference *self = GST_ONNX_INFERENCE (trans);
|
||||
auto onnxClient = GST_ONNX_CLIENT_MEMBER (self);
|
||||
GstCaps *other_caps;
|
||||
guint i;
|
||||
|
||||
if (!gst_onnx_inference_create_session (trans))
|
||||
return NULL;
|
||||
GST_LOG_OBJECT (self, "transforming caps %" GST_PTR_FORMAT, caps);
|
||||
|
||||
if (gst_base_transform_is_passthrough (trans)
|
||||
|| (!onnxClient->isFixedInputImageSize ()))
|
||||
return gst_caps_ref (caps);
|
||||
|
||||
other_caps = gst_caps_new_empty ();
|
||||
for (i = 0; i < gst_caps_get_size (caps); ++i) {
|
||||
GstStructure *structure, *new_structure;
|
||||
|
||||
structure = gst_caps_get_structure (caps, i);
|
||||
new_structure = gst_structure_copy (structure);
|
||||
gst_structure_set (new_structure, "width", G_TYPE_INT,
|
||||
onnxClient->getWidth (), "height", G_TYPE_INT,
|
||||
onnxClient->getHeight (), NULL);
|
||||
GST_LOG_OBJECT (self,
|
||||
"transformed structure %2d: %" GST_PTR_FORMAT " => %"
|
||||
GST_PTR_FORMAT, i, structure, new_structure);
|
||||
gst_caps_append_structure (other_caps, new_structure);
|
||||
}
|
||||
|
||||
if (!gst_caps_is_empty (other_caps) && filter_caps) {
|
||||
GstCaps *tmp = gst_caps_intersect_full (other_caps, filter_caps,
|
||||
GST_CAPS_INTERSECT_FIRST);
|
||||
gst_caps_replace (&other_caps, tmp);
|
||||
gst_caps_unref (tmp);
|
||||
}
|
||||
|
||||
return other_caps;
|
||||
}
|
||||
|
||||
static gboolean
|
||||
gst_onnx_inference_set_caps (GstBaseTransform * trans, GstCaps * incaps,
|
||||
GstCaps * outcaps)
|
||||
{
|
||||
GstOnnxInference *self = GST_ONNX_INFERENCE (trans);
|
||||
auto onnxClient = GST_ONNX_CLIENT_MEMBER (self);
|
||||
|
||||
if (!gst_video_info_from_caps (&self->video_info, incaps)) {
|
||||
GST_ERROR_OBJECT (self, "Failed to parse caps");
|
||||
return FALSE;
|
||||
}
|
||||
|
||||
onnxClient->parseDimensions (self->video_info);
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
static GstFlowReturn
|
||||
gst_onnx_inference_transform_ip (GstBaseTransform * trans, GstBuffer * buf)
|
||||
{
|
||||
if (!gst_base_transform_is_passthrough (trans)
|
||||
&& !gst_onnx_inference_process (trans, buf)) {
|
||||
GST_ELEMENT_ERROR (trans, STREAM, FAILED,
|
||||
(NULL), ("ONNX inference failed"));
|
||||
return GST_FLOW_ERROR;
|
||||
}
|
||||
|
||||
return GST_FLOW_OK;
|
||||
}
|
||||
|
||||
static gboolean
|
||||
gst_onnx_inference_process (GstBaseTransform * trans, GstBuffer * buf)
|
||||
{
|
||||
GstMapInfo info;
|
||||
GstVideoMeta *vmeta = gst_buffer_get_video_meta (buf);
|
||||
|
||||
if (!vmeta) {
|
||||
GST_WARNING_OBJECT (trans, "missing video meta");
|
||||
return FALSE;
|
||||
}
|
||||
if (gst_buffer_map (buf, &info, GST_MAP_READ)) {
|
||||
GstOnnxInference *self = GST_ONNX_INFERENCE (trans);
|
||||
try {
|
||||
auto client = GST_ONNX_CLIENT_MEMBER (self);
|
||||
auto outputs = client->run (info.data, self->video_info);
|
||||
auto meta = client->copy_tensors_to_meta (outputs, buf);
|
||||
if (!meta)
|
||||
return FALSE;
|
||||
meta->batch_size = 1;
|
||||
}
|
||||
catch (Ort::Exception & ortex) {
|
||||
GST_ERROR_OBJECT (self, "%s", ortex.what ());
|
||||
return FALSE;
|
||||
}
|
||||
|
||||
gst_buffer_unmap (buf, &info);
|
||||
}
|
||||
|
||||
return TRUE;
|
||||
}
|
64
subprojects/gst-plugins-bad/ext/onnx/gstonnxinference.h
Normal file
64
subprojects/gst-plugins-bad/ext/onnx/gstonnxinference.h
Normal file
|
@ -0,0 +1,64 @@
|
|||
/*
|
||||
* GStreamer gstreamer-onnxinference
|
||||
* Copyright (C) 2023 Collabora Ltd
|
||||
*
|
||||
* gstonnxinference.h
|
||||
*
|
||||
* This library is free software; you can redistribute it and/or
|
||||
* modify it under the terms of the GNU Library General Public
|
||||
* License as published by the Free Software Foundation; either
|
||||
* version 2 of the License, or (at your option) any later version.
|
||||
*
|
||||
* This library is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||
* Library General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Library General Public
|
||||
* License along with this library; if not, write to the
|
||||
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
|
||||
* Boston, MA 02110-1301, USA.
|
||||
*/
|
||||
|
||||
#ifndef __GST_ONNX_INFERENCE_H__
|
||||
#define __GST_ONNX_INFERENCE_H__
|
||||
|
||||
#include <gst/gst.h>
|
||||
#include <gst/video/video.h>
|
||||
#include <gst/video/gstvideofilter.h>
|
||||
#include "gstonnxenums.h"
|
||||
|
||||
G_BEGIN_DECLS
|
||||
|
||||
#define GST_TYPE_ONNX_INFERENCE (gst_onnx_inference_get_type())
|
||||
G_DECLARE_FINAL_TYPE (GstOnnxInference, gst_onnx_inference, GST,
|
||||
ONNX_INFERENCE, GstBaseTransform)
|
||||
|
||||
/**
|
||||
* GstOnnxInference:
|
||||
*
|
||||
* @model_file model file
|
||||
* @optimization_level: ONNX session optimization level
|
||||
* @execution_provider: ONNX execution provider
|
||||
* @onnx_client opaque pointer to ONNX client
|
||||
* @onnx_disabled true if inference is disabled
|
||||
* @video_info @ref GstVideoInfo of sink caps
|
||||
*
|
||||
* Since: 1.24
|
||||
*/
|
||||
struct _GstOnnxInference
|
||||
{
|
||||
GstBaseTransform basetransform;
|
||||
gchar *model_file;
|
||||
GstOnnxOptimizationLevel optimization_level;
|
||||
GstOnnxExecutionProvider execution_provider;
|
||||
gpointer onnx_client;
|
||||
gboolean onnx_disabled;
|
||||
GstVideoInfo video_info;
|
||||
};
|
||||
|
||||
GST_ELEMENT_REGISTER_DECLARE (onnx_inference)
|
||||
|
||||
G_END_DECLS
|
||||
|
||||
#endif /* __GST_ONNX_INFERENCE_H__ */
|
|
@ -1,684 +0,0 @@
|
|||
/*
|
||||
* GStreamer gstreamer-onnxobjectdetector
|
||||
* Copyright (C) 2021 Collabora Ltd.
|
||||
*
|
||||
* gstonnxobjectdetector.c
|
||||
*
|
||||
* This library is free software; you can redistribute it and/or
|
||||
* modify it under the terms of the GNU Library General Public
|
||||
* License as published by the Free Software Foundation; either
|
||||
* version 2 of the License, or (at your option) any later version.
|
||||
*
|
||||
* This library is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||
* Library General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Library General Public
|
||||
* License along with this library; if not, write to the
|
||||
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
|
||||
* Boston, MA 02110-1301, USA.
|
||||
*/
|
||||
|
||||
/**
|
||||
* SECTION:element-onnxobjectdetector
|
||||
* @short_description: Detect objects in video frame
|
||||
*
|
||||
* This element can apply a generic ONNX object detection model such as YOLO or SSD
|
||||
* to each video frame.
|
||||
*
|
||||
* To install ONNX on your system, recursively clone this repository
|
||||
* https://github.com/microsoft/onnxruntime.git
|
||||
*
|
||||
* and build and install with cmake:
|
||||
*
|
||||
* CPU:
|
||||
*
|
||||
* cmake -Donnxruntime_BUILD_SHARED_LIB:ON -DBUILD_TESTING:OFF \
|
||||
* $SRC_DIR/onnxruntime/cmake && make -j8 && sudo make install
|
||||
*
|
||||
*
|
||||
* GPU :
|
||||
*
|
||||
* cmake -Donnxruntime_BUILD_SHARED_LIB:ON -DBUILD_TESTING:OFF -Donnxruntime_USE_CUDA:ON \
|
||||
* -Donnxruntime_CUDA_HOME=$CUDA_PATH -Donnxruntime_CUDNN_HOME=$CUDA_PATH \
|
||||
* $SRC_DIR/onnxruntime/cmake && make -j8 && sudo make install
|
||||
*
|
||||
*
|
||||
* where :
|
||||
*
|
||||
* 1. $SRC_DIR and $BUILD_DIR are local source and build directories
|
||||
* 2. To run with CUDA, both CUDA and cuDNN libraries must be installed.
|
||||
* $CUDA_PATH is an environment variable set to the CUDA root path.
|
||||
* On Linux, it would be /usr/local/cuda-XX.X where XX.X is the installed version of CUDA.
|
||||
*
|
||||
*
|
||||
* ## Example launch command:
|
||||
*
|
||||
* (note: an object detection model has 3 or 4 output nodes, but there is no
|
||||
* naming convention to indicate which node outputs the bounding box, which
|
||||
* node outputs the label, etc. So, the `onnxobjectdetector` element has
|
||||
* properties to map each node's functionality to its respective node index in
|
||||
* the specified model. Image resolution also need to be adapted to the model.
|
||||
* The videoscale in the pipeline below will scale the image, using padding if
|
||||
* required, to 640x383 resolution required by the model.)
|
||||
*
|
||||
* model.onnx can be found here:
|
||||
* https://github.com/zoq/onnx-runtime-examples/raw/main/data/models/model.onnx
|
||||
*
|
||||
* ```
|
||||
* GST_DEBUG=objectdetector:5 gst-launch-1.0 multifilesrc \
|
||||
* location=000000088462.jpg caps=image/jpeg,framerate=\(fraction\)30/1 ! jpegdec ! \
|
||||
* videoconvert ! \
|
||||
* videoscale ! \
|
||||
* 'video/x-raw,width=640,height=383' ! \
|
||||
* onnxobjectdetector \
|
||||
* box-node-index=0 \
|
||||
* class-node-index=1 \
|
||||
* score-node-index=2 \
|
||||
* detection-node-index=3 \
|
||||
* execution-provider=cpu \
|
||||
* model-file=model.onnx \
|
||||
* label-file=COCO_classes.txt ! \
|
||||
* videoconvert ! \
|
||||
* autovideosink
|
||||
* ```
|
||||
*/
|
||||
|
||||
#ifdef HAVE_CONFIG_H
|
||||
#include "config.h"
|
||||
#endif
|
||||
|
||||
#include "gstonnxobjectdetector.h"
|
||||
#include "gstonnxclient.h"
|
||||
|
||||
#include <gst/gst.h>
|
||||
#include <gst/video/video.h>
|
||||
#include <gst/video/gstvideometa.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <glib.h>
|
||||
|
||||
GST_DEBUG_CATEGORY_STATIC (onnx_object_detector_debug);
|
||||
#define GST_CAT_DEFAULT onnx_object_detector_debug
|
||||
#define GST_ONNX_MEMBER( self ) ((GstOnnxNamespace::GstOnnxClient *) (self->onnx_ptr))
|
||||
GST_ELEMENT_REGISTER_DEFINE (onnx_object_detector, "onnxobjectdetector",
|
||||
GST_RANK_PRIMARY, GST_TYPE_ONNX_OBJECT_DETECTOR);
|
||||
|
||||
/* GstOnnxObjectDetector properties */
|
||||
enum
|
||||
{
|
||||
PROP_0,
|
||||
PROP_MODEL_FILE,
|
||||
PROP_LABEL_FILE,
|
||||
PROP_SCORE_THRESHOLD,
|
||||
PROP_DETECTION_NODE_INDEX,
|
||||
PROP_BOUNDING_BOX_NODE_INDEX,
|
||||
PROP_SCORE_NODE_INDEX,
|
||||
PROP_CLASS_NODE_INDEX,
|
||||
PROP_INPUT_IMAGE_FORMAT,
|
||||
PROP_OPTIMIZATION_LEVEL,
|
||||
PROP_EXECUTION_PROVIDER
|
||||
};
|
||||
|
||||
|
||||
#define GST_ONNX_OBJECT_DETECTOR_DEFAULT_EXECUTION_PROVIDER GST_ONNX_EXECUTION_PROVIDER_CPU
|
||||
#define GST_ONNX_OBJECT_DETECTOR_DEFAULT_OPTIMIZATION_LEVEL GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED
|
||||
#define GST_ONNX_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD 0.3f /* 0 to 1 */
|
||||
|
||||
static GstStaticPadTemplate gst_onnx_object_detector_src_template =
|
||||
GST_STATIC_PAD_TEMPLATE ("src",
|
||||
GST_PAD_SRC,
|
||||
GST_PAD_ALWAYS,
|
||||
GST_STATIC_CAPS (GST_VIDEO_CAPS_MAKE ("{ RGB,RGBA,BGR,BGRA }"))
|
||||
);
|
||||
|
||||
static GstStaticPadTemplate gst_onnx_object_detector_sink_template =
|
||||
GST_STATIC_PAD_TEMPLATE ("sink",
|
||||
GST_PAD_SINK,
|
||||
GST_PAD_ALWAYS,
|
||||
GST_STATIC_CAPS (GST_VIDEO_CAPS_MAKE ("{ RGB,RGBA,BGR,BGRA }"))
|
||||
);
|
||||
|
||||
static void gst_onnx_object_detector_set_property (GObject * object,
|
||||
guint prop_id, const GValue * value, GParamSpec * pspec);
|
||||
static void gst_onnx_object_detector_get_property (GObject * object,
|
||||
guint prop_id, GValue * value, GParamSpec * pspec);
|
||||
static void gst_onnx_object_detector_finalize (GObject * object);
|
||||
static GstFlowReturn gst_onnx_object_detector_transform_ip (GstBaseTransform *
|
||||
trans, GstBuffer * buf);
|
||||
static gboolean gst_onnx_object_detector_process (GstBaseTransform * trans,
|
||||
GstBuffer * buf);
|
||||
static gboolean gst_onnx_object_detector_create_session (GstBaseTransform * trans);
|
||||
static GstCaps *gst_onnx_object_detector_transform_caps (GstBaseTransform *
|
||||
trans, GstPadDirection direction, GstCaps * caps, GstCaps * filter_caps);
|
||||
|
||||
G_DEFINE_TYPE (GstOnnxObjectDetector, gst_onnx_object_detector,
|
||||
GST_TYPE_BASE_TRANSFORM);
|
||||
|
||||
static void
|
||||
gst_onnx_object_detector_class_init (GstOnnxObjectDetectorClass * klass)
|
||||
{
|
||||
GObjectClass *gobject_class = (GObjectClass *) klass;
|
||||
GstElementClass *element_class = (GstElementClass *) klass;
|
||||
GstBaseTransformClass *basetransform_class = (GstBaseTransformClass *) klass;
|
||||
|
||||
GST_DEBUG_CATEGORY_INIT (onnx_object_detector_debug, "onnxobjectdetector",
|
||||
0, "onnx_objectdetector");
|
||||
gobject_class->set_property = gst_onnx_object_detector_set_property;
|
||||
gobject_class->get_property = gst_onnx_object_detector_get_property;
|
||||
gobject_class->finalize = gst_onnx_object_detector_finalize;
|
||||
|
||||
/**
|
||||
* GstOnnxObjectDetector:model-file
|
||||
*
|
||||
* ONNX model file
|
||||
*
|
||||
* Since: 1.20
|
||||
*/
|
||||
g_object_class_install_property (G_OBJECT_CLASS (klass), PROP_MODEL_FILE,
|
||||
g_param_spec_string ("model-file",
|
||||
"ONNX model file", "ONNX model file", NULL, (GParamFlags)
|
||||
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
||||
|
||||
/**
|
||||
* GstOnnxObjectDetector:label-file
|
||||
*
|
||||
* Label file for ONNX model
|
||||
*
|
||||
* Since: 1.20
|
||||
*/
|
||||
g_object_class_install_property (G_OBJECT_CLASS (klass), PROP_LABEL_FILE,
|
||||
g_param_spec_string ("label-file",
|
||||
"Label file", "Label file associated with model", NULL, (GParamFlags)
|
||||
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
||||
|
||||
|
||||
/**
|
||||
* GstOnnxObjectDetector:detection-node-index
|
||||
*
|
||||
* Index of model detection node
|
||||
*
|
||||
* Since: 1.20
|
||||
*/
|
||||
g_object_class_install_property (G_OBJECT_CLASS (klass),
|
||||
PROP_DETECTION_NODE_INDEX,
|
||||
g_param_spec_int ("detection-node-index",
|
||||
"Detection node index",
|
||||
"Index of neural network output node corresponding to number of detected objects",
|
||||
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED,
|
||||
GstOnnxNamespace::GST_ML_OUTPUT_NODE_NUMBER_OF-1,
|
||||
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED, (GParamFlags)
|
||||
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
||||
|
||||
|
||||
/**
|
||||
* GstOnnxObjectDetector:bounding-box-node-index
|
||||
*
|
||||
* Index of model bounding box node
|
||||
*
|
||||
* Since: 1.20
|
||||
*/
|
||||
g_object_class_install_property (G_OBJECT_CLASS (klass),
|
||||
PROP_BOUNDING_BOX_NODE_INDEX,
|
||||
g_param_spec_int ("box-node-index",
|
||||
"Bounding box node index",
|
||||
"Index of neural network output node corresponding to bounding box",
|
||||
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED,
|
||||
GstOnnxNamespace::GST_ML_OUTPUT_NODE_NUMBER_OF-1,
|
||||
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED, (GParamFlags)
|
||||
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
||||
|
||||
/**
|
||||
* GstOnnxObjectDetector:score-node-index
|
||||
*
|
||||
* Index of model score node
|
||||
*
|
||||
* Since: 1.20
|
||||
*/
|
||||
g_object_class_install_property (G_OBJECT_CLASS (klass),
|
||||
PROP_SCORE_NODE_INDEX,
|
||||
g_param_spec_int ("score-node-index",
|
||||
"Score node index",
|
||||
"Index of neural network output node corresponding to score",
|
||||
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED,
|
||||
GstOnnxNamespace::GST_ML_OUTPUT_NODE_NUMBER_OF-1,
|
||||
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED, (GParamFlags)
|
||||
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
||||
|
||||
/**
|
||||
* GstOnnxObjectDetector:class-node-index
|
||||
*
|
||||
* Index of model class (label) node
|
||||
*
|
||||
* Since: 1.20
|
||||
*/
|
||||
g_object_class_install_property (G_OBJECT_CLASS (klass),
|
||||
PROP_CLASS_NODE_INDEX,
|
||||
g_param_spec_int ("class-node-index",
|
||||
"Class node index",
|
||||
"Index of neural network output node corresponding to class (label)",
|
||||
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED,
|
||||
GstOnnxNamespace::GST_ML_OUTPUT_NODE_NUMBER_OF-1,
|
||||
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED, (GParamFlags)
|
||||
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
||||
|
||||
|
||||
/**
|
||||
* GstOnnxObjectDetector:score-threshold
|
||||
*
|
||||
* Threshold for deciding when to remove boxes based on score
|
||||
*
|
||||
* Since: 1.20
|
||||
*/
|
||||
g_object_class_install_property (G_OBJECT_CLASS (klass), PROP_SCORE_THRESHOLD,
|
||||
g_param_spec_float ("score-threshold",
|
||||
"Score threshold",
|
||||
"Threshold for deciding when to remove boxes based on score",
|
||||
0.0, 1.0,
|
||||
GST_ONNX_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD, (GParamFlags)
|
||||
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
||||
|
||||
/**
|
||||
* GstOnnxObjectDetector:input-image-format
|
||||
*
|
||||
* Model input image format
|
||||
*
|
||||
* Since: 1.20
|
||||
*/
|
||||
g_object_class_install_property (G_OBJECT_CLASS (klass),
|
||||
PROP_INPUT_IMAGE_FORMAT,
|
||||
g_param_spec_enum ("input-image-format",
|
||||
"Input image format",
|
||||
"Input image format",
|
||||
GST_TYPE_ML_MODEL_INPUT_IMAGE_FORMAT,
|
||||
GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC, (GParamFlags)
|
||||
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
||||
|
||||
/**
|
||||
* GstOnnxObjectDetector:optimization-level
|
||||
*
|
||||
* ONNX optimization level
|
||||
*
|
||||
* Since: 1.20
|
||||
*/
|
||||
g_object_class_install_property (G_OBJECT_CLASS (klass),
|
||||
PROP_OPTIMIZATION_LEVEL,
|
||||
g_param_spec_enum ("optimization-level",
|
||||
"Optimization level",
|
||||
"ONNX optimization level",
|
||||
GST_TYPE_ONNX_OPTIMIZATION_LEVEL,
|
||||
GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED, (GParamFlags)
|
||||
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
||||
|
||||
/**
|
||||
* GstOnnxObjectDetector:execution-provider
|
||||
*
|
||||
* ONNX execution provider
|
||||
*
|
||||
* Since: 1.20
|
||||
*/
|
||||
g_object_class_install_property (G_OBJECT_CLASS (klass),
|
||||
PROP_EXECUTION_PROVIDER,
|
||||
g_param_spec_enum ("execution-provider",
|
||||
"Execution provider",
|
||||
"ONNX execution provider",
|
||||
GST_TYPE_ONNX_EXECUTION_PROVIDER,
|
||||
GST_ONNX_EXECUTION_PROVIDER_CPU, (GParamFlags)
|
||||
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
||||
|
||||
gst_element_class_set_static_metadata (element_class, "onnxobjectdetector",
|
||||
"Filter/Effect/Video",
|
||||
"Apply neural network to detect objects in video frames",
|
||||
"Aaron Boxer <aaron.boxer@collabora.com>, Marcus Edel <marcus.edel@collabora.com>");
|
||||
gst_element_class_add_pad_template (element_class,
|
||||
gst_static_pad_template_get (&gst_onnx_object_detector_sink_template));
|
||||
gst_element_class_add_pad_template (element_class,
|
||||
gst_static_pad_template_get (&gst_onnx_object_detector_src_template));
|
||||
basetransform_class->transform_ip =
|
||||
GST_DEBUG_FUNCPTR (gst_onnx_object_detector_transform_ip);
|
||||
basetransform_class->transform_caps =
|
||||
GST_DEBUG_FUNCPTR (gst_onnx_object_detector_transform_caps);
|
||||
}
|
||||
|
||||
static void
|
||||
gst_onnx_object_detector_init (GstOnnxObjectDetector * self)
|
||||
{
|
||||
self->onnx_ptr = new GstOnnxNamespace::GstOnnxClient ();
|
||||
self->onnx_disabled = false;
|
||||
}
|
||||
|
||||
static void
|
||||
gst_onnx_object_detector_finalize (GObject * object)
|
||||
{
|
||||
GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (object);
|
||||
|
||||
g_free (self->model_file);
|
||||
delete GST_ONNX_MEMBER (self);
|
||||
G_OBJECT_CLASS (gst_onnx_object_detector_parent_class)->finalize (object);
|
||||
}
|
||||
|
||||
static void
|
||||
gst_onnx_object_detector_set_property (GObject * object, guint prop_id,
|
||||
const GValue * value, GParamSpec * pspec)
|
||||
{
|
||||
GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (object);
|
||||
const gchar *filename;
|
||||
auto onnxClient = GST_ONNX_MEMBER (self);
|
||||
|
||||
switch (prop_id) {
|
||||
case PROP_MODEL_FILE:
|
||||
filename = g_value_get_string (value);
|
||||
if (filename
|
||||
&& g_file_test (filename,
|
||||
(GFileTest) (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR))) {
|
||||
if (self->model_file)
|
||||
g_free (self->model_file);
|
||||
self->model_file = g_strdup (filename);
|
||||
} else {
|
||||
GST_WARNING_OBJECT (self, "Model file '%s' not found!", filename);
|
||||
gst_base_transform_set_passthrough (GST_BASE_TRANSFORM (self), TRUE);
|
||||
}
|
||||
break;
|
||||
case PROP_LABEL_FILE:
|
||||
filename = g_value_get_string (value);
|
||||
if (filename
|
||||
&& g_file_test (filename,
|
||||
(GFileTest) (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR))) {
|
||||
if (self->label_file)
|
||||
g_free (self->label_file);
|
||||
self->label_file = g_strdup (filename);
|
||||
} else {
|
||||
GST_WARNING_OBJECT (self, "Label file '%s' not found!", filename);
|
||||
}
|
||||
break;
|
||||
case PROP_SCORE_THRESHOLD:
|
||||
GST_OBJECT_LOCK (self);
|
||||
self->score_threshold = g_value_get_float (value);
|
||||
GST_OBJECT_UNLOCK (self);
|
||||
break;
|
||||
case PROP_OPTIMIZATION_LEVEL:
|
||||
self->optimization_level =
|
||||
(GstOnnxOptimizationLevel) g_value_get_enum (value);
|
||||
break;
|
||||
case PROP_EXECUTION_PROVIDER:
|
||||
self->execution_provider =
|
||||
(GstOnnxExecutionProvider) g_value_get_enum (value);
|
||||
break;
|
||||
case PROP_DETECTION_NODE_INDEX:
|
||||
onnxClient->setOutputNodeIndex
|
||||
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_DETECTION,
|
||||
g_value_get_int (value));
|
||||
break;
|
||||
case PROP_BOUNDING_BOX_NODE_INDEX:
|
||||
onnxClient->setOutputNodeIndex
|
||||
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX,
|
||||
g_value_get_int (value));
|
||||
break;
|
||||
break;
|
||||
case PROP_SCORE_NODE_INDEX:
|
||||
onnxClient->setOutputNodeIndex
|
||||
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_SCORE,
|
||||
g_value_get_int (value));
|
||||
break;
|
||||
break;
|
||||
case PROP_CLASS_NODE_INDEX:
|
||||
onnxClient->setOutputNodeIndex
|
||||
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_CLASS,
|
||||
g_value_get_int (value));
|
||||
break;
|
||||
case PROP_INPUT_IMAGE_FORMAT:
|
||||
onnxClient->setInputImageFormat ((GstMlModelInputImageFormat)
|
||||
g_value_get_enum (value));
|
||||
break;
|
||||
default:
|
||||
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
gst_onnx_object_detector_get_property (GObject * object, guint prop_id,
|
||||
GValue * value, GParamSpec * pspec)
|
||||
{
|
||||
GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (object);
|
||||
auto onnxClient = GST_ONNX_MEMBER (self);
|
||||
|
||||
switch (prop_id) {
|
||||
case PROP_MODEL_FILE:
|
||||
g_value_set_string (value, self->model_file);
|
||||
break;
|
||||
case PROP_LABEL_FILE:
|
||||
g_value_set_string (value, self->label_file);
|
||||
break;
|
||||
case PROP_SCORE_THRESHOLD:
|
||||
GST_OBJECT_LOCK (self);
|
||||
g_value_set_float (value, self->score_threshold);
|
||||
GST_OBJECT_UNLOCK (self);
|
||||
break;
|
||||
case PROP_OPTIMIZATION_LEVEL:
|
||||
g_value_set_enum (value, self->optimization_level);
|
||||
break;
|
||||
case PROP_EXECUTION_PROVIDER:
|
||||
g_value_set_enum (value, self->execution_provider);
|
||||
break;
|
||||
case PROP_DETECTION_NODE_INDEX:
|
||||
g_value_set_int (value,
|
||||
onnxClient->getOutputNodeIndex
|
||||
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_DETECTION));
|
||||
break;
|
||||
case PROP_BOUNDING_BOX_NODE_INDEX:
|
||||
g_value_set_int (value,
|
||||
onnxClient->getOutputNodeIndex
|
||||
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX));
|
||||
break;
|
||||
break;
|
||||
case PROP_SCORE_NODE_INDEX:
|
||||
g_value_set_int (value,
|
||||
onnxClient->getOutputNodeIndex
|
||||
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_SCORE));
|
||||
break;
|
||||
break;
|
||||
case PROP_CLASS_NODE_INDEX:
|
||||
g_value_set_int (value,
|
||||
onnxClient->getOutputNodeIndex
|
||||
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_CLASS));
|
||||
break;
|
||||
case PROP_INPUT_IMAGE_FORMAT:
|
||||
g_value_set_enum (value, onnxClient->getInputImageFormat ());
|
||||
break;
|
||||
default:
|
||||
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static gboolean
|
||||
gst_onnx_object_detector_create_session (GstBaseTransform * trans)
|
||||
{
|
||||
GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (trans);
|
||||
auto onnxClient = GST_ONNX_MEMBER (self);
|
||||
|
||||
GST_OBJECT_LOCK (self);
|
||||
if (self->onnx_disabled || onnxClient->hasSession ()) {
|
||||
GST_OBJECT_UNLOCK (self);
|
||||
|
||||
return TRUE;
|
||||
}
|
||||
if (self->model_file) {
|
||||
gboolean ret = GST_ONNX_MEMBER (self)->createSession (self->model_file,
|
||||
self->optimization_level,
|
||||
self->execution_provider);
|
||||
if (!ret) {
|
||||
GST_ERROR_OBJECT (self,
|
||||
"Unable to create ONNX session. Detection disabled.");
|
||||
} else {
|
||||
auto outputNames = onnxClient->getOutputNodeNames ();
|
||||
|
||||
for (size_t i = 0; i < outputNames.size (); ++i)
|
||||
GST_INFO_OBJECT (self, "Output node index: %d for node: %s", (gint) i,
|
||||
outputNames[i]);
|
||||
if (outputNames.size () < 3) {
|
||||
GST_ERROR_OBJECT (self,
|
||||
"Number of output tensor nodes %d does not match the 3 or 4 nodes "
|
||||
"required for an object detection model. Detection is disabled.",
|
||||
(gint) outputNames.size ());
|
||||
self->onnx_disabled = TRUE;
|
||||
}
|
||||
// sanity check on output node indices
|
||||
if (onnxClient->getOutputNodeIndex
|
||||
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_DETECTION) ==
|
||||
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED) {
|
||||
GST_ERROR_OBJECT (self,
|
||||
"Output detection node index not set. Detection disabled.");
|
||||
self->onnx_disabled = TRUE;
|
||||
}
|
||||
if (onnxClient->getOutputNodeIndex
|
||||
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX) ==
|
||||
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED) {
|
||||
GST_ERROR_OBJECT (self,
|
||||
"Output bounding box node index not set. Detection disabled.");
|
||||
self->onnx_disabled = TRUE;
|
||||
}
|
||||
if (onnxClient->getOutputNodeIndex
|
||||
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_SCORE) ==
|
||||
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED) {
|
||||
GST_ERROR_OBJECT (self,
|
||||
"Output score node index not set. Detection disabled.");
|
||||
self->onnx_disabled = TRUE;
|
||||
}
|
||||
if (outputNames.size () == 4 && onnxClient->getOutputNodeIndex
|
||||
(GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_CLASS) ==
|
||||
GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED) {
|
||||
GST_ERROR_OBJECT (self,
|
||||
"Output class node index not set. Detection disabled.");
|
||||
self->onnx_disabled = TRUE;
|
||||
}
|
||||
// model is not usable, so fail
|
||||
if (self->onnx_disabled) {
|
||||
GST_ELEMENT_WARNING (self, RESOURCE, FAILED,
|
||||
("ONNX model cannot be used for object detection"), (NULL));
|
||||
|
||||
return FALSE;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
self->onnx_disabled = TRUE;
|
||||
}
|
||||
GST_OBJECT_UNLOCK (self);
|
||||
if (self->onnx_disabled){
|
||||
gst_base_transform_set_passthrough (GST_BASE_TRANSFORM (self), TRUE);
|
||||
}
|
||||
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
|
||||
static GstCaps *
|
||||
gst_onnx_object_detector_transform_caps (GstBaseTransform *
|
||||
trans, GstPadDirection direction, GstCaps * caps, GstCaps * filter_caps)
|
||||
{
|
||||
GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (trans);
|
||||
auto onnxClient = GST_ONNX_MEMBER (self);
|
||||
GstCaps *other_caps;
|
||||
guint i;
|
||||
|
||||
if ( !gst_onnx_object_detector_create_session (trans) )
|
||||
return NULL;
|
||||
GST_LOG_OBJECT (self, "transforming caps %" GST_PTR_FORMAT, caps);
|
||||
|
||||
if (gst_base_transform_is_passthrough (trans)
|
||||
|| (!onnxClient->isFixedInputImageSize ()))
|
||||
return gst_caps_ref (caps);
|
||||
|
||||
other_caps = gst_caps_new_empty ();
|
||||
for (i = 0; i < gst_caps_get_size (caps); ++i) {
|
||||
GstStructure *structure, *new_structure;
|
||||
|
||||
structure = gst_caps_get_structure (caps, i);
|
||||
new_structure = gst_structure_copy (structure);
|
||||
gst_structure_set (new_structure, "width", G_TYPE_INT,
|
||||
onnxClient->getWidth (), "height", G_TYPE_INT,
|
||||
onnxClient->getHeight (), NULL);
|
||||
GST_LOG_OBJECT (self,
|
||||
"transformed structure %2d: %" GST_PTR_FORMAT " => %"
|
||||
GST_PTR_FORMAT, i, structure, new_structure);
|
||||
gst_caps_append_structure (other_caps, new_structure);
|
||||
}
|
||||
|
||||
if (!gst_caps_is_empty (other_caps) && filter_caps) {
|
||||
GstCaps *tmp = gst_caps_intersect_full (other_caps,filter_caps,
|
||||
GST_CAPS_INTERSECT_FIRST);
|
||||
gst_caps_replace (&other_caps, tmp);
|
||||
gst_caps_unref (tmp);
|
||||
}
|
||||
|
||||
return other_caps;
|
||||
}
|
||||
|
||||
|
||||
static GstFlowReturn
|
||||
gst_onnx_object_detector_transform_ip (GstBaseTransform * trans,
|
||||
GstBuffer * buf)
|
||||
{
|
||||
if (!gst_base_transform_is_passthrough (trans)
|
||||
&& !gst_onnx_object_detector_process (trans, buf)){
|
||||
GST_ELEMENT_WARNING (trans, STREAM, FAILED,
|
||||
("ONNX object detection failed"), (NULL));
|
||||
return GST_FLOW_ERROR;
|
||||
}
|
||||
|
||||
return GST_FLOW_OK;
|
||||
}
|
||||
|
||||
static gboolean
|
||||
gst_onnx_object_detector_process (GstBaseTransform * trans, GstBuffer * buf)
|
||||
{
|
||||
GstMapInfo info;
|
||||
GstVideoMeta *vmeta = gst_buffer_get_video_meta (buf);
|
||||
|
||||
if (!vmeta) {
|
||||
GST_WARNING_OBJECT (trans, "missing video meta");
|
||||
return FALSE;
|
||||
}
|
||||
if (gst_buffer_map (buf, &info, GST_MAP_READ)) {
|
||||
GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (trans);
|
||||
std::vector < GstOnnxNamespace::GstMlBoundingBox > boxes;
|
||||
try {
|
||||
boxes = GST_ONNX_MEMBER (self)->run (info.data, vmeta,
|
||||
self->label_file ? self->label_file : "", self->score_threshold);
|
||||
}
|
||||
catch (Ort::Exception & ortex) {
|
||||
GST_ERROR_OBJECT (self, "%s", ortex.what ());
|
||||
return FALSE;
|
||||
}
|
||||
for (auto & b:boxes) {
|
||||
auto vroi_meta = gst_buffer_add_video_region_of_interest_meta (buf,
|
||||
GST_ONNX_OBJECT_DETECTOR_META_NAME,
|
||||
b.x0, b.y0,
|
||||
b.width,
|
||||
b.height);
|
||||
if (!vroi_meta) {
|
||||
GST_WARNING_OBJECT (trans,
|
||||
"Unable to attach GstVideoRegionOfInterestMeta to buffer");
|
||||
return FALSE;
|
||||
}
|
||||
auto s = gst_structure_new (GST_ONNX_OBJECT_DETECTOR_META_PARAM_NAME,
|
||||
GST_ONNX_OBJECT_DETECTOR_META_FIELD_LABEL,
|
||||
G_TYPE_STRING,
|
||||
b.label.c_str (),
|
||||
GST_ONNX_OBJECT_DETECTOR_META_FIELD_SCORE,
|
||||
G_TYPE_DOUBLE,
|
||||
b.score,
|
||||
NULL);
|
||||
gst_video_region_of_interest_meta_add_param (vroi_meta, s);
|
||||
GST_DEBUG_OBJECT (self,
|
||||
"Object detected with label : %s, score: %f, bound box: (%f,%f,%f,%f) \n",
|
||||
b.label.c_str (), b.score, b.x0, b.y0,
|
||||
b.x0 + b.width, b.y0 + b.height);
|
||||
}
|
||||
gst_buffer_unmap (buf, &info);
|
||||
}
|
||||
|
||||
return TRUE;
|
||||
}
|
|
@ -1,93 +0,0 @@
|
|||
/*
|
||||
* GStreamer gstreamer-onnxobjectdetector
|
||||
* Copyright (C) 2021 Collabora Ltd
|
||||
*
|
||||
* gstonnxobjectdetector.h
|
||||
*
|
||||
* This library is free software; you can redistribute it and/or
|
||||
* modify it under the terms of the GNU Library General Public
|
||||
* License as published by the Free Software Foundation; either
|
||||
* version 2 of the License, or (at your option) any later version.
|
||||
*
|
||||
* This library is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||
* Library General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Library General Public
|
||||
* License along with this library; if not, write to the
|
||||
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
|
||||
* Boston, MA 02110-1301, USA.
|
||||
*/
|
||||
|
||||
#ifndef __GST_ONNX_OBJECT_DETECTOR_H__
|
||||
#define __GST_ONNX_OBJECT_DETECTOR_H__
|
||||
|
||||
#include <gst/gst.h>
|
||||
#include <gst/video/video.h>
|
||||
#include <gst/video/gstvideofilter.h>
|
||||
#include "gstonnxelement.h"
|
||||
|
||||
G_BEGIN_DECLS
|
||||
|
||||
#define GST_TYPE_ONNX_OBJECT_DETECTOR (gst_onnx_object_detector_get_type())
|
||||
G_DECLARE_FINAL_TYPE (GstOnnxObjectDetector, gst_onnx_object_detector, GST, ONNX_OBJECT_DETECTOR, GstBaseTransform)
|
||||
#define GST_ONNX_OBJECT_DETECTOR(obj) (G_TYPE_CHECK_INSTANCE_CAST((obj),GST_TYPE_ONNX_OBJECT_DETECTOR,GstOnnxObjectDetector))
|
||||
#define GST_ONNX_OBJECT_DETECTOR_CLASS(klass) (G_TYPE_CHECK_CLASS_CAST((klass), GST_TYPE_ONNX_OBJECT_DETECTOR,GstOnnxObjectDetectorClass))
|
||||
#define GST_ONNX_OBJECT_DETECTOR_GET_CLASS(obj) (G_TYPE_INSTANCE_GET_CLASS((obj), GST_TYPE_ONNX_OBJECT_DETECTOR,GstOnnxObjectDetectorClass))
|
||||
#define GST_IS_ONNX_OBJECT_DETECTOR(obj) (G_TYPE_CHECK_INSTANCE_TYPE((obj),GST_TYPE_ONNX_OBJECT_DETECTOR))
|
||||
#define GST_IS_ONNX_OBJECT_DETECTOR_CLASS(klass) (G_TYPE_CHECK_CLASS_TYPE((klass), GST_TYPE_ONNX_OBJECT_DETECTOR))
|
||||
|
||||
#define GST_ONNX_OBJECT_DETECTOR_META_NAME "onnx-object_detector"
|
||||
#define GST_ONNX_OBJECT_DETECTOR_META_PARAM_NAME "extra-data"
|
||||
#define GST_ONNX_OBJECT_DETECTOR_META_FIELD_LABEL "label"
|
||||
#define GST_ONNX_OBJECT_DETECTOR_META_FIELD_SCORE "score"
|
||||
|
||||
/**
|
||||
* GstOnnxObjectDetector:
|
||||
*
|
||||
* @model_file model file
|
||||
* @label_file label file
|
||||
* @score_threshold score threshold
|
||||
* @confidence_threshold confidence threshold
|
||||
* @iou_threhsold iou threshold
|
||||
* @optimization_level ONNX optimization level
|
||||
* @execution_provider: ONNX execution provider
|
||||
* @onnx_ptr opaque pointer to ONNX implementation
|
||||
*
|
||||
* Since: 1.20
|
||||
*/
|
||||
struct _GstOnnxObjectDetector
|
||||
{
|
||||
GstBaseTransform basetransform;
|
||||
gchar *model_file;
|
||||
gchar *label_file;
|
||||
gfloat score_threshold;
|
||||
gfloat confidence_threshold;
|
||||
gfloat iou_threshold;
|
||||
GstOnnxOptimizationLevel optimization_level;
|
||||
GstOnnxExecutionProvider execution_provider;
|
||||
gpointer onnx_ptr;
|
||||
gboolean onnx_disabled;
|
||||
|
||||
void (*process) (GstOnnxObjectDetector * onnx_object_detector,
|
||||
GstVideoFrame * inframe, GstVideoFrame * outframe);
|
||||
};
|
||||
|
||||
/**
|
||||
* GstOnnxObjectDetectorClass:
|
||||
*
|
||||
* @parent_class base transform base class
|
||||
*
|
||||
* Since: 1.20
|
||||
*/
|
||||
struct _GstOnnxObjectDetectorClass
|
||||
{
|
||||
GstBaseTransformClass parent_class;
|
||||
};
|
||||
|
||||
GST_ELEMENT_REGISTER_DECLARE (onnx_object_detector)
|
||||
|
||||
G_END_DECLS
|
||||
|
||||
#endif /* __GST_ONNX_OBJECT_DETECTOR_H__ */
|
|
@ -3,27 +3,31 @@ if get_option('onnx').disabled()
|
|||
endif
|
||||
|
||||
|
||||
onnxrt_dep = dependency('libonnxruntime', version : '>= 1.13.1', required : get_option('onnx'))
|
||||
onnxrt_dep = dependency('libonnxruntime', version : '>= 1.15.1', required : get_option('onnx'))
|
||||
|
||||
extra_args = []
|
||||
extra_deps = []
|
||||
if gstcuda_dep.found()
|
||||
extra_args += ['-DHAVE_CUDA']
|
||||
extra_deps += [gstcuda_dep]
|
||||
endif
|
||||
|
||||
if onnxrt_dep.found()
|
||||
onnxrt_include_root = onnxrt_dep.get_variable('includedir')
|
||||
onnxrt_includes = [onnxrt_include_root / 'core/session', onnxrt_include_root / 'core']
|
||||
onnxrt_dep_args = []
|
||||
|
||||
compiler = meson.get_compiler('cpp')
|
||||
if compiler.has_header(onnxrt_include_root / 'core/providers/cuda/cuda_provider_factory.h')
|
||||
onnxrt_dep_args = ['-DGST_ML_ONNX_RUNTIME_HAVE_CUDA']
|
||||
endif
|
||||
gstonnx = library('gstonnx',
|
||||
'gstonnx.c',
|
||||
'gstonnxelement.c',
|
||||
'gstonnxobjectdetector.cpp',
|
||||
'decoders/gstobjectdetectorutils.cpp',
|
||||
'decoders/gstssdobjectdetector.cpp',
|
||||
'gstonnxinference.cpp',
|
||||
'gstonnxclient.cpp',
|
||||
c_args : gst_plugins_bad_args,
|
||||
cpp_args: onnxrt_dep_args,
|
||||
'tensor/gsttensorid.cpp',
|
||||
'tensor/gsttensormeta.c',
|
||||
c_args : gst_plugins_bad_args + extra_args,
|
||||
cpp_args : gst_plugins_bad_args + extra_args,
|
||||
link_args : noseh_link_args,
|
||||
include_directories : [configinc, libsinc, onnxrt_includes],
|
||||
dependencies : [gstbase_dep, gstvideo_dep, onnxrt_dep, libm],
|
||||
include_directories : [configinc, libsinc, onnxrt_includes, cuda_stubinc],
|
||||
dependencies : [gstbase_dep, gstvideo_dep, onnxrt_dep, libm] + extra_deps,
|
||||
install : true,
|
||||
install_dir : plugins_install_dir,
|
||||
)
|
||||
|
|
69
subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensor.h
Normal file
69
subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensor.h
Normal file
|
@ -0,0 +1,69 @@
|
|||
/*
|
||||
* GStreamer gstreamer-tensor
|
||||
* Copyright (C) 2023 Collabora Ltd
|
||||
*
|
||||
* gsttensor.h
|
||||
*
|
||||
* This library is free software; you can redistribute it and/or
|
||||
* modify it under the terms of the GNU Library General Public
|
||||
* License as published by the Free Software Foundation; either
|
||||
* version 2 of the License, or (at your option) any later version.
|
||||
*
|
||||
* This library is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||
* Library General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Library General Public
|
||||
* License along with this library; if not, write to the
|
||||
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
|
||||
* Boston, MA 02110-1301, USA.
|
||||
*/
|
||||
#ifndef __GST_TENSOR_H__
|
||||
#define __GST_TENSOR_H__
|
||||
|
||||
|
||||
/**
|
||||
* GstTensorType:
|
||||
*
|
||||
* @GST_TENSOR_TYPE_INT8 8 bit integer tensor data
|
||||
* @GST_TENSOR_TYPE_INT16 16 bit integer tensor data
|
||||
* @GST_TENSOR_TYPE_INT32 32 bit integer tensor data
|
||||
* @GST_TENSOR_TYPE_FLOAT16 16 bit floating point tensor data
|
||||
* @GST_TENSOR_TYPE_FLOAT32 32 bit floating point tensor data
|
||||
*
|
||||
* Since: 1.24
|
||||
*/
|
||||
typedef enum _GstTensorType
|
||||
{
|
||||
GST_TENSOR_TYPE_INT8,
|
||||
GST_TENSOR_TYPE_INT16,
|
||||
GST_TENSOR_TYPE_INT32,
|
||||
GST_TENSOR_TYPE_FLOAT16,
|
||||
GST_TENSOR_TYPE_FLOAT32
|
||||
} GstTensorType;
|
||||
|
||||
|
||||
/**
|
||||
* GstTensor:
|
||||
*
|
||||
* @id unique tensor identifier
|
||||
* @num_dims number of tensor dimensions
|
||||
* @dims tensor dimensions
|
||||
* @type @ref GstTensorType of tensor data
|
||||
* @data @ref GstBuffer holding tensor data
|
||||
*
|
||||
* Since: 1.24
|
||||
*/
|
||||
typedef struct _GstTensor
|
||||
{
|
||||
GQuark id;
|
||||
gint num_dims;
|
||||
int64_t *dims;
|
||||
GstTensorType type;
|
||||
GstBuffer *data;
|
||||
} GstTensor;
|
||||
|
||||
#define GST_TENSOR_MISSING_ID -1
|
||||
|
||||
#endif
|
82
subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensorid.cpp
Normal file
82
subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensorid.cpp
Normal file
|
@ -0,0 +1,82 @@
|
|||
/*
|
||||
* GStreamer gstreamer-tensorid
|
||||
* Copyright (C) 2023 Collabora Ltd
|
||||
*
|
||||
* gsttensorid.c
|
||||
*
|
||||
* This library is free software; you can redistribute it and/or
|
||||
* modify it under the terms of the GNU Library General Public
|
||||
* License as published by the Free Software Foundation; either
|
||||
* version 2 of the License, or (at your option) any later version.
|
||||
*
|
||||
* This library is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||
* Library General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Library General Public
|
||||
* License along with this library; if not, write to the
|
||||
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
|
||||
* Boston, MA 02110-1301, USA.
|
||||
*/
|
||||
|
||||
#include <gst/gst.h>
|
||||
#include "gsttensorid.h"
|
||||
|
||||
/* Structure to encapsulate a string and its associated GQuark */
|
||||
struct TensorQuark
|
||||
{
|
||||
const char *string;
|
||||
GQuark quark_id;
|
||||
};
|
||||
|
||||
class TensorId
|
||||
{
|
||||
public:
|
||||
TensorId (void):tensor_quarks_array (g_array_new (FALSE, FALSE,
|
||||
sizeof (TensorQuark)))
|
||||
{
|
||||
}
|
||||
~TensorId (void)
|
||||
{
|
||||
if (tensor_quarks_array) {
|
||||
for (guint i = 0; i < tensor_quarks_array->len; i++) {
|
||||
TensorQuark *quark =
|
||||
&g_array_index (tensor_quarks_array, TensorQuark, i);
|
||||
g_free ((gpointer) quark->string); // free the duplicated string
|
||||
}
|
||||
g_array_free (tensor_quarks_array, TRUE);
|
||||
}
|
||||
}
|
||||
GQuark get_quark (const char *str)
|
||||
{
|
||||
for (guint i = 0; i < tensor_quarks_array->len; i++) {
|
||||
TensorQuark *quark = &g_array_index (tensor_quarks_array, TensorQuark, i);
|
||||
if (g_strcmp0 (quark->string, str) == 0) {
|
||||
return quark->quark_id; // already registered
|
||||
}
|
||||
}
|
||||
|
||||
// Register the new quark and append to the GArray
|
||||
TensorQuark new_quark;
|
||||
new_quark.string = g_strdup (str); // create a copy of the string
|
||||
new_quark.quark_id = g_quark_from_string (new_quark.string);
|
||||
g_array_append_val (tensor_quarks_array, new_quark);
|
||||
|
||||
return new_quark.quark_id;
|
||||
}
|
||||
private:
|
||||
GArray * tensor_quarks_array;
|
||||
};
|
||||
|
||||
static TensorId tensorId;
|
||||
|
||||
G_BEGIN_DECLS
|
||||
|
||||
GQuark
|
||||
gst_tensorid_get_quark (const char *tensor_id)
|
||||
{
|
||||
return tensorId.get_quark (tensor_id);
|
||||
}
|
||||
|
||||
G_END_DECLS
|
34
subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensorid.h
Normal file
34
subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensorid.h
Normal file
|
@ -0,0 +1,34 @@
|
|||
/*
|
||||
* GStreamer gstreamer-tensorid
|
||||
* Copyright (C) 2023 Collabora Ltd
|
||||
*
|
||||
* gsttensorid.h
|
||||
*
|
||||
* This library is free software; you can redistribute it and/or
|
||||
* modify it under the terms of the GNU Library General Public
|
||||
* License as published by the Free Software Foundation; either
|
||||
* version 2 of the License, or (at your option) any later version.
|
||||
*
|
||||
* This library is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||
* Library General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Library General Public
|
||||
* License along with this library; if not, write to the
|
||||
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
|
||||
* Boston, MA 02110-1301, USA.
|
||||
*/
|
||||
#ifndef __GST_TENSOR_ID_H__
|
||||
#define __GST_TENSOR_ID_H__
|
||||
|
||||
G_BEGIN_DECLS
|
||||
/**
|
||||
* gst_tensorid_get_quark get tensor id
|
||||
*
|
||||
* @param tensor_id unique string id for tensor node
|
||||
*/
|
||||
GQuark gst_tensorid_get_quark (const char *tensor_id);
|
||||
|
||||
G_END_DECLS
|
||||
#endif
|
107
subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensormeta.c
Normal file
107
subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensormeta.c
Normal file
|
@ -0,0 +1,107 @@
|
|||
/*
|
||||
* GStreamer gstreamer-tensormeta
|
||||
* Copyright (C) 2023 Collabora Ltd
|
||||
*
|
||||
* gsttensormeta.c
|
||||
*
|
||||
* This library is free software; you can redistribute it and/or
|
||||
* modify it under the terms of the GNU Library General Public
|
||||
* License as published by the Free Software Foundation; either
|
||||
* version 2 of the License, or (at your option) any later version.
|
||||
*
|
||||
* This library is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||
* Library General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Library General Public
|
||||
* License along with this library; if not, write to the
|
||||
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
|
||||
* Boston, MA 02110-1301, USA.
|
||||
*/
|
||||
|
||||
#include "gsttensormeta.h"
|
||||
|
||||
#include "gsttensor.h"
|
||||
|
||||
static gboolean
|
||||
gst_tensor_meta_init (GstMeta * meta, gpointer params, GstBuffer * buffer)
|
||||
{
|
||||
GstTensorMeta *tmeta = (GstTensorMeta *) meta;
|
||||
|
||||
tmeta->num_tensors = 0;
|
||||
tmeta->tensor = NULL;
|
||||
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
static void
|
||||
gst_tensor_meta_free (GstMeta * meta, GstBuffer * buffer)
|
||||
{
|
||||
GstTensorMeta *tmeta = (GstTensorMeta *) meta;
|
||||
|
||||
for (int i = 0; i < tmeta->num_tensors; i++) {
|
||||
g_free (tmeta->tensor[i].dims);
|
||||
gst_buffer_unref (tmeta->tensor[i].data);
|
||||
}
|
||||
g_free (tmeta->tensor);
|
||||
}
|
||||
|
||||
GType
|
||||
gst_tensor_meta_api_get_type (void)
|
||||
{
|
||||
static GType type = 0;
|
||||
static const gchar *tags[] = { NULL };
|
||||
|
||||
if (g_once_init_enter (&type)) {
|
||||
type = gst_meta_api_type_register ("GstTensorMetaAPI", tags);
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
|
||||
const GstMetaInfo *
|
||||
gst_tensor_meta_get_info (void)
|
||||
{
|
||||
static const GstMetaInfo *tmeta_info = NULL;
|
||||
|
||||
if (g_once_init_enter (&tmeta_info)) {
|
||||
const GstMetaInfo *meta =
|
||||
gst_meta_register (gst_tensor_meta_api_get_type (),
|
||||
"GstTensorMeta",
|
||||
sizeof (GstTensorMeta),
|
||||
gst_tensor_meta_init,
|
||||
gst_tensor_meta_free,
|
||||
NULL); /* tensor_meta_transform not implemented */
|
||||
g_once_init_leave (&tmeta_info, meta);
|
||||
}
|
||||
return tmeta_info;
|
||||
}
|
||||
|
||||
GList *
|
||||
gst_tensor_meta_get_all_from_buffer (GstBuffer * buffer)
|
||||
{
|
||||
GType tensor_meta_api_type = gst_tensor_meta_api_get_type ();
|
||||
GList *tensor_metas = NULL;
|
||||
gpointer state = NULL;
|
||||
GstMeta *meta;
|
||||
|
||||
while ((meta = gst_buffer_iterate_meta (buffer, &state))) {
|
||||
if (meta->info->api == tensor_meta_api_type) {
|
||||
tensor_metas = g_list_append (tensor_metas, meta);
|
||||
}
|
||||
}
|
||||
|
||||
return tensor_metas;
|
||||
}
|
||||
|
||||
gint
|
||||
gst_tensor_meta_get_index_from_id (GstTensorMeta * meta, GQuark id)
|
||||
{
|
||||
for (int i = 0; i < meta->num_tensors; ++i) {
|
||||
if ((meta->tensor + i)->id == id)
|
||||
return i;
|
||||
}
|
||||
|
||||
return GST_TENSOR_MISSING_ID;
|
||||
}
|
56
subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensormeta.h
Normal file
56
subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensormeta.h
Normal file
|
@ -0,0 +1,56 @@
|
|||
/*
|
||||
* GStreamer gstreamer-tensormeta
|
||||
* Copyright (C) 2023 Collabora Ltd
|
||||
*
|
||||
* gsttensormeta.h
|
||||
*
|
||||
* This library is free software; you can redistribute it and/or
|
||||
* modify it under the terms of the GNU Library General Public
|
||||
* License as published by the Free Software Foundation; either
|
||||
* version 2 of the License, or (at your option) any later version.
|
||||
*
|
||||
* This library is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||
* Library General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Library General Public
|
||||
* License along with this library; if not, write to the
|
||||
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
|
||||
* Boston, MA 02110-1301, USA.
|
||||
*/
|
||||
#ifndef __GST_TENSOR_META_H__
|
||||
#define __GST_TENSOR_META_H__
|
||||
|
||||
#include <gst/gst.h>
|
||||
#include "gsttensor.h"
|
||||
|
||||
/**
|
||||
* GstTensorMeta:
|
||||
*
|
||||
* @meta base GstMeta
|
||||
* @num_tensors number of tensors
|
||||
* @tensor @ref GstTensor for each tensor
|
||||
* @batch_size model batch size
|
||||
*
|
||||
* Since: 1.24
|
||||
*/
|
||||
typedef struct _GstTensorMeta
|
||||
{
|
||||
GstMeta meta;
|
||||
|
||||
gint num_tensors;
|
||||
GstTensor *tensor;
|
||||
int batch_size;
|
||||
} GstTensorMeta;
|
||||
|
||||
G_BEGIN_DECLS
|
||||
|
||||
GType gst_tensor_meta_api_get_type (void);
|
||||
const GstMetaInfo *gst_tensor_meta_get_info (void);
|
||||
GList *gst_tensor_meta_get_all_from_buffer (GstBuffer * buffer);
|
||||
gint gst_tensor_meta_get_index_from_id(GstTensorMeta *meta, GQuark id);
|
||||
|
||||
G_END_DECLS
|
||||
|
||||
#endif
|
Loading…
Reference in a new issue