mirror of
https://gitlab.freedesktop.org/gstreamer/gstreamer.git
synced 2025-01-17 12:55:53 +00:00
onnx: Only read labels file one and use GIO
Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/6001>
This commit is contained in:
parent
13de5160be
commit
5e1291fd86
5 changed files with 78 additions and 35 deletions
|
@ -22,7 +22,56 @@
|
|||
|
||||
#include "gstobjectdetectorutils.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <gio/gio.h>
|
||||
|
||||
|
||||
char **
|
||||
read_labels (const char * labels_file)
|
||||
{
|
||||
GPtrArray *array;
|
||||
GFile *file = g_file_new_for_path (labels_file);
|
||||
GFileInputStream *file_stream;
|
||||
GDataInputStream *data_stream;
|
||||
GError *error = NULL;
|
||||
gchar *line;
|
||||
|
||||
file_stream = g_file_read (file, NULL, &error);
|
||||
g_object_unref (file);
|
||||
if (!file_stream) {
|
||||
GST_WARNING ("Could not open file %s: %s\n", labels_file,
|
||||
error->message);
|
||||
g_clear_error (&error);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
data_stream = g_data_input_stream_new (G_INPUT_STREAM (file_stream));
|
||||
g_object_unref (file_stream);
|
||||
|
||||
array = g_ptr_array_new();
|
||||
|
||||
while ((line = g_data_input_stream_read_line (data_stream, NULL, NULL,
|
||||
&error)))
|
||||
g_ptr_array_add (array, line);
|
||||
|
||||
g_object_unref (data_stream);
|
||||
|
||||
if (error) {
|
||||
GST_WARNING ("Could not open file %s: %s", labels_file, error->message);
|
||||
g_ptr_array_free (array, TRUE);
|
||||
g_clear_error (&error);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (array->len == 0) {
|
||||
g_ptr_array_free (array, TRUE);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
g_ptr_array_add (array, NULL);
|
||||
|
||||
return (char **) g_ptr_array_free (array, FALSE);
|
||||
}
|
||||
|
||||
|
||||
GstMlBoundingBox::GstMlBoundingBox (std::string lbl, float score, float _x0,
|
||||
float _y0, float _width, float _height):
|
||||
|
@ -47,20 +96,9 @@ namespace GstObjectDetectorUtils
|
|||
{
|
||||
}
|
||||
|
||||
std::vector < std::string >
|
||||
GstObjectDetectorUtils::ReadLabels (const std::string & labelsFile)
|
||||
{
|
||||
std::vector < std::string > labels;
|
||||
std::string line;
|
||||
std::ifstream fp (labelsFile);
|
||||
while (std::getline (fp, line))
|
||||
labels.push_back (line);
|
||||
|
||||
return labels;
|
||||
}
|
||||
|
||||
std::vector < GstMlBoundingBox > GstObjectDetectorUtils::run (int32_t w,
|
||||
int32_t h, GstTensorMeta * tmeta, std::string labelPath,
|
||||
int32_t h, GstTensorMeta * tmeta, gchar **labels,
|
||||
float scoreThreshold)
|
||||
{
|
||||
|
||||
|
@ -72,18 +110,17 @@ namespace GstObjectDetectorUtils
|
|||
}
|
||||
auto type = tmeta->tensor[classIndex].type;
|
||||
return (type == GST_TENSOR_TYPE_FLOAT32) ?
|
||||
doRun < float >(w, h, tmeta, labelPath, scoreThreshold)
|
||||
: doRun < int >(w, h, tmeta, labelPath, scoreThreshold);
|
||||
doRun < float >(w, h, tmeta, labels, scoreThreshold)
|
||||
: doRun < int >(w, h, tmeta, labels, scoreThreshold);
|
||||
}
|
||||
|
||||
template < typename T > std::vector < GstMlBoundingBox >
|
||||
GstObjectDetectorUtils::doRun (int32_t w, int32_t h,
|
||||
GstTensorMeta * tmeta, std::string labelPath, float scoreThreshold)
|
||||
GstTensorMeta * tmeta, char **labels, float scoreThreshold)
|
||||
{
|
||||
std::vector < GstMlBoundingBox > boundingBoxes;
|
||||
GstMapInfo map_info[GstObjectDetectorMaxNodes];
|
||||
GstMemory *memory[GstObjectDetectorMaxNodes] = { NULL };
|
||||
std::vector < std::string > labels;
|
||||
gint index;
|
||||
T *numDetections = nullptr, *bboxes = nullptr, *scores =
|
||||
nullptr, *labelIndex = nullptr;
|
||||
|
@ -162,15 +199,12 @@ namespace GstObjectDetectorUtils
|
|||
labelIndex = (T *) map_info[index].data;
|
||||
}
|
||||
|
||||
if (!labelPath.empty ())
|
||||
labels = ReadLabels (labelPath);
|
||||
|
||||
for (int i = 0; i < numDetections[0]; ++i) {
|
||||
if (scores[i] > scoreThreshold) {
|
||||
std::string label = "";
|
||||
|
||||
if (labelIndex && !labels.empty ())
|
||||
label = labels[labelIndex[i] - 1];
|
||||
if (labels && labelIndex)
|
||||
label = labels[(int)labelIndex[i] - 1];
|
||||
auto score = scores[i];
|
||||
auto y0 = bboxes[i * 4] * h;
|
||||
auto x0 = bboxes[i * 4 + 1] * w;
|
||||
|
|
|
@ -29,6 +29,8 @@
|
|||
#include "gstml.h"
|
||||
#include "tensor/gsttensormeta.h"
|
||||
|
||||
char ** read_labels (const char * labels_file);
|
||||
|
||||
/* Object detection tensor id strings */
|
||||
#define GST_MODEL_OBJECT_DETECTOR_BOXES "Gst.Model.ObjectDetector.Boxes"
|
||||
#define GST_MODEL_OBJECT_DETECTOR_SCORES "Gst.Model.ObjectDetector.Scores"
|
||||
|
@ -68,14 +70,13 @@ namespace GstObjectDetectorUtils {
|
|||
~GstObjectDetectorUtils(void) = default;
|
||||
std::vector < GstMlBoundingBox > run(int32_t w, int32_t h,
|
||||
GstTensorMeta *tmeta,
|
||||
std::string labelPath,
|
||||
char **labels,
|
||||
float scoreThreshold);
|
||||
private:
|
||||
template < typename T > std::vector < GstMlBoundingBox >
|
||||
doRun(int32_t w, int32_t h,
|
||||
GstTensorMeta *tmeta, std::string labelPath,
|
||||
GstTensorMeta *tmeta, char **labels,
|
||||
float scoreThreshold);
|
||||
std::vector < std::string > ReadLabels(const std::string & labelsFile);
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -164,6 +164,7 @@ gst_ssd_object_detector_finalize (GObject * object)
|
|||
GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (object);
|
||||
|
||||
g_free (self->label_file);
|
||||
g_strfreev (self->labels);
|
||||
delete GST_ODUTILS_MEMBER (self);
|
||||
|
||||
G_OBJECT_CLASS (gst_ssd_object_detector_parent_class)->finalize (object);
|
||||
|
@ -178,14 +179,20 @@ gst_ssd_object_detector_set_property (GObject * object, guint prop_id,
|
|||
|
||||
switch (prop_id) {
|
||||
case PROP_LABEL_FILE:
|
||||
filename = g_value_get_string (value);
|
||||
if (filename
|
||||
&& g_file_test (filename,
|
||||
(GFileTest) (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR))) {
|
||||
g_free (self->label_file);
|
||||
self->label_file = g_strdup (filename);
|
||||
} else {
|
||||
GST_WARNING_OBJECT (self, "Label file '%s' not found!", filename);
|
||||
{
|
||||
gchar **labels;
|
||||
|
||||
filename = g_value_get_string (value);
|
||||
labels = read_labels (filename);
|
||||
|
||||
if (labels) {
|
||||
g_free (self->label_file);
|
||||
self->label_file = g_strdup (filename);
|
||||
g_strfreev (self->labels);
|
||||
self->labels = labels;
|
||||
} else {
|
||||
GST_WARNING_OBJECT (self, "Label file '%s' not found!", filename);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case PROP_SCORE_THRESHOLD:
|
||||
|
@ -313,7 +320,7 @@ gst_ssd_object_detector_process (GstBaseTransform * trans, GstBuffer * buf)
|
|||
|
||||
std::vector < GstMlBoundingBox > boxes =
|
||||
GST_ODUTILS_MEMBER (self)->run (self->video_info.width,
|
||||
self->video_info.height, tmeta, self->label_file ? self->label_file : "",
|
||||
self->video_info.height, tmeta, self->labels,
|
||||
self->score_threshold);
|
||||
|
||||
for (auto & b:boxes) {
|
||||
|
|
|
@ -52,6 +52,7 @@ struct _GstSsdObjectDetector
|
|||
{
|
||||
GstBaseTransform basetransform;
|
||||
gchar *label_file;
|
||||
gchar **labels;
|
||||
gfloat score_threshold;
|
||||
gfloat confidence_threshold;
|
||||
gfloat iou_threshold;
|
||||
|
|
|
@ -25,7 +25,7 @@ if onnxrt_dep.found()
|
|||
link_args : noseh_link_args,
|
||||
include_directories : [configinc, libsinc, cuda_stubinc],
|
||||
dependencies : [gstbase_dep, gstvideo_dep, gstanalytics_dep, onnxrt_dep,
|
||||
libm] + extra_deps,
|
||||
libm, gio_dep] + extra_deps,
|
||||
install : true,
|
||||
install_dir : plugins_install_dir,
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue