onnx: Port SSD detector to C

Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/6001>
This commit is contained in:
Olivier Crête 2024-01-25 00:02:30 -05:00
parent 5e1291fd86
commit 3325a10f57
8 changed files with 262 additions and 378 deletions

View file

@ -1,228 +0,0 @@
/*
* GStreamer gstreamer-objectdetectorutils
* Copyright (C) 2023 Collabora Ltd
*
* gstobjectdetectorutils.cpp
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public
* License along with this library; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
#include "gstobjectdetectorutils.h"
#include <gio/gio.h>
char **
read_labels (const char * labels_file)
{
GPtrArray *array;
GFile *file = g_file_new_for_path (labels_file);
GFileInputStream *file_stream;
GDataInputStream *data_stream;
GError *error = NULL;
gchar *line;
file_stream = g_file_read (file, NULL, &error);
g_object_unref (file);
if (!file_stream) {
GST_WARNING ("Could not open file %s: %s\n", labels_file,
error->message);
g_clear_error (&error);
return NULL;
}
data_stream = g_data_input_stream_new (G_INPUT_STREAM (file_stream));
g_object_unref (file_stream);
array = g_ptr_array_new();
while ((line = g_data_input_stream_read_line (data_stream, NULL, NULL,
&error)))
g_ptr_array_add (array, line);
g_object_unref (data_stream);
if (error) {
GST_WARNING ("Could not open file %s: %s", labels_file, error->message);
g_ptr_array_free (array, TRUE);
g_clear_error (&error);
return NULL;
}
if (array->len == 0) {
g_ptr_array_free (array, TRUE);
return NULL;
}
g_ptr_array_add (array, NULL);
return (char **) g_ptr_array_free (array, FALSE);
}
GstMlBoundingBox::GstMlBoundingBox (std::string lbl, float score, float _x0,
float _y0, float _width, float _height):
label (lbl),
score (score),
x0 (_x0),
y0 (_y0),
width (_width),
height (_height)
{
}
GstMlBoundingBox::GstMlBoundingBox ():
GstMlBoundingBox ("", 0.0f, 0.0f, 0.0f, 0.0f, 0.0f)
{
}
namespace GstObjectDetectorUtils
{
GstObjectDetectorUtils::GstObjectDetectorUtils ()
{
}
std::vector < GstMlBoundingBox > GstObjectDetectorUtils::run (int32_t w,
int32_t h, GstTensorMeta * tmeta, gchar **labels,
float scoreThreshold)
{
auto classIndex = gst_tensor_meta_get_index_from_id (tmeta,
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_CLASSES));
if (classIndex == GST_TENSOR_MISSING_ID) {
GST_ERROR ("Missing class tensor id");
return std::vector < GstMlBoundingBox > ();
}
auto type = tmeta->tensor[classIndex].type;
return (type == GST_TENSOR_TYPE_FLOAT32) ?
doRun < float >(w, h, tmeta, labels, scoreThreshold)
: doRun < int >(w, h, tmeta, labels, scoreThreshold);
}
template < typename T > std::vector < GstMlBoundingBox >
GstObjectDetectorUtils::doRun (int32_t w, int32_t h,
GstTensorMeta * tmeta, char **labels, float scoreThreshold)
{
std::vector < GstMlBoundingBox > boundingBoxes;
GstMapInfo map_info[GstObjectDetectorMaxNodes];
GstMemory *memory[GstObjectDetectorMaxNodes] = { NULL };
gint index;
T *numDetections = nullptr, *bboxes = nullptr, *scores =
nullptr, *labelIndex = nullptr;
// number of detections
index = gst_tensor_meta_get_index_from_id (tmeta,
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS));
if (index == GST_TENSOR_MISSING_ID) {
GST_WARNING ("Missing tensor data for tensor index %d", index);
goto cleanup;
}
memory[index] = gst_buffer_peek_memory (tmeta->tensor[index].data, 0);
if (!memory[index]) {
GST_WARNING ("Missing tensor data for tensor index %d", index);
goto cleanup;
}
if (!gst_memory_map (memory[index], map_info + index, GST_MAP_READ)) {
GST_WARNING ("Failed to map tensor memory for index %d", index);
goto cleanup;
}
numDetections = (T *) map_info[index].data;
// bounding boxes
index =
gst_tensor_meta_get_index_from_id (tmeta,
g_quark_from_static_string(GST_MODEL_OBJECT_DETECTOR_BOXES));
if (index == GST_TENSOR_MISSING_ID) {
GST_WARNING ("Missing tensor data for tensor index %d", index);
goto cleanup;
}
memory[index] = gst_buffer_peek_memory (tmeta->tensor[index].data, 0);
if (!memory[index]) {
GST_WARNING ("Failed to map tensor memory for index %d", index);
goto cleanup;
}
if (!gst_memory_map (memory[index], map_info + index, GST_MAP_READ)) {
GST_ERROR ("Failed to map GstMemory");
goto cleanup;
}
bboxes = (T *) map_info[index].data;
// scores
index =
gst_tensor_meta_get_index_from_id (tmeta,
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_SCORES));
if (index == GST_TENSOR_MISSING_ID) {
GST_ERROR ("Missing scores tensor id");
goto cleanup;
}
memory[index] = gst_buffer_peek_memory (tmeta->tensor[index].data, 0);
if (!memory[index]) {
GST_WARNING ("Missing tensor data for tensor index %d", index);
goto cleanup;
}
if (!gst_memory_map (memory[index], map_info + index, GST_MAP_READ)) {
GST_ERROR ("Failed to map GstMemory");
goto cleanup;
}
scores = (T *) map_info[index].data;
// optional label
labelIndex = nullptr;
index =
gst_tensor_meta_get_index_from_id (tmeta,
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_CLASSES));
if (index != GST_TENSOR_MISSING_ID) {
memory[index] = gst_buffer_peek_memory (tmeta->tensor[index].data, 0);
if (!memory[index]) {
GST_WARNING ("Missing tensor data for tensor index %d", index);
goto cleanup;
}
if (!gst_memory_map (memory[index], map_info + index, GST_MAP_READ)) {
GST_ERROR ("Failed to map GstMemory");
goto cleanup;
}
labelIndex = (T *) map_info[index].data;
}
for (int i = 0; i < numDetections[0]; ++i) {
if (scores[i] > scoreThreshold) {
std::string label = "";
if (labels && labelIndex)
label = labels[(int)labelIndex[i] - 1];
auto score = scores[i];
auto y0 = bboxes[i * 4] * h;
auto x0 = bboxes[i * 4 + 1] * w;
auto bheight = bboxes[i * 4 + 2] * h - y0;
auto bwidth = bboxes[i * 4 + 3] * w - x0;
boundingBoxes.push_back (GstMlBoundingBox (label, score, x0, y0, bwidth,
bheight));
}
}
cleanup:
for (int i = 0; i < GstObjectDetectorMaxNodes; ++i) {
if (memory[i])
gst_memory_unmap (memory[i], map_info + i);
}
return boundingBoxes;
}
}

View file

@ -1,83 +0,0 @@
/*
* GStreamer gstreamer-objectdetectorutils
* Copyright (C) 2023 Collabora Ltd
*
* gstobjectdetectorutils.h
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Library General Public
* License as published by the Free Software Foundation; either
* version 2 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Library General Public License for more details.
*
* You should have received a copy of the GNU Library General Public
* License along with this library; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301, USA.
*/
#ifndef __GST_OBJECT_DETECTOR_UTILS_H__
#define __GST_OBJECT_DETECTOR_UTILS_H__
#include <gst/gst.h>
#include <string>
#include <vector>
#include "gstml.h"
#include "tensor/gsttensormeta.h"
char ** read_labels (const char * labels_file);
/* Object detection tensor id strings */
#define GST_MODEL_OBJECT_DETECTOR_BOXES "Gst.Model.ObjectDetector.Boxes"
#define GST_MODEL_OBJECT_DETECTOR_SCORES "Gst.Model.ObjectDetector.Scores"
#define GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS "Gst.Model.ObjectDetector.NumDetections"
#define GST_MODEL_OBJECT_DETECTOR_CLASSES "Gst.Model.ObjectDetector.Classes"
/**
* GstMlBoundingBox:
*
* @label label
* @score detection confidence
* @x0 top left hand x coordinate
* @y0 top left hand y coordinate
* @width width
* @height height
*
* Since: 1.20
*/
struct GstMlBoundingBox {
GstMlBoundingBox(std::string lbl, float score, float _x0, float _y0,
float _width, float _height);
GstMlBoundingBox();
std::string label;
float score;
float x0;
float y0;
float width;
float height;
};
namespace GstObjectDetectorUtils {
const int GstObjectDetectorMaxNodes = 4;
class GstObjectDetectorUtils {
public:
GstObjectDetectorUtils(void);
~GstObjectDetectorUtils(void) = default;
std::vector < GstMlBoundingBox > run(int32_t w, int32_t h,
GstTensorMeta *tmeta,
char **labels,
float scoreThreshold);
private:
template < typename T > std::vector < GstMlBoundingBox >
doRun(int32_t w, int32_t h,
GstTensorMeta *tmeta, char **labels,
float scoreThreshold);
};
}
#endif /* __GST_OBJECT_DETECTOR_UTILS_H__ */

View file

@ -44,19 +44,23 @@
#include "config.h"
#endif
#include "gstssdobjectdetector.h"
#include "gstobjectdetectorutils.h"
#include <gio/gio.h>
#include <gst/gst.h>
#include <gst/video/video.h>
#include <gst/analytics/analytics.h>
#include "tensor/gsttensormeta.h"
/* Object detection tensor id strings */
#define GST_MODEL_OBJECT_DETECTOR_BOXES "Gst.Model.ObjectDetector.Boxes"
#define GST_MODEL_OBJECT_DETECTOR_SCORES "Gst.Model.ObjectDetector.Scores"
#define GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS "Gst.Model.ObjectDetector.NumDetections"
#define GST_MODEL_OBJECT_DETECTOR_CLASSES "Gst.Model.ObjectDetector.Classes"
GST_DEBUG_CATEGORY_STATIC (ssd_object_detector_debug);
#define GST_CAT_DEFAULT ssd_object_detector_debug
#define GST_ODUTILS_MEMBER( self ) ((GstObjectDetectorUtils::GstObjectDetectorUtils *) (self->odutils))
GST_ELEMENT_REGISTER_DEFINE (ssd_object_detector, "ssdobjectdetector",
GST_RANK_PRIMARY, GST_TYPE_SSD_OBJECT_DETECTOR);
@ -68,7 +72,7 @@ enum
PROP_SCORE_THRESHOLD,
};
#define GST_SSD_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD 0.3f /* 0 to 1 */
#define GST_SSD_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD 0.3f /* 0 to 1 */
static GstStaticPadTemplate gst_ssd_object_detector_src_template =
GST_STATIC_PAD_TEMPLATE ("src",
@ -97,7 +101,8 @@ static gboolean
gst_ssd_object_detector_set_caps (GstBaseTransform * trans, GstCaps * incaps,
GstCaps * outcaps);
G_DEFINE_TYPE (GstSsdObjectDetector, gst_ssd_object_detector, GST_TYPE_BASE_TRANSFORM);
G_DEFINE_TYPE (GstSsdObjectDetector, gst_ssd_object_detector,
GST_TYPE_BASE_TRANSFORM);
static void
gst_ssd_object_detector_class_init (GstSsdObjectDetectorClass * klass)
@ -135,7 +140,8 @@ gst_ssd_object_detector_class_init (GstSsdObjectDetectorClass * klass)
g_param_spec_float ("score-threshold",
"Score threshold",
"Threshold for deciding when to remove boxes based on score",
0.0, 1.0, GST_SSD_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD, (GParamFlags)
0.0, 1.0, GST_SSD_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD,
(GParamFlags)
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
gst_element_class_set_static_metadata (element_class, "objectdetector",
@ -155,7 +161,6 @@ gst_ssd_object_detector_class_init (GstSsdObjectDetectorClass * klass)
static void
gst_ssd_object_detector_init (GstSsdObjectDetector * self)
{
self->odutils = new GstObjectDetectorUtils::GstObjectDetectorUtils ();
}
static void
@ -164,12 +169,58 @@ gst_ssd_object_detector_finalize (GObject * object)
GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (object);
g_free (self->label_file);
g_strfreev (self->labels);
delete GST_ODUTILS_MEMBER (self);
g_clear_pointer (&self->labels, g_array_unref);
G_OBJECT_CLASS (gst_ssd_object_detector_parent_class)->finalize (object);
}
static GArray *
read_labels (const char *labels_file)
{
GArray *array;
GFile *file = g_file_new_for_path (labels_file);
GFileInputStream *file_stream;
GDataInputStream *data_stream;
GError *error = NULL;
gchar *line;
file_stream = g_file_read (file, NULL, &error);
g_object_unref (file);
if (!file_stream) {
GST_WARNING ("Could not open file %s: %s\n", labels_file, error->message);
g_clear_error (&error);
return NULL;
}
data_stream = g_data_input_stream_new (G_INPUT_STREAM (file_stream));
g_object_unref (file_stream);
array = g_array_new (FALSE, FALSE, sizeof (GQuark));
while ((line = g_data_input_stream_read_line (data_stream, NULL, NULL,
&error))) {
GQuark label = g_quark_from_string (line);
g_array_append_val (array, label);
g_free (line);
}
g_object_unref (data_stream);
if (error) {
GST_WARNING ("Could not open file %s: %s", labels_file, error->message);
g_array_free (array, TRUE);
g_clear_error (&error);
return NULL;
}
if (array->len == 0) {
g_array_free (array, TRUE);
return NULL;
}
return array;
}
static void
gst_ssd_object_detector_set_property (GObject * object, guint prop_id,
const GValue * value, GParamSpec * pspec)
@ -179,21 +230,21 @@ gst_ssd_object_detector_set_property (GObject * object, guint prop_id,
switch (prop_id) {
case PROP_LABEL_FILE:
{
gchar **labels;
{
GArray *labels;
filename = g_value_get_string (value);
labels = read_labels (filename);
filename = g_value_get_string (value);
labels = read_labels (filename);
if (labels) {
g_free (self->label_file);
self->label_file = g_strdup (filename);
g_strfreev (self->labels);
self->labels = labels;
} else {
GST_WARNING_OBJECT (self, "Label file '%s' not found!", filename);
}
if (labels) {
g_free (self->label_file);
self->label_file = g_strdup (filename);
g_clear_pointer (&self->labels, g_array_unref);
self->labels = labels;
} else {
GST_WARNING_OBJECT (self, "Label file '%s' not found!", filename);
}
}
break;
case PROP_SCORE_THRESHOLD:
GST_OBJECT_LOCK (self);
@ -259,7 +310,8 @@ gst_ssd_object_detector_get_tensor_meta (GstSsdObjectDetector * object_detector,
gint clasesIndex = gst_tensor_meta_get_index_from_id (tensor_meta,
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_CLASSES));
if (boxesIndex == GST_TENSOR_MISSING_ID || scoresIndex == GST_TENSOR_MISSING_ID
if (boxesIndex == GST_TENSOR_MISSING_ID
|| scoresIndex == GST_TENSOR_MISSING_ID
|| numDetectionsIndex == GST_TENSOR_MISSING_ID)
continue;
@ -300,13 +352,175 @@ gst_ssd_object_detector_transform_ip (GstBaseTransform * trans, GstBuffer * buf)
return GST_FLOW_OK;
}
#define DEFINE_GET_FUNC(TYPE, MAX) \
static gboolean \
get_ ## TYPE ## _at_index (GstTensor *tensor, GstMapInfo *map, \
guint index, TYPE * out) \
{ \
switch (tensor->type) { \
case GST_TENSOR_TYPE_FLOAT32: { \
float *f = (float *) map->data; \
if (sizeof(*f) * (index + 1) > map->size) \
return FALSE; \
*out = f[index]; \
break; \
} \
case GST_TENSOR_TYPE_UINT32: { \
guint32 *u = (guint32 *) map->data; \
if (sizeof(*u) * (index + 1) > map->size) \
return FALSE; \
*out = u[index]; \
break; \
} \
default: \
GST_ERROR ("Only float32 and int32 tensors are understood"); \
return FALSE; \
} \
return TRUE; \
}
DEFINE_GET_FUNC (guint32, UINT32_MAX)
DEFINE_GET_FUNC (float, FLOAT_MAX)
#undef DEFINE_GET_FUNC
static void
extract_bounding_boxes (GstSsdObjectDetector * self, gsize w, gsize h,
GstAnalyticsRelationMeta * rmeta, GstTensorMeta * tmeta)
{
gint classes_index;
gint boxes_index;
gint scores_index;
gint numdetect_index;
GstMapInfo boxes_map = GST_MAP_INFO_INIT;
GstMapInfo numdetect_map = GST_MAP_INFO_INIT;
GstMapInfo scores_map = GST_MAP_INFO_INIT;
GstMapInfo classes_map = GST_MAP_INFO_INIT;
guint num_detections = 0;
classes_index = gst_tensor_meta_get_index_from_id (tmeta,
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_CLASSES));
numdetect_index = gst_tensor_meta_get_index_from_id (tmeta,
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS));
scores_index = gst_tensor_meta_get_index_from_id (tmeta,
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_SCORES));
boxes_index = gst_tensor_meta_get_index_from_id (tmeta,
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_BOXES));
if (numdetect_index == GST_TENSOR_MISSING_ID
|| scores_index == GST_TENSOR_MISSING_ID
|| numdetect_index == GST_TENSOR_MISSING_ID) {
GST_WARNING ("Missing tensor data expected for SSD model");
return;
}
if (!gst_buffer_map (tmeta->tensor[numdetect_index].data, &numdetect_map,
GST_MAP_READ)) {
GST_ERROR_OBJECT (self, "Failed to map tensor memory for index %d",
numdetect_index);
goto cleanup;
}
if (!gst_buffer_map (tmeta->tensor[boxes_index].data, &boxes_map,
GST_MAP_READ)) {
GST_ERROR_OBJECT (self, "Failed to map tensor memory for index %d",
boxes_index);
goto cleanup;
}
if (!gst_buffer_map (tmeta->tensor[scores_index].data, &scores_map,
GST_MAP_READ)) {
GST_ERROR_OBJECT (self, "Failed to map tensor memory for index %d",
scores_index);
goto cleanup;
}
if (classes_index != GST_TENSOR_MISSING_ID &&
!gst_buffer_map (tmeta->tensor[classes_index].data, &classes_map,
GST_MAP_READ)) {
GST_DEBUG_OBJECT (self, "Failed to map tensor memory for index %d",
classes_index);
}
if (!get_guint32_at_index (&tmeta->tensor[numdetect_index], &numdetect_map,
0, &num_detections)) {
GST_ERROR_OBJECT (self, "Failed to get the number of detections");
goto cleanup;
}
GST_LOG_OBJECT (self, "Model claims %d detections", num_detections);
for (int i = 0; i < num_detections; i++) {
float score;
float x, y, bwidth, bheight;
gint x_i, y_i, bwidth_i, bheight_i;
guint32 bclass;
GQuark label = 0;
GstAnalyticsODMtd odmtd;
if (!get_float_at_index (&tmeta->tensor[numdetect_index], &scores_map,
i, &score))
continue;
GST_LOG_OBJECT (self, "Detection %u score is %f", i, score);
if (score < self->score_threshold)
continue;
if (!get_float_at_index (&tmeta->tensor[boxes_index], &boxes_map,
i * 4, &y))
continue;
if (!get_float_at_index (&tmeta->tensor[boxes_index], &boxes_map,
i * 4 + 1, &x))
continue;
if (!get_float_at_index (&tmeta->tensor[boxes_index], &boxes_map,
i * 4 + 2, &bheight))
continue;
if (!get_float_at_index (&tmeta->tensor[boxes_index], &boxes_map,
i * 4 + 3, &bwidth))
continue;
if (self->labels && classes_map.memory &&
get_guint32_at_index (&tmeta->tensor[classes_index], &classes_map,
i, &bclass)) {
if (bclass < self->labels->len)
label = g_array_index (self->labels, GQuark, bclass);
}
x_i = x * w;
y_i = y * h;
bheight_i = (bheight * h) - y_i;
bwidth_i = (bwidth * w) - x_i;
if (gst_analytics_relation_meta_add_od_mtd (rmeta, label,
x_i, y_i, bwidth_i, bheight_i, score, &odmtd))
GST_DEBUG_OBJECT (self,
"Object detected with label : %s, score: %f, bound box: %dx%d at (%d,%d)",
g_quark_to_string (label), score, bwidth_i, bheight_i, x_i, y_i);
else
GST_WARNING_OBJECT (self, "Could not add detection to meta");
}
cleanup:
if (numdetect_map.memory)
gst_buffer_unmap (tmeta->tensor[numdetect_index].data, &numdetect_map);
if (classes_map.memory)
gst_buffer_unmap (tmeta->tensor[classes_index].data, &classes_map);
if (scores_map.memory)
gst_buffer_unmap (tmeta->tensor[scores_index].data, &scores_map);
if (boxes_map.memory)
gst_buffer_unmap (tmeta->tensor[boxes_index].data, &boxes_map);
}
static gboolean
gst_ssd_object_detector_process (GstBaseTransform * trans, GstBuffer * buf)
{
GstTensorMeta *tmeta = NULL;
GstAnalyticsODMtd odmtd;
GstAnalyticsRelationMeta *rmeta;
GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (trans);
GstTensorMeta *tmeta;
GstAnalyticsRelationMeta *rmeta;
// get all tensor metas
tmeta = gst_ssd_object_detector_get_tensor_meta (self, buf);
@ -315,25 +529,11 @@ gst_ssd_object_detector_process (GstBaseTransform * trans, GstBuffer * buf)
return TRUE;
} else {
rmeta = gst_buffer_add_analytics_relation_meta (buf);
g_return_val_if_fail (rmeta != NULL, FALSE);
g_assert (rmeta);
}
std::vector < GstMlBoundingBox > boxes =
GST_ODUTILS_MEMBER (self)->run (self->video_info.width,
self->video_info.height, tmeta, self->labels,
self->score_threshold);
for (auto & b:boxes) {
if (gst_analytics_relation_meta_add_od_mtd (rmeta,
g_quark_from_string(b.label.c_str ()), b.x0, b.y0, b.width, b.height,
b.score, &odmtd)) {
GST_DEBUG_OBJECT (self,
"Object detected with label : %s, score: %f, bound box: (%f,%f,%f,%f)",
b.label.c_str (), b.score, b.x0, b.y0, b.x0 + b.width, b.y0 + b.height);
} else {
GST_ERROR_OBJECT (self, "Failed to add object detection analytics-meta");
}
}
extract_bounding_boxes (self, self->video_info.width,
self->video_info.height, rmeta, tmeta);
return TRUE;
}

View file

@ -37,14 +37,11 @@ G_DECLARE_FINAL_TYPE (GstSsdObjectDetector, gst_ssd_object_detector, GST, SSD_OB
#define GST_SSD_OBJECT_DETECTOR_META_FIELD_LABEL "label"
#define GST_SSD_OBJECT_DETECTOR_META_FIELD_SCORE "score"
/**
/*
* GstSsdObjectDetector:
*
* @label_file label file
* @score_threshold score threshold
* @confidence_threshold confidence threshold
* @iou_threhsold iou threshold
* @od_ptr opaque pointer to GstOd object detection implementation
*
* Since: 1.20
*/
@ -52,11 +49,8 @@ struct _GstSsdObjectDetector
{
GstBaseTransform basetransform;
gchar *label_file;
gchar **labels;
GArray *labels;
gfloat score_threshold;
gfloat confidence_threshold;
gfloat iou_threshold;
gpointer odutils;
GstVideoInfo video_info;
};

View file

@ -96,7 +96,7 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
return inputImageFormat;
}
void GstOnnxClient::setInputImageDatatype(GstTensorType datatype)
void GstOnnxClient::setInputImageDatatype(GstTensorDataType datatype)
{
inputDatatype = datatype;
switch (inputDatatype) {
@ -144,7 +144,7 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
return inputTensorScale;
}
GstTensorType GstOnnxClient::getInputImageDatatype(void)
GstTensorDataType GstOnnxClient::getInputImageDatatype(void)
{
return inputDatatype;
}

View file

@ -56,7 +56,7 @@ namespace GstOnnxNamespace {
bool hasSession(void);
void setInputImageFormat(GstMlInputImageFormat format);
GstMlInputImageFormat getInputImageFormat(void);
GstTensorType getInputImageDatatype(void);
GstTensorDataType getInputImageDatatype(void);
void setInputImageOffset (float offset);
float getInputImageOffset ();
void setInputImageScale (float offset);
@ -73,7 +73,7 @@ namespace GstOnnxNamespace {
private:
GstElement *debug_parent;
void setInputImageDatatype (GstTensorType datatype);
void setInputImageDatatype (GstTensorDataType datatype);
template < typename T>
void convert_image_remove_alpha (T *dest, GstMlInputImageFormat hwc,
uint8_t **srcPtr, uint32_t srcSamplesPerPixel, uint32_t stride, T offset, T div);
@ -91,7 +91,7 @@ namespace GstOnnxNamespace {
std::vector < Ort::AllocatedStringPtr > outputNames;
std::vector < GQuark > outputIds;
GstMlInputImageFormat inputImageFormat;
GstTensorType inputDatatype;
GstTensorDataType inputDatatype;
size_t inputDatatypeSize;
bool fixedInputImageSize;
float inputTensorOffset;

View file

@ -15,8 +15,7 @@ endif
if onnxrt_dep.found()
gstonnx = library('gstonnx',
'gstonnx.c',
'decoders/gstobjectdetectorutils.cpp',
'decoders/gstssdobjectdetector.cpp',
'decoders/gstssdobjectdetector.c',
'gstonnxinference.cpp',
'gstonnxclient.cpp',
'tensor/gsttensormeta.c',

View file

@ -25,7 +25,7 @@
#include <gst/gst.h>
/**
* GstTensorType:
* GstTensorDataType:
*
* @GST_TENSOR_TYPE_INT4 signed 4 bit integer tensor data
* @GST_TENSOR_TYPE_INT8 signed 8 bit integer tensor data
@ -42,9 +42,11 @@
* @GST_TENSOR_TYPE_FLOAT64 64 bit floating point tensor data
* @GST_TENSOR_TYPE_BFLOAT16 "brain" 16 bit floating point tensor data
*
* Describe the type of data contain in the tensor.
*
* Since: 1.24
*/
typedef enum _GstTensorType
typedef enum _GstTensorDataType
{
GST_TENSOR_TYPE_INT4,
GST_TENSOR_TYPE_INT8,
@ -60,17 +62,17 @@ typedef enum _GstTensorType
GST_TENSOR_TYPE_FLOAT32,
GST_TENSOR_TYPE_FLOAT64,
GST_TENSOR_TYPE_BFLOAT16,
} GstTensorType;
} GstTensorDataType;
/**
* GstTensor:
*
* @id unique tensor identifier
* @num_dims number of tensor dimensions
* @dims tensor dimensions
* @type @ref GstTensorType of tensor data
* @data @ref GstBuffer holding tensor data
* @id: semantically identify the contents of the tensor
* @num_dims: number of tensor dimensions
* @dims: tensor dimensions
* @type: #GstTensorDataType of tensor data
* @data: #GstBuffer holding tensor data
*
* Since: 1.24
*/
@ -79,7 +81,7 @@ typedef struct _GstTensor
GQuark id;
gint num_dims;
int64_t *dims;
GstTensorType type;
GstTensorDataType data_type;
GstBuffer *data;
} GstTensor;