mirror of
https://gitlab.freedesktop.org/gstreamer/gstreamer.git
synced 2025-02-17 03:35:21 +00:00
onnx: Port SSD detector to C
Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/6001>
This commit is contained in:
parent
5e1291fd86
commit
3325a10f57
8 changed files with 262 additions and 378 deletions
|
@ -1,228 +0,0 @@
|
||||||
/*
|
|
||||||
* GStreamer gstreamer-objectdetectorutils
|
|
||||||
* Copyright (C) 2023 Collabora Ltd
|
|
||||||
*
|
|
||||||
* gstobjectdetectorutils.cpp
|
|
||||||
*
|
|
||||||
* This library is free software; you can redistribute it and/or
|
|
||||||
* modify it under the terms of the GNU Library General Public
|
|
||||||
* License as published by the Free Software Foundation; either
|
|
||||||
* version 2 of the License, or (at your option) any later version.
|
|
||||||
*
|
|
||||||
* This library is distributed in the hope that it will be useful,
|
|
||||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
||||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
||||||
* Library General Public License for more details.
|
|
||||||
*
|
|
||||||
* You should have received a copy of the GNU Library General Public
|
|
||||||
* License along with this library; if not, write to the
|
|
||||||
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
|
|
||||||
* Boston, MA 02110-1301, USA.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "gstobjectdetectorutils.h"
|
|
||||||
|
|
||||||
#include <gio/gio.h>
|
|
||||||
|
|
||||||
|
|
||||||
char **
|
|
||||||
read_labels (const char * labels_file)
|
|
||||||
{
|
|
||||||
GPtrArray *array;
|
|
||||||
GFile *file = g_file_new_for_path (labels_file);
|
|
||||||
GFileInputStream *file_stream;
|
|
||||||
GDataInputStream *data_stream;
|
|
||||||
GError *error = NULL;
|
|
||||||
gchar *line;
|
|
||||||
|
|
||||||
file_stream = g_file_read (file, NULL, &error);
|
|
||||||
g_object_unref (file);
|
|
||||||
if (!file_stream) {
|
|
||||||
GST_WARNING ("Could not open file %s: %s\n", labels_file,
|
|
||||||
error->message);
|
|
||||||
g_clear_error (&error);
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
data_stream = g_data_input_stream_new (G_INPUT_STREAM (file_stream));
|
|
||||||
g_object_unref (file_stream);
|
|
||||||
|
|
||||||
array = g_ptr_array_new();
|
|
||||||
|
|
||||||
while ((line = g_data_input_stream_read_line (data_stream, NULL, NULL,
|
|
||||||
&error)))
|
|
||||||
g_ptr_array_add (array, line);
|
|
||||||
|
|
||||||
g_object_unref (data_stream);
|
|
||||||
|
|
||||||
if (error) {
|
|
||||||
GST_WARNING ("Could not open file %s: %s", labels_file, error->message);
|
|
||||||
g_ptr_array_free (array, TRUE);
|
|
||||||
g_clear_error (&error);
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (array->len == 0) {
|
|
||||||
g_ptr_array_free (array, TRUE);
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
g_ptr_array_add (array, NULL);
|
|
||||||
|
|
||||||
return (char **) g_ptr_array_free (array, FALSE);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
GstMlBoundingBox::GstMlBoundingBox (std::string lbl, float score, float _x0,
|
|
||||||
float _y0, float _width, float _height):
|
|
||||||
label (lbl),
|
|
||||||
score (score),
|
|
||||||
x0 (_x0),
|
|
||||||
y0 (_y0),
|
|
||||||
width (_width),
|
|
||||||
height (_height)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
GstMlBoundingBox::GstMlBoundingBox ():
|
|
||||||
GstMlBoundingBox ("", 0.0f, 0.0f, 0.0f, 0.0f, 0.0f)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace GstObjectDetectorUtils
|
|
||||||
{
|
|
||||||
|
|
||||||
GstObjectDetectorUtils::GstObjectDetectorUtils ()
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
std::vector < GstMlBoundingBox > GstObjectDetectorUtils::run (int32_t w,
|
|
||||||
int32_t h, GstTensorMeta * tmeta, gchar **labels,
|
|
||||||
float scoreThreshold)
|
|
||||||
{
|
|
||||||
|
|
||||||
auto classIndex = gst_tensor_meta_get_index_from_id (tmeta,
|
|
||||||
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_CLASSES));
|
|
||||||
if (classIndex == GST_TENSOR_MISSING_ID) {
|
|
||||||
GST_ERROR ("Missing class tensor id");
|
|
||||||
return std::vector < GstMlBoundingBox > ();
|
|
||||||
}
|
|
||||||
auto type = tmeta->tensor[classIndex].type;
|
|
||||||
return (type == GST_TENSOR_TYPE_FLOAT32) ?
|
|
||||||
doRun < float >(w, h, tmeta, labels, scoreThreshold)
|
|
||||||
: doRun < int >(w, h, tmeta, labels, scoreThreshold);
|
|
||||||
}
|
|
||||||
|
|
||||||
template < typename T > std::vector < GstMlBoundingBox >
|
|
||||||
GstObjectDetectorUtils::doRun (int32_t w, int32_t h,
|
|
||||||
GstTensorMeta * tmeta, char **labels, float scoreThreshold)
|
|
||||||
{
|
|
||||||
std::vector < GstMlBoundingBox > boundingBoxes;
|
|
||||||
GstMapInfo map_info[GstObjectDetectorMaxNodes];
|
|
||||||
GstMemory *memory[GstObjectDetectorMaxNodes] = { NULL };
|
|
||||||
gint index;
|
|
||||||
T *numDetections = nullptr, *bboxes = nullptr, *scores =
|
|
||||||
nullptr, *labelIndex = nullptr;
|
|
||||||
|
|
||||||
// number of detections
|
|
||||||
index = gst_tensor_meta_get_index_from_id (tmeta,
|
|
||||||
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS));
|
|
||||||
if (index == GST_TENSOR_MISSING_ID) {
|
|
||||||
GST_WARNING ("Missing tensor data for tensor index %d", index);
|
|
||||||
goto cleanup;
|
|
||||||
}
|
|
||||||
memory[index] = gst_buffer_peek_memory (tmeta->tensor[index].data, 0);
|
|
||||||
if (!memory[index]) {
|
|
||||||
GST_WARNING ("Missing tensor data for tensor index %d", index);
|
|
||||||
goto cleanup;
|
|
||||||
}
|
|
||||||
if (!gst_memory_map (memory[index], map_info + index, GST_MAP_READ)) {
|
|
||||||
GST_WARNING ("Failed to map tensor memory for index %d", index);
|
|
||||||
goto cleanup;
|
|
||||||
}
|
|
||||||
numDetections = (T *) map_info[index].data;
|
|
||||||
|
|
||||||
// bounding boxes
|
|
||||||
index =
|
|
||||||
gst_tensor_meta_get_index_from_id (tmeta,
|
|
||||||
g_quark_from_static_string(GST_MODEL_OBJECT_DETECTOR_BOXES));
|
|
||||||
if (index == GST_TENSOR_MISSING_ID) {
|
|
||||||
GST_WARNING ("Missing tensor data for tensor index %d", index);
|
|
||||||
goto cleanup;
|
|
||||||
}
|
|
||||||
memory[index] = gst_buffer_peek_memory (tmeta->tensor[index].data, 0);
|
|
||||||
if (!memory[index]) {
|
|
||||||
GST_WARNING ("Failed to map tensor memory for index %d", index);
|
|
||||||
goto cleanup;
|
|
||||||
}
|
|
||||||
if (!gst_memory_map (memory[index], map_info + index, GST_MAP_READ)) {
|
|
||||||
GST_ERROR ("Failed to map GstMemory");
|
|
||||||
goto cleanup;
|
|
||||||
}
|
|
||||||
bboxes = (T *) map_info[index].data;
|
|
||||||
|
|
||||||
// scores
|
|
||||||
index =
|
|
||||||
gst_tensor_meta_get_index_from_id (tmeta,
|
|
||||||
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_SCORES));
|
|
||||||
if (index == GST_TENSOR_MISSING_ID) {
|
|
||||||
GST_ERROR ("Missing scores tensor id");
|
|
||||||
goto cleanup;
|
|
||||||
}
|
|
||||||
memory[index] = gst_buffer_peek_memory (tmeta->tensor[index].data, 0);
|
|
||||||
if (!memory[index]) {
|
|
||||||
GST_WARNING ("Missing tensor data for tensor index %d", index);
|
|
||||||
goto cleanup;
|
|
||||||
}
|
|
||||||
if (!gst_memory_map (memory[index], map_info + index, GST_MAP_READ)) {
|
|
||||||
GST_ERROR ("Failed to map GstMemory");
|
|
||||||
goto cleanup;
|
|
||||||
}
|
|
||||||
scores = (T *) map_info[index].data;
|
|
||||||
|
|
||||||
// optional label
|
|
||||||
labelIndex = nullptr;
|
|
||||||
index =
|
|
||||||
gst_tensor_meta_get_index_from_id (tmeta,
|
|
||||||
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_CLASSES));
|
|
||||||
if (index != GST_TENSOR_MISSING_ID) {
|
|
||||||
memory[index] = gst_buffer_peek_memory (tmeta->tensor[index].data, 0);
|
|
||||||
if (!memory[index]) {
|
|
||||||
GST_WARNING ("Missing tensor data for tensor index %d", index);
|
|
||||||
goto cleanup;
|
|
||||||
}
|
|
||||||
if (!gst_memory_map (memory[index], map_info + index, GST_MAP_READ)) {
|
|
||||||
GST_ERROR ("Failed to map GstMemory");
|
|
||||||
goto cleanup;
|
|
||||||
}
|
|
||||||
labelIndex = (T *) map_info[index].data;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < numDetections[0]; ++i) {
|
|
||||||
if (scores[i] > scoreThreshold) {
|
|
||||||
std::string label = "";
|
|
||||||
|
|
||||||
if (labels && labelIndex)
|
|
||||||
label = labels[(int)labelIndex[i] - 1];
|
|
||||||
auto score = scores[i];
|
|
||||||
auto y0 = bboxes[i * 4] * h;
|
|
||||||
auto x0 = bboxes[i * 4 + 1] * w;
|
|
||||||
auto bheight = bboxes[i * 4 + 2] * h - y0;
|
|
||||||
auto bwidth = bboxes[i * 4 + 3] * w - x0;
|
|
||||||
boundingBoxes.push_back (GstMlBoundingBox (label, score, x0, y0, bwidth,
|
|
||||||
bheight));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cleanup:
|
|
||||||
for (int i = 0; i < GstObjectDetectorMaxNodes; ++i) {
|
|
||||||
if (memory[i])
|
|
||||||
gst_memory_unmap (memory[i], map_info + i);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return boundingBoxes;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,83 +0,0 @@
|
||||||
/*
|
|
||||||
* GStreamer gstreamer-objectdetectorutils
|
|
||||||
* Copyright (C) 2023 Collabora Ltd
|
|
||||||
*
|
|
||||||
* gstobjectdetectorutils.h
|
|
||||||
*
|
|
||||||
* This library is free software; you can redistribute it and/or
|
|
||||||
* modify it under the terms of the GNU Library General Public
|
|
||||||
* License as published by the Free Software Foundation; either
|
|
||||||
* version 2 of the License, or (at your option) any later version.
|
|
||||||
*
|
|
||||||
* This library is distributed in the hope that it will be useful,
|
|
||||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
||||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
||||||
* Library General Public License for more details.
|
|
||||||
*
|
|
||||||
* You should have received a copy of the GNU Library General Public
|
|
||||||
* License along with this library; if not, write to the
|
|
||||||
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
|
|
||||||
* Boston, MA 02110-1301, USA.
|
|
||||||
*/
|
|
||||||
#ifndef __GST_OBJECT_DETECTOR_UTILS_H__
|
|
||||||
#define __GST_OBJECT_DETECTOR_UTILS_H__
|
|
||||||
|
|
||||||
#include <gst/gst.h>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "gstml.h"
|
|
||||||
#include "tensor/gsttensormeta.h"
|
|
||||||
|
|
||||||
char ** read_labels (const char * labels_file);
|
|
||||||
|
|
||||||
/* Object detection tensor id strings */
|
|
||||||
#define GST_MODEL_OBJECT_DETECTOR_BOXES "Gst.Model.ObjectDetector.Boxes"
|
|
||||||
#define GST_MODEL_OBJECT_DETECTOR_SCORES "Gst.Model.ObjectDetector.Scores"
|
|
||||||
#define GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS "Gst.Model.ObjectDetector.NumDetections"
|
|
||||||
#define GST_MODEL_OBJECT_DETECTOR_CLASSES "Gst.Model.ObjectDetector.Classes"
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* GstMlBoundingBox:
|
|
||||||
*
|
|
||||||
* @label label
|
|
||||||
* @score detection confidence
|
|
||||||
* @x0 top left hand x coordinate
|
|
||||||
* @y0 top left hand y coordinate
|
|
||||||
* @width width
|
|
||||||
* @height height
|
|
||||||
*
|
|
||||||
* Since: 1.20
|
|
||||||
*/
|
|
||||||
struct GstMlBoundingBox {
|
|
||||||
GstMlBoundingBox(std::string lbl, float score, float _x0, float _y0,
|
|
||||||
float _width, float _height);
|
|
||||||
GstMlBoundingBox();
|
|
||||||
std::string label;
|
|
||||||
float score;
|
|
||||||
float x0;
|
|
||||||
float y0;
|
|
||||||
float width;
|
|
||||||
float height;
|
|
||||||
};
|
|
||||||
|
|
||||||
namespace GstObjectDetectorUtils {
|
|
||||||
const int GstObjectDetectorMaxNodes = 4;
|
|
||||||
class GstObjectDetectorUtils {
|
|
||||||
public:
|
|
||||||
GstObjectDetectorUtils(void);
|
|
||||||
~GstObjectDetectorUtils(void) = default;
|
|
||||||
std::vector < GstMlBoundingBox > run(int32_t w, int32_t h,
|
|
||||||
GstTensorMeta *tmeta,
|
|
||||||
char **labels,
|
|
||||||
float scoreThreshold);
|
|
||||||
private:
|
|
||||||
template < typename T > std::vector < GstMlBoundingBox >
|
|
||||||
doRun(int32_t w, int32_t h,
|
|
||||||
GstTensorMeta *tmeta, char **labels,
|
|
||||||
float scoreThreshold);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif /* __GST_OBJECT_DETECTOR_UTILS_H__ */
|
|
|
@ -44,19 +44,23 @@
|
||||||
#include "config.h"
|
#include "config.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
#include "gstssdobjectdetector.h"
|
#include "gstssdobjectdetector.h"
|
||||||
#include "gstobjectdetectorutils.h"
|
|
||||||
|
#include <gio/gio.h>
|
||||||
|
|
||||||
#include <gst/gst.h>
|
#include <gst/gst.h>
|
||||||
#include <gst/video/video.h>
|
#include <gst/video/video.h>
|
||||||
|
|
||||||
#include <gst/analytics/analytics.h>
|
#include <gst/analytics/analytics.h>
|
||||||
#include "tensor/gsttensormeta.h"
|
#include "tensor/gsttensormeta.h"
|
||||||
|
|
||||||
|
/* Object detection tensor id strings */
|
||||||
|
#define GST_MODEL_OBJECT_DETECTOR_BOXES "Gst.Model.ObjectDetector.Boxes"
|
||||||
|
#define GST_MODEL_OBJECT_DETECTOR_SCORES "Gst.Model.ObjectDetector.Scores"
|
||||||
|
#define GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS "Gst.Model.ObjectDetector.NumDetections"
|
||||||
|
#define GST_MODEL_OBJECT_DETECTOR_CLASSES "Gst.Model.ObjectDetector.Classes"
|
||||||
|
|
||||||
GST_DEBUG_CATEGORY_STATIC (ssd_object_detector_debug);
|
GST_DEBUG_CATEGORY_STATIC (ssd_object_detector_debug);
|
||||||
#define GST_CAT_DEFAULT ssd_object_detector_debug
|
#define GST_CAT_DEFAULT ssd_object_detector_debug
|
||||||
#define GST_ODUTILS_MEMBER( self ) ((GstObjectDetectorUtils::GstObjectDetectorUtils *) (self->odutils))
|
|
||||||
GST_ELEMENT_REGISTER_DEFINE (ssd_object_detector, "ssdobjectdetector",
|
GST_ELEMENT_REGISTER_DEFINE (ssd_object_detector, "ssdobjectdetector",
|
||||||
GST_RANK_PRIMARY, GST_TYPE_SSD_OBJECT_DETECTOR);
|
GST_RANK_PRIMARY, GST_TYPE_SSD_OBJECT_DETECTOR);
|
||||||
|
|
||||||
|
@ -97,7 +101,8 @@ static gboolean
|
||||||
gst_ssd_object_detector_set_caps (GstBaseTransform * trans, GstCaps * incaps,
|
gst_ssd_object_detector_set_caps (GstBaseTransform * trans, GstCaps * incaps,
|
||||||
GstCaps * outcaps);
|
GstCaps * outcaps);
|
||||||
|
|
||||||
G_DEFINE_TYPE (GstSsdObjectDetector, gst_ssd_object_detector, GST_TYPE_BASE_TRANSFORM);
|
G_DEFINE_TYPE (GstSsdObjectDetector, gst_ssd_object_detector,
|
||||||
|
GST_TYPE_BASE_TRANSFORM);
|
||||||
|
|
||||||
static void
|
static void
|
||||||
gst_ssd_object_detector_class_init (GstSsdObjectDetectorClass * klass)
|
gst_ssd_object_detector_class_init (GstSsdObjectDetectorClass * klass)
|
||||||
|
@ -135,7 +140,8 @@ gst_ssd_object_detector_class_init (GstSsdObjectDetectorClass * klass)
|
||||||
g_param_spec_float ("score-threshold",
|
g_param_spec_float ("score-threshold",
|
||||||
"Score threshold",
|
"Score threshold",
|
||||||
"Threshold for deciding when to remove boxes based on score",
|
"Threshold for deciding when to remove boxes based on score",
|
||||||
0.0, 1.0, GST_SSD_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD, (GParamFlags)
|
0.0, 1.0, GST_SSD_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD,
|
||||||
|
(GParamFlags)
|
||||||
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
||||||
|
|
||||||
gst_element_class_set_static_metadata (element_class, "objectdetector",
|
gst_element_class_set_static_metadata (element_class, "objectdetector",
|
||||||
|
@ -155,7 +161,6 @@ gst_ssd_object_detector_class_init (GstSsdObjectDetectorClass * klass)
|
||||||
static void
|
static void
|
||||||
gst_ssd_object_detector_init (GstSsdObjectDetector * self)
|
gst_ssd_object_detector_init (GstSsdObjectDetector * self)
|
||||||
{
|
{
|
||||||
self->odutils = new GstObjectDetectorUtils::GstObjectDetectorUtils ();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void
|
static void
|
||||||
|
@ -164,12 +169,58 @@ gst_ssd_object_detector_finalize (GObject * object)
|
||||||
GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (object);
|
GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (object);
|
||||||
|
|
||||||
g_free (self->label_file);
|
g_free (self->label_file);
|
||||||
g_strfreev (self->labels);
|
g_clear_pointer (&self->labels, g_array_unref);
|
||||||
delete GST_ODUTILS_MEMBER (self);
|
|
||||||
|
|
||||||
G_OBJECT_CLASS (gst_ssd_object_detector_parent_class)->finalize (object);
|
G_OBJECT_CLASS (gst_ssd_object_detector_parent_class)->finalize (object);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static GArray *
|
||||||
|
read_labels (const char *labels_file)
|
||||||
|
{
|
||||||
|
GArray *array;
|
||||||
|
GFile *file = g_file_new_for_path (labels_file);
|
||||||
|
GFileInputStream *file_stream;
|
||||||
|
GDataInputStream *data_stream;
|
||||||
|
GError *error = NULL;
|
||||||
|
gchar *line;
|
||||||
|
|
||||||
|
file_stream = g_file_read (file, NULL, &error);
|
||||||
|
g_object_unref (file);
|
||||||
|
if (!file_stream) {
|
||||||
|
GST_WARNING ("Could not open file %s: %s\n", labels_file, error->message);
|
||||||
|
g_clear_error (&error);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
data_stream = g_data_input_stream_new (G_INPUT_STREAM (file_stream));
|
||||||
|
g_object_unref (file_stream);
|
||||||
|
|
||||||
|
array = g_array_new (FALSE, FALSE, sizeof (GQuark));
|
||||||
|
|
||||||
|
while ((line = g_data_input_stream_read_line (data_stream, NULL, NULL,
|
||||||
|
&error))) {
|
||||||
|
GQuark label = g_quark_from_string (line);
|
||||||
|
g_array_append_val (array, label);
|
||||||
|
g_free (line);
|
||||||
|
}
|
||||||
|
|
||||||
|
g_object_unref (data_stream);
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
GST_WARNING ("Could not open file %s: %s", labels_file, error->message);
|
||||||
|
g_array_free (array, TRUE);
|
||||||
|
g_clear_error (&error);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (array->len == 0) {
|
||||||
|
g_array_free (array, TRUE);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
return array;
|
||||||
|
}
|
||||||
|
|
||||||
static void
|
static void
|
||||||
gst_ssd_object_detector_set_property (GObject * object, guint prop_id,
|
gst_ssd_object_detector_set_property (GObject * object, guint prop_id,
|
||||||
const GValue * value, GParamSpec * pspec)
|
const GValue * value, GParamSpec * pspec)
|
||||||
|
@ -180,7 +231,7 @@ gst_ssd_object_detector_set_property (GObject * object, guint prop_id,
|
||||||
switch (prop_id) {
|
switch (prop_id) {
|
||||||
case PROP_LABEL_FILE:
|
case PROP_LABEL_FILE:
|
||||||
{
|
{
|
||||||
gchar **labels;
|
GArray *labels;
|
||||||
|
|
||||||
filename = g_value_get_string (value);
|
filename = g_value_get_string (value);
|
||||||
labels = read_labels (filename);
|
labels = read_labels (filename);
|
||||||
|
@ -188,7 +239,7 @@ gst_ssd_object_detector_set_property (GObject * object, guint prop_id,
|
||||||
if (labels) {
|
if (labels) {
|
||||||
g_free (self->label_file);
|
g_free (self->label_file);
|
||||||
self->label_file = g_strdup (filename);
|
self->label_file = g_strdup (filename);
|
||||||
g_strfreev (self->labels);
|
g_clear_pointer (&self->labels, g_array_unref);
|
||||||
self->labels = labels;
|
self->labels = labels;
|
||||||
} else {
|
} else {
|
||||||
GST_WARNING_OBJECT (self, "Label file '%s' not found!", filename);
|
GST_WARNING_OBJECT (self, "Label file '%s' not found!", filename);
|
||||||
|
@ -259,7 +310,8 @@ gst_ssd_object_detector_get_tensor_meta (GstSsdObjectDetector * object_detector,
|
||||||
gint clasesIndex = gst_tensor_meta_get_index_from_id (tensor_meta,
|
gint clasesIndex = gst_tensor_meta_get_index_from_id (tensor_meta,
|
||||||
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_CLASSES));
|
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_CLASSES));
|
||||||
|
|
||||||
if (boxesIndex == GST_TENSOR_MISSING_ID || scoresIndex == GST_TENSOR_MISSING_ID
|
if (boxesIndex == GST_TENSOR_MISSING_ID
|
||||||
|
|| scoresIndex == GST_TENSOR_MISSING_ID
|
||||||
|| numDetectionsIndex == GST_TENSOR_MISSING_ID)
|
|| numDetectionsIndex == GST_TENSOR_MISSING_ID)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
|
@ -300,13 +352,175 @@ gst_ssd_object_detector_transform_ip (GstBaseTransform * trans, GstBuffer * buf)
|
||||||
return GST_FLOW_OK;
|
return GST_FLOW_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define DEFINE_GET_FUNC(TYPE, MAX) \
|
||||||
|
static gboolean \
|
||||||
|
get_ ## TYPE ## _at_index (GstTensor *tensor, GstMapInfo *map, \
|
||||||
|
guint index, TYPE * out) \
|
||||||
|
{ \
|
||||||
|
switch (tensor->type) { \
|
||||||
|
case GST_TENSOR_TYPE_FLOAT32: { \
|
||||||
|
float *f = (float *) map->data; \
|
||||||
|
if (sizeof(*f) * (index + 1) > map->size) \
|
||||||
|
return FALSE; \
|
||||||
|
*out = f[index]; \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
case GST_TENSOR_TYPE_UINT32: { \
|
||||||
|
guint32 *u = (guint32 *) map->data; \
|
||||||
|
if (sizeof(*u) * (index + 1) > map->size) \
|
||||||
|
return FALSE; \
|
||||||
|
*out = u[index]; \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
default: \
|
||||||
|
GST_ERROR ("Only float32 and int32 tensors are understood"); \
|
||||||
|
return FALSE; \
|
||||||
|
} \
|
||||||
|
return TRUE; \
|
||||||
|
}
|
||||||
|
|
||||||
|
DEFINE_GET_FUNC (guint32, UINT32_MAX)
|
||||||
|
DEFINE_GET_FUNC (float, FLOAT_MAX)
|
||||||
|
#undef DEFINE_GET_FUNC
|
||||||
|
static void
|
||||||
|
extract_bounding_boxes (GstSsdObjectDetector * self, gsize w, gsize h,
|
||||||
|
GstAnalyticsRelationMeta * rmeta, GstTensorMeta * tmeta)
|
||||||
|
{
|
||||||
|
gint classes_index;
|
||||||
|
gint boxes_index;
|
||||||
|
gint scores_index;
|
||||||
|
gint numdetect_index;
|
||||||
|
|
||||||
|
GstMapInfo boxes_map = GST_MAP_INFO_INIT;
|
||||||
|
GstMapInfo numdetect_map = GST_MAP_INFO_INIT;
|
||||||
|
GstMapInfo scores_map = GST_MAP_INFO_INIT;
|
||||||
|
GstMapInfo classes_map = GST_MAP_INFO_INIT;
|
||||||
|
|
||||||
|
guint num_detections = 0;
|
||||||
|
|
||||||
|
classes_index = gst_tensor_meta_get_index_from_id (tmeta,
|
||||||
|
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_CLASSES));
|
||||||
|
numdetect_index = gst_tensor_meta_get_index_from_id (tmeta,
|
||||||
|
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS));
|
||||||
|
scores_index = gst_tensor_meta_get_index_from_id (tmeta,
|
||||||
|
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_SCORES));
|
||||||
|
boxes_index = gst_tensor_meta_get_index_from_id (tmeta,
|
||||||
|
g_quark_from_static_string (GST_MODEL_OBJECT_DETECTOR_BOXES));
|
||||||
|
|
||||||
|
if (numdetect_index == GST_TENSOR_MISSING_ID
|
||||||
|
|| scores_index == GST_TENSOR_MISSING_ID
|
||||||
|
|| numdetect_index == GST_TENSOR_MISSING_ID) {
|
||||||
|
GST_WARNING ("Missing tensor data expected for SSD model");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!gst_buffer_map (tmeta->tensor[numdetect_index].data, &numdetect_map,
|
||||||
|
GST_MAP_READ)) {
|
||||||
|
GST_ERROR_OBJECT (self, "Failed to map tensor memory for index %d",
|
||||||
|
numdetect_index);
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!gst_buffer_map (tmeta->tensor[boxes_index].data, &boxes_map,
|
||||||
|
GST_MAP_READ)) {
|
||||||
|
GST_ERROR_OBJECT (self, "Failed to map tensor memory for index %d",
|
||||||
|
boxes_index);
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!gst_buffer_map (tmeta->tensor[scores_index].data, &scores_map,
|
||||||
|
GST_MAP_READ)) {
|
||||||
|
GST_ERROR_OBJECT (self, "Failed to map tensor memory for index %d",
|
||||||
|
scores_index);
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (classes_index != GST_TENSOR_MISSING_ID &&
|
||||||
|
!gst_buffer_map (tmeta->tensor[classes_index].data, &classes_map,
|
||||||
|
GST_MAP_READ)) {
|
||||||
|
GST_DEBUG_OBJECT (self, "Failed to map tensor memory for index %d",
|
||||||
|
classes_index);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if (!get_guint32_at_index (&tmeta->tensor[numdetect_index], &numdetect_map,
|
||||||
|
0, &num_detections)) {
|
||||||
|
GST_ERROR_OBJECT (self, "Failed to get the number of detections");
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
GST_LOG_OBJECT (self, "Model claims %d detections", num_detections);
|
||||||
|
|
||||||
|
for (int i = 0; i < num_detections; i++) {
|
||||||
|
float score;
|
||||||
|
float x, y, bwidth, bheight;
|
||||||
|
gint x_i, y_i, bwidth_i, bheight_i;
|
||||||
|
guint32 bclass;
|
||||||
|
GQuark label = 0;
|
||||||
|
GstAnalyticsODMtd odmtd;
|
||||||
|
|
||||||
|
if (!get_float_at_index (&tmeta->tensor[numdetect_index], &scores_map,
|
||||||
|
i, &score))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
GST_LOG_OBJECT (self, "Detection %u score is %f", i, score);
|
||||||
|
if (score < self->score_threshold)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (!get_float_at_index (&tmeta->tensor[boxes_index], &boxes_map,
|
||||||
|
i * 4, &y))
|
||||||
|
continue;
|
||||||
|
if (!get_float_at_index (&tmeta->tensor[boxes_index], &boxes_map,
|
||||||
|
i * 4 + 1, &x))
|
||||||
|
continue;
|
||||||
|
if (!get_float_at_index (&tmeta->tensor[boxes_index], &boxes_map,
|
||||||
|
i * 4 + 2, &bheight))
|
||||||
|
continue;
|
||||||
|
if (!get_float_at_index (&tmeta->tensor[boxes_index], &boxes_map,
|
||||||
|
i * 4 + 3, &bwidth))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (self->labels && classes_map.memory &&
|
||||||
|
get_guint32_at_index (&tmeta->tensor[classes_index], &classes_map,
|
||||||
|
i, &bclass)) {
|
||||||
|
if (bclass < self->labels->len)
|
||||||
|
label = g_array_index (self->labels, GQuark, bclass);
|
||||||
|
}
|
||||||
|
|
||||||
|
x_i = x * w;
|
||||||
|
y_i = y * h;
|
||||||
|
bheight_i = (bheight * h) - y_i;
|
||||||
|
bwidth_i = (bwidth * w) - x_i;
|
||||||
|
|
||||||
|
if (gst_analytics_relation_meta_add_od_mtd (rmeta, label,
|
||||||
|
x_i, y_i, bwidth_i, bheight_i, score, &odmtd))
|
||||||
|
GST_DEBUG_OBJECT (self,
|
||||||
|
"Object detected with label : %s, score: %f, bound box: %dx%d at (%d,%d)",
|
||||||
|
g_quark_to_string (label), score, bwidth_i, bheight_i, x_i, y_i);
|
||||||
|
else
|
||||||
|
GST_WARNING_OBJECT (self, "Could not add detection to meta");
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanup:
|
||||||
|
|
||||||
|
if (numdetect_map.memory)
|
||||||
|
gst_buffer_unmap (tmeta->tensor[numdetect_index].data, &numdetect_map);
|
||||||
|
if (classes_map.memory)
|
||||||
|
gst_buffer_unmap (tmeta->tensor[classes_index].data, &classes_map);
|
||||||
|
if (scores_map.memory)
|
||||||
|
gst_buffer_unmap (tmeta->tensor[scores_index].data, &scores_map);
|
||||||
|
if (boxes_map.memory)
|
||||||
|
gst_buffer_unmap (tmeta->tensor[boxes_index].data, &boxes_map);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
static gboolean
|
static gboolean
|
||||||
gst_ssd_object_detector_process (GstBaseTransform * trans, GstBuffer * buf)
|
gst_ssd_object_detector_process (GstBaseTransform * trans, GstBuffer * buf)
|
||||||
{
|
{
|
||||||
GstTensorMeta *tmeta = NULL;
|
|
||||||
GstAnalyticsODMtd odmtd;
|
|
||||||
GstAnalyticsRelationMeta *rmeta;
|
|
||||||
GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (trans);
|
GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (trans);
|
||||||
|
GstTensorMeta *tmeta;
|
||||||
|
GstAnalyticsRelationMeta *rmeta;
|
||||||
|
|
||||||
// get all tensor metas
|
// get all tensor metas
|
||||||
tmeta = gst_ssd_object_detector_get_tensor_meta (self, buf);
|
tmeta = gst_ssd_object_detector_get_tensor_meta (self, buf);
|
||||||
|
@ -315,25 +529,11 @@ gst_ssd_object_detector_process (GstBaseTransform * trans, GstBuffer * buf)
|
||||||
return TRUE;
|
return TRUE;
|
||||||
} else {
|
} else {
|
||||||
rmeta = gst_buffer_add_analytics_relation_meta (buf);
|
rmeta = gst_buffer_add_analytics_relation_meta (buf);
|
||||||
g_return_val_if_fail (rmeta != NULL, FALSE);
|
g_assert (rmeta);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector < GstMlBoundingBox > boxes =
|
extract_bounding_boxes (self, self->video_info.width,
|
||||||
GST_ODUTILS_MEMBER (self)->run (self->video_info.width,
|
self->video_info.height, rmeta, tmeta);
|
||||||
self->video_info.height, tmeta, self->labels,
|
|
||||||
self->score_threshold);
|
|
||||||
|
|
||||||
for (auto & b:boxes) {
|
|
||||||
if (gst_analytics_relation_meta_add_od_mtd (rmeta,
|
|
||||||
g_quark_from_string(b.label.c_str ()), b.x0, b.y0, b.width, b.height,
|
|
||||||
b.score, &odmtd)) {
|
|
||||||
GST_DEBUG_OBJECT (self,
|
|
||||||
"Object detected with label : %s, score: %f, bound box: (%f,%f,%f,%f)",
|
|
||||||
b.label.c_str (), b.score, b.x0, b.y0, b.x0 + b.width, b.y0 + b.height);
|
|
||||||
} else {
|
|
||||||
GST_ERROR_OBJECT (self, "Failed to add object detection analytics-meta");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return TRUE;
|
return TRUE;
|
||||||
}
|
}
|
|
@ -37,14 +37,11 @@ G_DECLARE_FINAL_TYPE (GstSsdObjectDetector, gst_ssd_object_detector, GST, SSD_OB
|
||||||
#define GST_SSD_OBJECT_DETECTOR_META_FIELD_LABEL "label"
|
#define GST_SSD_OBJECT_DETECTOR_META_FIELD_LABEL "label"
|
||||||
#define GST_SSD_OBJECT_DETECTOR_META_FIELD_SCORE "score"
|
#define GST_SSD_OBJECT_DETECTOR_META_FIELD_SCORE "score"
|
||||||
|
|
||||||
/**
|
/*
|
||||||
* GstSsdObjectDetector:
|
* GstSsdObjectDetector:
|
||||||
*
|
*
|
||||||
* @label_file label file
|
* @label_file label file
|
||||||
* @score_threshold score threshold
|
* @score_threshold score threshold
|
||||||
* @confidence_threshold confidence threshold
|
|
||||||
* @iou_threhsold iou threshold
|
|
||||||
* @od_ptr opaque pointer to GstOd object detection implementation
|
|
||||||
*
|
*
|
||||||
* Since: 1.20
|
* Since: 1.20
|
||||||
*/
|
*/
|
||||||
|
@ -52,11 +49,8 @@ struct _GstSsdObjectDetector
|
||||||
{
|
{
|
||||||
GstBaseTransform basetransform;
|
GstBaseTransform basetransform;
|
||||||
gchar *label_file;
|
gchar *label_file;
|
||||||
gchar **labels;
|
GArray *labels;
|
||||||
gfloat score_threshold;
|
gfloat score_threshold;
|
||||||
gfloat confidence_threshold;
|
|
||||||
gfloat iou_threshold;
|
|
||||||
gpointer odutils;
|
|
||||||
GstVideoInfo video_info;
|
GstVideoInfo video_info;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -96,7 +96,7 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
|
||||||
return inputImageFormat;
|
return inputImageFormat;
|
||||||
}
|
}
|
||||||
|
|
||||||
void GstOnnxClient::setInputImageDatatype(GstTensorType datatype)
|
void GstOnnxClient::setInputImageDatatype(GstTensorDataType datatype)
|
||||||
{
|
{
|
||||||
inputDatatype = datatype;
|
inputDatatype = datatype;
|
||||||
switch (inputDatatype) {
|
switch (inputDatatype) {
|
||||||
|
@ -144,7 +144,7 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
|
||||||
return inputTensorScale;
|
return inputTensorScale;
|
||||||
}
|
}
|
||||||
|
|
||||||
GstTensorType GstOnnxClient::getInputImageDatatype(void)
|
GstTensorDataType GstOnnxClient::getInputImageDatatype(void)
|
||||||
{
|
{
|
||||||
return inputDatatype;
|
return inputDatatype;
|
||||||
}
|
}
|
||||||
|
|
|
@ -56,7 +56,7 @@ namespace GstOnnxNamespace {
|
||||||
bool hasSession(void);
|
bool hasSession(void);
|
||||||
void setInputImageFormat(GstMlInputImageFormat format);
|
void setInputImageFormat(GstMlInputImageFormat format);
|
||||||
GstMlInputImageFormat getInputImageFormat(void);
|
GstMlInputImageFormat getInputImageFormat(void);
|
||||||
GstTensorType getInputImageDatatype(void);
|
GstTensorDataType getInputImageDatatype(void);
|
||||||
void setInputImageOffset (float offset);
|
void setInputImageOffset (float offset);
|
||||||
float getInputImageOffset ();
|
float getInputImageOffset ();
|
||||||
void setInputImageScale (float offset);
|
void setInputImageScale (float offset);
|
||||||
|
@ -73,7 +73,7 @@ namespace GstOnnxNamespace {
|
||||||
private:
|
private:
|
||||||
|
|
||||||
GstElement *debug_parent;
|
GstElement *debug_parent;
|
||||||
void setInputImageDatatype (GstTensorType datatype);
|
void setInputImageDatatype (GstTensorDataType datatype);
|
||||||
template < typename T>
|
template < typename T>
|
||||||
void convert_image_remove_alpha (T *dest, GstMlInputImageFormat hwc,
|
void convert_image_remove_alpha (T *dest, GstMlInputImageFormat hwc,
|
||||||
uint8_t **srcPtr, uint32_t srcSamplesPerPixel, uint32_t stride, T offset, T div);
|
uint8_t **srcPtr, uint32_t srcSamplesPerPixel, uint32_t stride, T offset, T div);
|
||||||
|
@ -91,7 +91,7 @@ namespace GstOnnxNamespace {
|
||||||
std::vector < Ort::AllocatedStringPtr > outputNames;
|
std::vector < Ort::AllocatedStringPtr > outputNames;
|
||||||
std::vector < GQuark > outputIds;
|
std::vector < GQuark > outputIds;
|
||||||
GstMlInputImageFormat inputImageFormat;
|
GstMlInputImageFormat inputImageFormat;
|
||||||
GstTensorType inputDatatype;
|
GstTensorDataType inputDatatype;
|
||||||
size_t inputDatatypeSize;
|
size_t inputDatatypeSize;
|
||||||
bool fixedInputImageSize;
|
bool fixedInputImageSize;
|
||||||
float inputTensorOffset;
|
float inputTensorOffset;
|
||||||
|
|
|
@ -15,8 +15,7 @@ endif
|
||||||
if onnxrt_dep.found()
|
if onnxrt_dep.found()
|
||||||
gstonnx = library('gstonnx',
|
gstonnx = library('gstonnx',
|
||||||
'gstonnx.c',
|
'gstonnx.c',
|
||||||
'decoders/gstobjectdetectorutils.cpp',
|
'decoders/gstssdobjectdetector.c',
|
||||||
'decoders/gstssdobjectdetector.cpp',
|
|
||||||
'gstonnxinference.cpp',
|
'gstonnxinference.cpp',
|
||||||
'gstonnxclient.cpp',
|
'gstonnxclient.cpp',
|
||||||
'tensor/gsttensormeta.c',
|
'tensor/gsttensormeta.c',
|
||||||
|
|
|
@ -25,7 +25,7 @@
|
||||||
#include <gst/gst.h>
|
#include <gst/gst.h>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* GstTensorType:
|
* GstTensorDataType:
|
||||||
*
|
*
|
||||||
* @GST_TENSOR_TYPE_INT4 signed 4 bit integer tensor data
|
* @GST_TENSOR_TYPE_INT4 signed 4 bit integer tensor data
|
||||||
* @GST_TENSOR_TYPE_INT8 signed 8 bit integer tensor data
|
* @GST_TENSOR_TYPE_INT8 signed 8 bit integer tensor data
|
||||||
|
@ -42,9 +42,11 @@
|
||||||
* @GST_TENSOR_TYPE_FLOAT64 64 bit floating point tensor data
|
* @GST_TENSOR_TYPE_FLOAT64 64 bit floating point tensor data
|
||||||
* @GST_TENSOR_TYPE_BFLOAT16 "brain" 16 bit floating point tensor data
|
* @GST_TENSOR_TYPE_BFLOAT16 "brain" 16 bit floating point tensor data
|
||||||
*
|
*
|
||||||
|
* Describe the type of data contain in the tensor.
|
||||||
|
*
|
||||||
* Since: 1.24
|
* Since: 1.24
|
||||||
*/
|
*/
|
||||||
typedef enum _GstTensorType
|
typedef enum _GstTensorDataType
|
||||||
{
|
{
|
||||||
GST_TENSOR_TYPE_INT4,
|
GST_TENSOR_TYPE_INT4,
|
||||||
GST_TENSOR_TYPE_INT8,
|
GST_TENSOR_TYPE_INT8,
|
||||||
|
@ -60,17 +62,17 @@ typedef enum _GstTensorType
|
||||||
GST_TENSOR_TYPE_FLOAT32,
|
GST_TENSOR_TYPE_FLOAT32,
|
||||||
GST_TENSOR_TYPE_FLOAT64,
|
GST_TENSOR_TYPE_FLOAT64,
|
||||||
GST_TENSOR_TYPE_BFLOAT16,
|
GST_TENSOR_TYPE_BFLOAT16,
|
||||||
} GstTensorType;
|
} GstTensorDataType;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* GstTensor:
|
* GstTensor:
|
||||||
*
|
*
|
||||||
* @id unique tensor identifier
|
* @id: semantically identify the contents of the tensor
|
||||||
* @num_dims number of tensor dimensions
|
* @num_dims: number of tensor dimensions
|
||||||
* @dims tensor dimensions
|
* @dims: tensor dimensions
|
||||||
* @type @ref GstTensorType of tensor data
|
* @type: #GstTensorDataType of tensor data
|
||||||
* @data @ref GstBuffer holding tensor data
|
* @data: #GstBuffer holding tensor data
|
||||||
*
|
*
|
||||||
* Since: 1.24
|
* Since: 1.24
|
||||||
*/
|
*/
|
||||||
|
@ -79,7 +81,7 @@ typedef struct _GstTensor
|
||||||
GQuark id;
|
GQuark id;
|
||||||
gint num_dims;
|
gint num_dims;
|
||||||
int64_t *dims;
|
int64_t *dims;
|
||||||
GstTensorType type;
|
GstTensorDataType data_type;
|
||||||
GstBuffer *data;
|
GstBuffer *data;
|
||||||
} GstTensor;
|
} GstTensor;
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue