mirror of
https://gitlab.freedesktop.org/gstreamer/gstreamer.git
synced 2025-01-10 17:35:59 +00:00
onnx: Port SSD detector to C
Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/6001>
This commit is contained in:
parent
5e1291fd86
commit
3325a10f57
8 changed files with 262 additions and 378 deletions
|
@ -1,228 +0,0 @@
|
|||
/*
|
||||
* GStreamer gstreamer-objectdetectorutils
|
||||
* Copyright (C) 2023 Collabora Ltd
|
||||
*
|
||||
* gstobjectdetectorutils.cpp
|
||||
*
|
||||
* This library is free software; you can redistribute it and/or
|
||||
* modify it under the terms of the GNU Library General Public
|
||||
* License as published by the Free Software Foundation; either
|
||||
* version 2 of the License, or (at your option) any later version.
|
||||
*
|
||||
* This library is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||
* Library General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Library General Public
|
||||
* License along with this library; if not, write to the
|
||||
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
|
||||
* Boston, MA 02110-1301, USA.
|
||||
*/
|
||||
|
||||
#include "gstobjectdetectorutils.h"
|
||||
|
||||
#include <gio/gio.h>
|
||||
|
||||
|
||||
char **
|
||||
read_labels (const char * labels_file)
|
||||
{
|
||||
GPtrArray *array;
|
||||
GFile *file = g_file_new_for_path (labels_file);
|
||||
GFileInputStream *file_stream;
|
||||
GDataInputStream *data_stream;
|
||||
GError *error = NULL;
|
||||
gchar *line;
|
||||
|
||||
file_stream = g_file_read (file, NULL, &error);
|
||||
g_object_unref (file);
|
||||
if (!file_stream) {
|
||||
GST_WARNING ("Could not open file %s: %s\n", labels_file,
|
||||
error->message);
|
||||
g_clear_error (&error);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
data_stream = g_data_input_stream_new (G_INPUT_STREAM (file_stream));
|
||||
g_object_unref (file_stream);
|
||||
|
||||
array = g_ptr_array_new();
|
||||
|
||||
while ((line = g_data_input_stream_read_line (data_stream, NULL, NULL,
|
||||
&error)))
|
||||
g_ptr_array_add (array, line);
|
||||
|
||||
g_object_unref (data_stream);
|
||||
|
||||
if (error) {
|
||||
GST_WARNING ("Could not open file %s: %s", labels_file, error->message);
|
||||
g_ptr_array_free (array, TRUE);
|
||||
g_clear_error (&error);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (array->len == 0) {
|
||||
g_ptr_array_free (array, TRUE);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
g_ptr_array_add (array, NULL);
|
||||
|
||||
return (char **) g_ptr_array_free (array, FALSE);
|
||||
}
|
||||
|
||||
|
||||
GstMlBoundingBox::GstMlBoundingBox (std::string lbl, float score, float _x0,
|
||||
float _y0, float _width, float _height):
|
||||
label (lbl),
|
||||
score (score),
|
||||
x0 (_x0),
|
||||
y0 (_y0),
|
||||
width (_width),
|
||||
height (_height)
|
||||
{
|
||||
}
|
||||
|
||||
GstMlBoundingBox::GstMlBoundingBox ():
|
||||
GstMlBoundingBox ("", 0.0f, 0.0f, 0.0f, 0.0f, 0.0f)
|
||||
{
|
||||
}
|
||||
|
||||
namespace GstObjectDetectorUtils
|
||||
{
|
||||
|
||||
GstObjectDetectorUtils::GstObjectDetectorUtils ()
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
std::vector < GstMlBoundingBox > GstObjectDetectorUtils::run (int32_t w,
|
||||
int32_t h, GstTensorMeta * tmeta, gchar **labels,
|
||||
float scoreThreshold)
|
||||
{
|
||||
|
||||
auto classIndex = gst_tensor_meta_get_index_from_id (tmeta,
|
||||
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_CLASSES));
|
||||
if (classIndex == GST_TENSOR_MISSING_ID) {
|
||||
GST_ERROR ("Missing class tensor id");
|
||||
return std::vector < GstMlBoundingBox > ();
|
||||
}
|
||||
auto type = tmeta->tensor[classIndex].type;
|
||||
return (type == GST_TENSOR_TYPE_FLOAT32) ?
|
||||
doRun < float >(w, h, tmeta, labels, scoreThreshold)
|
||||
: doRun < int >(w, h, tmeta, labels, scoreThreshold);
|
||||
}
|
||||
|
||||
template < typename T > std::vector < GstMlBoundingBox >
|
||||
GstObjectDetectorUtils::doRun (int32_t w, int32_t h,
|
||||
GstTensorMeta * tmeta, char **labels, float scoreThreshold)
|
||||
{
|
||||
std::vector < GstMlBoundingBox > boundingBoxes;
|
||||
GstMapInfo map_info[GstObjectDetectorMaxNodes];
|
||||
GstMemory *memory[GstObjectDetectorMaxNodes] = { NULL };
|
||||
gint index;
|
||||
T *numDetections = nullptr, *bboxes = nullptr, *scores =
|
||||
nullptr, *labelIndex = nullptr;
|
||||
|
||||
// number of detections
|
||||
index = gst_tensor_meta_get_index_from_id (tmeta,
|
||||
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS));
|
||||
if (index == GST_TENSOR_MISSING_ID) {
|
||||
GST_WARNING ("Missing tensor data for tensor index %d", index);
|
||||
goto cleanup;
|
||||
}
|
||||
memory[index] = gst_buffer_peek_memory (tmeta->tensor[index].data, 0);
|
||||
if (!memory[index]) {
|
||||
GST_WARNING ("Missing tensor data for tensor index %d", index);
|
||||
goto cleanup;
|
||||
}
|
||||
if (!gst_memory_map (memory[index], map_info + index, GST_MAP_READ)) {
|
||||
GST_WARNING ("Failed to map tensor memory for index %d", index);
|
||||
goto cleanup;
|
||||
}
|
||||
numDetections = (T *) map_info[index].data;
|
||||
|
||||
// bounding boxes
|
||||
index =
|
||||
gst_tensor_meta_get_index_from_id (tmeta,
|
||||
g_quark_from_static_string(GST_MODEL_OBJECT_DETECTOR_BOXES));
|
||||
if (index == GST_TENSOR_MISSING_ID) {
|
||||
GST_WARNING ("Missing tensor data for tensor index %d", index);
|
||||
goto cleanup;
|
||||
}
|
||||
memory[index] = gst_buffer_peek_memory (tmeta->tensor[index].data, 0);
|
||||
if (!memory[index]) {
|
||||
GST_WARNING ("Failed to map tensor memory for index %d", index);
|
||||
goto cleanup;
|
||||
}
|
||||
if (!gst_memory_map (memory[index], map_info + index, GST_MAP_READ)) {
|
||||
GST_ERROR ("Failed to map GstMemory");
|
||||
goto cleanup;
|
||||
}
|
||||
bboxes = (T *) map_info[index].data;
|
||||
|
||||
// scores
|
||||
index =
|
||||
gst_tensor_meta_get_index_from_id (tmeta,
|
||||
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_SCORES));
|
||||
if (index == GST_TENSOR_MISSING_ID) {
|
||||
GST_ERROR ("Missing scores tensor id");
|
||||
goto cleanup;
|
||||
}
|
||||
memory[index] = gst_buffer_peek_memory (tmeta->tensor[index].data, 0);
|
||||
if (!memory[index]) {
|
||||
GST_WARNING ("Missing tensor data for tensor index %d", index);
|
||||
goto cleanup;
|
||||
}
|
||||
if (!gst_memory_map (memory[index], map_info + index, GST_MAP_READ)) {
|
||||
GST_ERROR ("Failed to map GstMemory");
|
||||
goto cleanup;
|
||||
}
|
||||
scores = (T *) map_info[index].data;
|
||||
|
||||
// optional label
|
||||
labelIndex = nullptr;
|
||||
index =
|
||||
gst_tensor_meta_get_index_from_id (tmeta,
|
||||
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_CLASSES));
|
||||
if (index != GST_TENSOR_MISSING_ID) {
|
||||
memory[index] = gst_buffer_peek_memory (tmeta->tensor[index].data, 0);
|
||||
if (!memory[index]) {
|
||||
GST_WARNING ("Missing tensor data for tensor index %d", index);
|
||||
goto cleanup;
|
||||
}
|
||||
if (!gst_memory_map (memory[index], map_info + index, GST_MAP_READ)) {
|
||||
GST_ERROR ("Failed to map GstMemory");
|
||||
goto cleanup;
|
||||
}
|
||||
labelIndex = (T *) map_info[index].data;
|
||||
}
|
||||
|
||||
for (int i = 0; i < numDetections[0]; ++i) {
|
||||
if (scores[i] > scoreThreshold) {
|
||||
std::string label = "";
|
||||
|
||||
if (labels && labelIndex)
|
||||
label = labels[(int)labelIndex[i] - 1];
|
||||
auto score = scores[i];
|
||||
auto y0 = bboxes[i * 4] * h;
|
||||
auto x0 = bboxes[i * 4 + 1] * w;
|
||||
auto bheight = bboxes[i * 4 + 2] * h - y0;
|
||||
auto bwidth = bboxes[i * 4 + 3] * w - x0;
|
||||
boundingBoxes.push_back (GstMlBoundingBox (label, score, x0, y0, bwidth,
|
||||
bheight));
|
||||
}
|
||||
}
|
||||
|
||||
cleanup:
|
||||
for (int i = 0; i < GstObjectDetectorMaxNodes; ++i) {
|
||||
if (memory[i])
|
||||
gst_memory_unmap (memory[i], map_info + i);
|
||||
|
||||
}
|
||||
|
||||
return boundingBoxes;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,83 +0,0 @@
|
|||
/*
|
||||
* GStreamer gstreamer-objectdetectorutils
|
||||
* Copyright (C) 2023 Collabora Ltd
|
||||
*
|
||||
* gstobjectdetectorutils.h
|
||||
*
|
||||
* This library is free software; you can redistribute it and/or
|
||||
* modify it under the terms of the GNU Library General Public
|
||||
* License as published by the Free Software Foundation; either
|
||||
* version 2 of the License, or (at your option) any later version.
|
||||
*
|
||||
* This library is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||
* Library General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Library General Public
|
||||
* License along with this library; if not, write to the
|
||||
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
|
||||
* Boston, MA 02110-1301, USA.
|
||||
*/
|
||||
#ifndef __GST_OBJECT_DETECTOR_UTILS_H__
|
||||
#define __GST_OBJECT_DETECTOR_UTILS_H__
|
||||
|
||||
#include <gst/gst.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "gstml.h"
|
||||
#include "tensor/gsttensormeta.h"
|
||||
|
||||
char ** read_labels (const char * labels_file);
|
||||
|
||||
/* Object detection tensor id strings */
|
||||
#define GST_MODEL_OBJECT_DETECTOR_BOXES "Gst.Model.ObjectDetector.Boxes"
|
||||
#define GST_MODEL_OBJECT_DETECTOR_SCORES "Gst.Model.ObjectDetector.Scores"
|
||||
#define GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS "Gst.Model.ObjectDetector.NumDetections"
|
||||
#define GST_MODEL_OBJECT_DETECTOR_CLASSES "Gst.Model.ObjectDetector.Classes"
|
||||
|
||||
|
||||
/**
|
||||
* GstMlBoundingBox:
|
||||
*
|
||||
* @label label
|
||||
* @score detection confidence
|
||||
* @x0 top left hand x coordinate
|
||||
* @y0 top left hand y coordinate
|
||||
* @width width
|
||||
* @height height
|
||||
*
|
||||
* Since: 1.20
|
||||
*/
|
||||
struct GstMlBoundingBox {
|
||||
GstMlBoundingBox(std::string lbl, float score, float _x0, float _y0,
|
||||
float _width, float _height);
|
||||
GstMlBoundingBox();
|
||||
std::string label;
|
||||
float score;
|
||||
float x0;
|
||||
float y0;
|
||||
float width;
|
||||
float height;
|
||||
};
|
||||
|
||||
namespace GstObjectDetectorUtils {
|
||||
const int GstObjectDetectorMaxNodes = 4;
|
||||
class GstObjectDetectorUtils {
|
||||
public:
|
||||
GstObjectDetectorUtils(void);
|
||||
~GstObjectDetectorUtils(void) = default;
|
||||
std::vector < GstMlBoundingBox > run(int32_t w, int32_t h,
|
||||
GstTensorMeta *tmeta,
|
||||
char **labels,
|
||||
float scoreThreshold);
|
||||
private:
|
||||
template < typename T > std::vector < GstMlBoundingBox >
|
||||
doRun(int32_t w, int32_t h,
|
||||
GstTensorMeta *tmeta, char **labels,
|
||||
float scoreThreshold);
|
||||
};
|
||||
}
|
||||
|
||||
#endif /* __GST_OBJECT_DETECTOR_UTILS_H__ */
|
|
@ -44,19 +44,23 @@
|
|||
#include "config.h"
|
||||
#endif
|
||||
|
||||
|
||||
#include "gstssdobjectdetector.h"
|
||||
#include "gstobjectdetectorutils.h"
|
||||
|
||||
#include <gio/gio.h>
|
||||
|
||||
#include <gst/gst.h>
|
||||
#include <gst/video/video.h>
|
||||
|
||||
#include <gst/analytics/analytics.h>
|
||||
#include "tensor/gsttensormeta.h"
|
||||
|
||||
/* Object detection tensor id strings */
|
||||
#define GST_MODEL_OBJECT_DETECTOR_BOXES "Gst.Model.ObjectDetector.Boxes"
|
||||
#define GST_MODEL_OBJECT_DETECTOR_SCORES "Gst.Model.ObjectDetector.Scores"
|
||||
#define GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS "Gst.Model.ObjectDetector.NumDetections"
|
||||
#define GST_MODEL_OBJECT_DETECTOR_CLASSES "Gst.Model.ObjectDetector.Classes"
|
||||
|
||||
GST_DEBUG_CATEGORY_STATIC (ssd_object_detector_debug);
|
||||
#define GST_CAT_DEFAULT ssd_object_detector_debug
|
||||
#define GST_ODUTILS_MEMBER( self ) ((GstObjectDetectorUtils::GstObjectDetectorUtils *) (self->odutils))
|
||||
GST_ELEMENT_REGISTER_DEFINE (ssd_object_detector, "ssdobjectdetector",
|
||||
GST_RANK_PRIMARY, GST_TYPE_SSD_OBJECT_DETECTOR);
|
||||
|
||||
|
@ -68,7 +72,7 @@ enum
|
|||
PROP_SCORE_THRESHOLD,
|
||||
};
|
||||
|
||||
#define GST_SSD_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD 0.3f /* 0 to 1 */
|
||||
#define GST_SSD_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD 0.3f /* 0 to 1 */
|
||||
|
||||
static GstStaticPadTemplate gst_ssd_object_detector_src_template =
|
||||
GST_STATIC_PAD_TEMPLATE ("src",
|
||||
|
@ -97,7 +101,8 @@ static gboolean
|
|||
gst_ssd_object_detector_set_caps (GstBaseTransform * trans, GstCaps * incaps,
|
||||
GstCaps * outcaps);
|
||||
|
||||
G_DEFINE_TYPE (GstSsdObjectDetector, gst_ssd_object_detector, GST_TYPE_BASE_TRANSFORM);
|
||||
G_DEFINE_TYPE (GstSsdObjectDetector, gst_ssd_object_detector,
|
||||
GST_TYPE_BASE_TRANSFORM);
|
||||
|
||||
static void
|
||||
gst_ssd_object_detector_class_init (GstSsdObjectDetectorClass * klass)
|
||||
|
@ -135,7 +140,8 @@ gst_ssd_object_detector_class_init (GstSsdObjectDetectorClass * klass)
|
|||
g_param_spec_float ("score-threshold",
|
||||
"Score threshold",
|
||||
"Threshold for deciding when to remove boxes based on score",
|
||||
0.0, 1.0, GST_SSD_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD, (GParamFlags)
|
||||
0.0, 1.0, GST_SSD_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD,
|
||||
(GParamFlags)
|
||||
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
||||
|
||||
gst_element_class_set_static_metadata (element_class, "objectdetector",
|
||||
|
@ -155,7 +161,6 @@ gst_ssd_object_detector_class_init (GstSsdObjectDetectorClass * klass)
|
|||
static void
|
||||
gst_ssd_object_detector_init (GstSsdObjectDetector * self)
|
||||
{
|
||||
self->odutils = new GstObjectDetectorUtils::GstObjectDetectorUtils ();
|
||||
}
|
||||
|
||||
static void
|
||||
|
@ -164,12 +169,58 @@ gst_ssd_object_detector_finalize (GObject * object)
|
|||
GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (object);
|
||||
|
||||
g_free (self->label_file);
|
||||
g_strfreev (self->labels);
|
||||
delete GST_ODUTILS_MEMBER (self);
|
||||
g_clear_pointer (&self->labels, g_array_unref);
|
||||
|
||||
G_OBJECT_CLASS (gst_ssd_object_detector_parent_class)->finalize (object);
|
||||
}
|
||||
|
||||
static GArray *
|
||||
read_labels (const char *labels_file)
|
||||
{
|
||||
GArray *array;
|
||||
GFile *file = g_file_new_for_path (labels_file);
|
||||
GFileInputStream *file_stream;
|
||||
GDataInputStream *data_stream;
|
||||
GError *error = NULL;
|
||||
gchar *line;
|
||||
|
||||
file_stream = g_file_read (file, NULL, &error);
|
||||
g_object_unref (file);
|
||||
if (!file_stream) {
|
||||
GST_WARNING ("Could not open file %s: %s\n", labels_file, error->message);
|
||||
g_clear_error (&error);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
data_stream = g_data_input_stream_new (G_INPUT_STREAM (file_stream));
|
||||
g_object_unref (file_stream);
|
||||
|
||||
array = g_array_new (FALSE, FALSE, sizeof (GQuark));
|
||||
|
||||
while ((line = g_data_input_stream_read_line (data_stream, NULL, NULL,
|
||||
&error))) {
|
||||
GQuark label = g_quark_from_string (line);
|
||||
g_array_append_val (array, label);
|
||||
g_free (line);
|
||||
}
|
||||
|
||||
g_object_unref (data_stream);
|
||||
|
||||
if (error) {
|
||||
GST_WARNING ("Could not open file %s: %s", labels_file, error->message);
|
||||
g_array_free (array, TRUE);
|
||||
g_clear_error (&error);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (array->len == 0) {
|
||||
g_array_free (array, TRUE);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return array;
|
||||
}
|
||||
|
||||
static void
|
||||
gst_ssd_object_detector_set_property (GObject * object, guint prop_id,
|
||||
const GValue * value, GParamSpec * pspec)
|
||||
|
@ -179,21 +230,21 @@ gst_ssd_object_detector_set_property (GObject * object, guint prop_id,
|
|||
|
||||
switch (prop_id) {
|
||||
case PROP_LABEL_FILE:
|
||||
{
|
||||
gchar **labels;
|
||||
{
|
||||
GArray *labels;
|
||||
|
||||
filename = g_value_get_string (value);
|
||||
labels = read_labels (filename);
|
||||
filename = g_value_get_string (value);
|
||||
labels = read_labels (filename);
|
||||
|
||||
if (labels) {
|
||||
g_free (self->label_file);
|
||||
self->label_file = g_strdup (filename);
|
||||
g_strfreev (self->labels);
|
||||
self->labels = labels;
|
||||
} else {
|
||||
GST_WARNING_OBJECT (self, "Label file '%s' not found!", filename);
|
||||
}
|
||||
if (labels) {
|
||||
g_free (self->label_file);
|
||||
self->label_file = g_strdup (filename);
|
||||
g_clear_pointer (&self->labels, g_array_unref);
|
||||
self->labels = labels;
|
||||
} else {
|
||||
GST_WARNING_OBJECT (self, "Label file '%s' not found!", filename);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case PROP_SCORE_THRESHOLD:
|
||||
GST_OBJECT_LOCK (self);
|
||||
|
@ -259,7 +310,8 @@ gst_ssd_object_detector_get_tensor_meta (GstSsdObjectDetector * object_detector,
|
|||
gint clasesIndex = gst_tensor_meta_get_index_from_id (tensor_meta,
|
||||
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_CLASSES));
|
||||
|
||||
if (boxesIndex == GST_TENSOR_MISSING_ID || scoresIndex == GST_TENSOR_MISSING_ID
|
||||
if (boxesIndex == GST_TENSOR_MISSING_ID
|
||||
|| scoresIndex == GST_TENSOR_MISSING_ID
|
||||
|| numDetectionsIndex == GST_TENSOR_MISSING_ID)
|
||||
continue;
|
||||
|
||||
|
@ -300,13 +352,175 @@ gst_ssd_object_detector_transform_ip (GstBaseTransform * trans, GstBuffer * buf)
|
|||
return GST_FLOW_OK;
|
||||
}
|
||||
|
||||
#define DEFINE_GET_FUNC(TYPE, MAX) \
|
||||
static gboolean \
|
||||
get_ ## TYPE ## _at_index (GstTensor *tensor, GstMapInfo *map, \
|
||||
guint index, TYPE * out) \
|
||||
{ \
|
||||
switch (tensor->type) { \
|
||||
case GST_TENSOR_TYPE_FLOAT32: { \
|
||||
float *f = (float *) map->data; \
|
||||
if (sizeof(*f) * (index + 1) > map->size) \
|
||||
return FALSE; \
|
||||
*out = f[index]; \
|
||||
break; \
|
||||
} \
|
||||
case GST_TENSOR_TYPE_UINT32: { \
|
||||
guint32 *u = (guint32 *) map->data; \
|
||||
if (sizeof(*u) * (index + 1) > map->size) \
|
||||
return FALSE; \
|
||||
*out = u[index]; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
GST_ERROR ("Only float32 and int32 tensors are understood"); \
|
||||
return FALSE; \
|
||||
} \
|
||||
return TRUE; \
|
||||
}
|
||||
|
||||
DEFINE_GET_FUNC (guint32, UINT32_MAX)
|
||||
DEFINE_GET_FUNC (float, FLOAT_MAX)
|
||||
#undef DEFINE_GET_FUNC
|
||||
static void
|
||||
extract_bounding_boxes (GstSsdObjectDetector * self, gsize w, gsize h,
|
||||
GstAnalyticsRelationMeta * rmeta, GstTensorMeta * tmeta)
|
||||
{
|
||||
gint classes_index;
|
||||
gint boxes_index;
|
||||
gint scores_index;
|
||||
gint numdetect_index;
|
||||
|
||||
GstMapInfo boxes_map = GST_MAP_INFO_INIT;
|
||||
GstMapInfo numdetect_map = GST_MAP_INFO_INIT;
|
||||
GstMapInfo scores_map = GST_MAP_INFO_INIT;
|
||||
GstMapInfo classes_map = GST_MAP_INFO_INIT;
|
||||
|
||||
guint num_detections = 0;
|
||||
|
||||
classes_index = gst_tensor_meta_get_index_from_id (tmeta,
|
||||
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_CLASSES));
|
||||
numdetect_index = gst_tensor_meta_get_index_from_id (tmeta,
|
||||
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS));
|
||||
scores_index = gst_tensor_meta_get_index_from_id (tmeta,
|
||||
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_SCORES));
|
||||
boxes_index = gst_tensor_meta_get_index_from_id (tmeta,
|
||||
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_BOXES));
|
||||
|
||||
if (numdetect_index == GST_TENSOR_MISSING_ID
|
||||
|| scores_index == GST_TENSOR_MISSING_ID
|
||||
|| numdetect_index == GST_TENSOR_MISSING_ID) {
|
||||
GST_WARNING ("Missing tensor data expected for SSD model");
|
||||
return;
|
||||
}
|
||||
|
||||
if (!gst_buffer_map (tmeta->tensor[numdetect_index].data, &numdetect_map,
|
||||
GST_MAP_READ)) {
|
||||
GST_ERROR_OBJECT (self, "Failed to map tensor memory for index %d",
|
||||
numdetect_index);
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
if (!gst_buffer_map (tmeta->tensor[boxes_index].data, &boxes_map,
|
||||
GST_MAP_READ)) {
|
||||
GST_ERROR_OBJECT (self, "Failed to map tensor memory for index %d",
|
||||
boxes_index);
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
if (!gst_buffer_map (tmeta->tensor[scores_index].data, &scores_map,
|
||||
GST_MAP_READ)) {
|
||||
GST_ERROR_OBJECT (self, "Failed to map tensor memory for index %d",
|
||||
scores_index);
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
if (classes_index != GST_TENSOR_MISSING_ID &&
|
||||
!gst_buffer_map (tmeta->tensor[classes_index].data, &classes_map,
|
||||
GST_MAP_READ)) {
|
||||
GST_DEBUG_OBJECT (self, "Failed to map tensor memory for index %d",
|
||||
classes_index);
|
||||
}
|
||||
|
||||
|
||||
if (!get_guint32_at_index (&tmeta->tensor[numdetect_index], &numdetect_map,
|
||||
0, &num_detections)) {
|
||||
GST_ERROR_OBJECT (self, "Failed to get the number of detections");
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
|
||||
GST_LOG_OBJECT (self, "Model claims %d detections", num_detections);
|
||||
|
||||
for (int i = 0; i < num_detections; i++) {
|
||||
float score;
|
||||
float x, y, bwidth, bheight;
|
||||
gint x_i, y_i, bwidth_i, bheight_i;
|
||||
guint32 bclass;
|
||||
GQuark label = 0;
|
||||
GstAnalyticsODMtd odmtd;
|
||||
|
||||
if (!get_float_at_index (&tmeta->tensor[numdetect_index], &scores_map,
|
||||
i, &score))
|
||||
continue;
|
||||
|
||||
GST_LOG_OBJECT (self, "Detection %u score is %f", i, score);
|
||||
if (score < self->score_threshold)
|
||||
continue;
|
||||
|
||||
if (!get_float_at_index (&tmeta->tensor[boxes_index], &boxes_map,
|
||||
i * 4, &y))
|
||||
continue;
|
||||
if (!get_float_at_index (&tmeta->tensor[boxes_index], &boxes_map,
|
||||
i * 4 + 1, &x))
|
||||
continue;
|
||||
if (!get_float_at_index (&tmeta->tensor[boxes_index], &boxes_map,
|
||||
i * 4 + 2, &bheight))
|
||||
continue;
|
||||
if (!get_float_at_index (&tmeta->tensor[boxes_index], &boxes_map,
|
||||
i * 4 + 3, &bwidth))
|
||||
continue;
|
||||
|
||||
if (self->labels && classes_map.memory &&
|
||||
get_guint32_at_index (&tmeta->tensor[classes_index], &classes_map,
|
||||
i, &bclass)) {
|
||||
if (bclass < self->labels->len)
|
||||
label = g_array_index (self->labels, GQuark, bclass);
|
||||
}
|
||||
|
||||
x_i = x * w;
|
||||
y_i = y * h;
|
||||
bheight_i = (bheight * h) - y_i;
|
||||
bwidth_i = (bwidth * w) - x_i;
|
||||
|
||||
if (gst_analytics_relation_meta_add_od_mtd (rmeta, label,
|
||||
x_i, y_i, bwidth_i, bheight_i, score, &odmtd))
|
||||
GST_DEBUG_OBJECT (self,
|
||||
"Object detected with label : %s, score: %f, bound box: %dx%d at (%d,%d)",
|
||||
g_quark_to_string (label), score, bwidth_i, bheight_i, x_i, y_i);
|
||||
else
|
||||
GST_WARNING_OBJECT (self, "Could not add detection to meta");
|
||||
}
|
||||
|
||||
cleanup:
|
||||
|
||||
if (numdetect_map.memory)
|
||||
gst_buffer_unmap (tmeta->tensor[numdetect_index].data, &numdetect_map);
|
||||
if (classes_map.memory)
|
||||
gst_buffer_unmap (tmeta->tensor[classes_index].data, &classes_map);
|
||||
if (scores_map.memory)
|
||||
gst_buffer_unmap (tmeta->tensor[scores_index].data, &scores_map);
|
||||
if (boxes_map.memory)
|
||||
gst_buffer_unmap (tmeta->tensor[boxes_index].data, &boxes_map);
|
||||
}
|
||||
|
||||
|
||||
static gboolean
|
||||
gst_ssd_object_detector_process (GstBaseTransform * trans, GstBuffer * buf)
|
||||
{
|
||||
GstTensorMeta *tmeta = NULL;
|
||||
GstAnalyticsODMtd odmtd;
|
||||
GstAnalyticsRelationMeta *rmeta;
|
||||
GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (trans);
|
||||
GstTensorMeta *tmeta;
|
||||
GstAnalyticsRelationMeta *rmeta;
|
||||
|
||||
// get all tensor metas
|
||||
tmeta = gst_ssd_object_detector_get_tensor_meta (self, buf);
|
||||
|
@ -315,25 +529,11 @@ gst_ssd_object_detector_process (GstBaseTransform * trans, GstBuffer * buf)
|
|||
return TRUE;
|
||||
} else {
|
||||
rmeta = gst_buffer_add_analytics_relation_meta (buf);
|
||||
g_return_val_if_fail (rmeta != NULL, FALSE);
|
||||
g_assert (rmeta);
|
||||
}
|
||||
|
||||
std::vector < GstMlBoundingBox > boxes =
|
||||
GST_ODUTILS_MEMBER (self)->run (self->video_info.width,
|
||||
self->video_info.height, tmeta, self->labels,
|
||||
self->score_threshold);
|
||||
|
||||
for (auto & b:boxes) {
|
||||
if (gst_analytics_relation_meta_add_od_mtd (rmeta,
|
||||
g_quark_from_string(b.label.c_str ()), b.x0, b.y0, b.width, b.height,
|
||||
b.score, &odmtd)) {
|
||||
GST_DEBUG_OBJECT (self,
|
||||
"Object detected with label : %s, score: %f, bound box: (%f,%f,%f,%f)",
|
||||
b.label.c_str (), b.score, b.x0, b.y0, b.x0 + b.width, b.y0 + b.height);
|
||||
} else {
|
||||
GST_ERROR_OBJECT (self, "Failed to add object detection analytics-meta");
|
||||
}
|
||||
}
|
||||
extract_bounding_boxes (self, self->video_info.width,
|
||||
self->video_info.height, rmeta, tmeta);
|
||||
|
||||
return TRUE;
|
||||
}
|
|
@ -37,14 +37,11 @@ G_DECLARE_FINAL_TYPE (GstSsdObjectDetector, gst_ssd_object_detector, GST, SSD_OB
|
|||
#define GST_SSD_OBJECT_DETECTOR_META_FIELD_LABEL "label"
|
||||
#define GST_SSD_OBJECT_DETECTOR_META_FIELD_SCORE "score"
|
||||
|
||||
/**
|
||||
/*
|
||||
* GstSsdObjectDetector:
|
||||
*
|
||||
* @label_file label file
|
||||
* @score_threshold score threshold
|
||||
* @confidence_threshold confidence threshold
|
||||
* @iou_threhsold iou threshold
|
||||
* @od_ptr opaque pointer to GstOd object detection implementation
|
||||
*
|
||||
* Since: 1.20
|
||||
*/
|
||||
|
@ -52,11 +49,8 @@ struct _GstSsdObjectDetector
|
|||
{
|
||||
GstBaseTransform basetransform;
|
||||
gchar *label_file;
|
||||
gchar **labels;
|
||||
GArray *labels;
|
||||
gfloat score_threshold;
|
||||
gfloat confidence_threshold;
|
||||
gfloat iou_threshold;
|
||||
gpointer odutils;
|
||||
GstVideoInfo video_info;
|
||||
};
|
||||
|
||||
|
|
|
@ -96,7 +96,7 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
|
|||
return inputImageFormat;
|
||||
}
|
||||
|
||||
void GstOnnxClient::setInputImageDatatype(GstTensorType datatype)
|
||||
void GstOnnxClient::setInputImageDatatype(GstTensorDataType datatype)
|
||||
{
|
||||
inputDatatype = datatype;
|
||||
switch (inputDatatype) {
|
||||
|
@ -144,7 +144,7 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
|
|||
return inputTensorScale;
|
||||
}
|
||||
|
||||
GstTensorType GstOnnxClient::getInputImageDatatype(void)
|
||||
GstTensorDataType GstOnnxClient::getInputImageDatatype(void)
|
||||
{
|
||||
return inputDatatype;
|
||||
}
|
||||
|
|
|
@ -56,7 +56,7 @@ namespace GstOnnxNamespace {
|
|||
bool hasSession(void);
|
||||
void setInputImageFormat(GstMlInputImageFormat format);
|
||||
GstMlInputImageFormat getInputImageFormat(void);
|
||||
GstTensorType getInputImageDatatype(void);
|
||||
GstTensorDataType getInputImageDatatype(void);
|
||||
void setInputImageOffset (float offset);
|
||||
float getInputImageOffset ();
|
||||
void setInputImageScale (float offset);
|
||||
|
@ -73,7 +73,7 @@ namespace GstOnnxNamespace {
|
|||
private:
|
||||
|
||||
GstElement *debug_parent;
|
||||
void setInputImageDatatype (GstTensorType datatype);
|
||||
void setInputImageDatatype (GstTensorDataType datatype);
|
||||
template < typename T>
|
||||
void convert_image_remove_alpha (T *dest, GstMlInputImageFormat hwc,
|
||||
uint8_t **srcPtr, uint32_t srcSamplesPerPixel, uint32_t stride, T offset, T div);
|
||||
|
@ -91,7 +91,7 @@ namespace GstOnnxNamespace {
|
|||
std::vector < Ort::AllocatedStringPtr > outputNames;
|
||||
std::vector < GQuark > outputIds;
|
||||
GstMlInputImageFormat inputImageFormat;
|
||||
GstTensorType inputDatatype;
|
||||
GstTensorDataType inputDatatype;
|
||||
size_t inputDatatypeSize;
|
||||
bool fixedInputImageSize;
|
||||
float inputTensorOffset;
|
||||
|
|
|
@ -15,8 +15,7 @@ endif
|
|||
if onnxrt_dep.found()
|
||||
gstonnx = library('gstonnx',
|
||||
'gstonnx.c',
|
||||
'decoders/gstobjectdetectorutils.cpp',
|
||||
'decoders/gstssdobjectdetector.cpp',
|
||||
'decoders/gstssdobjectdetector.c',
|
||||
'gstonnxinference.cpp',
|
||||
'gstonnxclient.cpp',
|
||||
'tensor/gsttensormeta.c',
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
#include <gst/gst.h>
|
||||
|
||||
/**
|
||||
* GstTensorType:
|
||||
* GstTensorDataType:
|
||||
*
|
||||
* @GST_TENSOR_TYPE_INT4 signed 4 bit integer tensor data
|
||||
* @GST_TENSOR_TYPE_INT8 signed 8 bit integer tensor data
|
||||
|
@ -42,9 +42,11 @@
|
|||
* @GST_TENSOR_TYPE_FLOAT64 64 bit floating point tensor data
|
||||
* @GST_TENSOR_TYPE_BFLOAT16 "brain" 16 bit floating point tensor data
|
||||
*
|
||||
* Describe the type of data contain in the tensor.
|
||||
*
|
||||
* Since: 1.24
|
||||
*/
|
||||
typedef enum _GstTensorType
|
||||
typedef enum _GstTensorDataType
|
||||
{
|
||||
GST_TENSOR_TYPE_INT4,
|
||||
GST_TENSOR_TYPE_INT8,
|
||||
|
@ -60,17 +62,17 @@ typedef enum _GstTensorType
|
|||
GST_TENSOR_TYPE_FLOAT32,
|
||||
GST_TENSOR_TYPE_FLOAT64,
|
||||
GST_TENSOR_TYPE_BFLOAT16,
|
||||
} GstTensorType;
|
||||
} GstTensorDataType;
|
||||
|
||||
|
||||
/**
|
||||
* GstTensor:
|
||||
*
|
||||
* @id unique tensor identifier
|
||||
* @num_dims number of tensor dimensions
|
||||
* @dims tensor dimensions
|
||||
* @type @ref GstTensorType of tensor data
|
||||
* @data @ref GstBuffer holding tensor data
|
||||
* @id: semantically identify the contents of the tensor
|
||||
* @num_dims: number of tensor dimensions
|
||||
* @dims: tensor dimensions
|
||||
* @type: #GstTensorDataType of tensor data
|
||||
* @data: #GstBuffer holding tensor data
|
||||
*
|
||||
* Since: 1.24
|
||||
*/
|
||||
|
@ -79,7 +81,7 @@ typedef struct _GstTensor
|
|||
GQuark id;
|
||||
gint num_dims;
|
||||
int64_t *dims;
|
||||
GstTensorType type;
|
||||
GstTensorDataType data_type;
|
||||
GstBuffer *data;
|
||||
} GstTensor;
|
||||
|
||||
|
|
Loading…
Reference in a new issue