onnx: Only read labels file one and use GIO

Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/6001>
This commit is contained in:
Olivier Crête 2024-01-24 22:31:21 -05:00
parent 13de5160be
commit 5e1291fd86
5 changed files with 78 additions and 35 deletions

View file

@ -22,7 +22,56 @@
#include "gstobjectdetectorutils.h"
#include <fstream>
#include <gio/gio.h>
char **
read_labels (const char * labels_file)
{
GPtrArray *array;
GFile *file = g_file_new_for_path (labels_file);
GFileInputStream *file_stream;
GDataInputStream *data_stream;
GError *error = NULL;
gchar *line;
file_stream = g_file_read (file, NULL, &error);
g_object_unref (file);
if (!file_stream) {
GST_WARNING ("Could not open file %s: %s\n", labels_file,
error->message);
g_clear_error (&error);
return NULL;
}
data_stream = g_data_input_stream_new (G_INPUT_STREAM (file_stream));
g_object_unref (file_stream);
array = g_ptr_array_new();
while ((line = g_data_input_stream_read_line (data_stream, NULL, NULL,
&error)))
g_ptr_array_add (array, line);
g_object_unref (data_stream);
if (error) {
GST_WARNING ("Could not open file %s: %s", labels_file, error->message);
g_ptr_array_free (array, TRUE);
g_clear_error (&error);
return NULL;
}
if (array->len == 0) {
g_ptr_array_free (array, TRUE);
return NULL;
}
g_ptr_array_add (array, NULL);
return (char **) g_ptr_array_free (array, FALSE);
}
GstMlBoundingBox::GstMlBoundingBox (std::string lbl, float score, float _x0,
float _y0, float _width, float _height):
@ -47,20 +96,9 @@ namespace GstObjectDetectorUtils
{
}
std::vector < std::string >
GstObjectDetectorUtils::ReadLabels (const std::string & labelsFile)
{
std::vector < std::string > labels;
std::string line;
std::ifstream fp (labelsFile);
while (std::getline (fp, line))
labels.push_back (line);
return labels;
}
std::vector < GstMlBoundingBox > GstObjectDetectorUtils::run (int32_t w,
int32_t h, GstTensorMeta * tmeta, std::string labelPath,
int32_t h, GstTensorMeta * tmeta, gchar **labels,
float scoreThreshold)
{
@ -72,18 +110,17 @@ namespace GstObjectDetectorUtils
}
auto type = tmeta->tensor[classIndex].type;
return (type == GST_TENSOR_TYPE_FLOAT32) ?
doRun < float >(w, h, tmeta, labelPath, scoreThreshold)
: doRun < int >(w, h, tmeta, labelPath, scoreThreshold);
doRun < float >(w, h, tmeta, labels, scoreThreshold)
: doRun < int >(w, h, tmeta, labels, scoreThreshold);
}
template < typename T > std::vector < GstMlBoundingBox >
GstObjectDetectorUtils::doRun (int32_t w, int32_t h,
GstTensorMeta * tmeta, std::string labelPath, float scoreThreshold)
GstTensorMeta * tmeta, char **labels, float scoreThreshold)
{
std::vector < GstMlBoundingBox > boundingBoxes;
GstMapInfo map_info[GstObjectDetectorMaxNodes];
GstMemory *memory[GstObjectDetectorMaxNodes] = { NULL };
std::vector < std::string > labels;
gint index;
T *numDetections = nullptr, *bboxes = nullptr, *scores =
nullptr, *labelIndex = nullptr;
@ -162,15 +199,12 @@ namespace GstObjectDetectorUtils
labelIndex = (T *) map_info[index].data;
}
if (!labelPath.empty ())
labels = ReadLabels (labelPath);
for (int i = 0; i < numDetections[0]; ++i) {
if (scores[i] > scoreThreshold) {
std::string label = "";
if (labelIndex && !labels.empty ())
label = labels[labelIndex[i] - 1];
if (labels && labelIndex)
label = labels[(int)labelIndex[i] - 1];
auto score = scores[i];
auto y0 = bboxes[i * 4] * h;
auto x0 = bboxes[i * 4 + 1] * w;

View file

@ -29,6 +29,8 @@
#include "gstml.h"
#include "tensor/gsttensormeta.h"
char ** read_labels (const char * labels_file);
/* Object detection tensor id strings */
#define GST_MODEL_OBJECT_DETECTOR_BOXES "Gst.Model.ObjectDetector.Boxes"
#define GST_MODEL_OBJECT_DETECTOR_SCORES "Gst.Model.ObjectDetector.Scores"
@ -68,14 +70,13 @@ namespace GstObjectDetectorUtils {
~GstObjectDetectorUtils(void) = default;
std::vector < GstMlBoundingBox > run(int32_t w, int32_t h,
GstTensorMeta *tmeta,
std::string labelPath,
char **labels,
float scoreThreshold);
private:
template < typename T > std::vector < GstMlBoundingBox >
doRun(int32_t w, int32_t h,
GstTensorMeta *tmeta, std::string labelPath,
GstTensorMeta *tmeta, char **labels,
float scoreThreshold);
std::vector < std::string > ReadLabels(const std::string & labelsFile);
};
}

View file

@ -164,6 +164,7 @@ gst_ssd_object_detector_finalize (GObject * object)
GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (object);
g_free (self->label_file);
g_strfreev (self->labels);
delete GST_ODUTILS_MEMBER (self);
G_OBJECT_CLASS (gst_ssd_object_detector_parent_class)->finalize (object);
@ -178,14 +179,20 @@ gst_ssd_object_detector_set_property (GObject * object, guint prop_id,
switch (prop_id) {
case PROP_LABEL_FILE:
filename = g_value_get_string (value);
if (filename
&& g_file_test (filename,
(GFileTest) (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR))) {
g_free (self->label_file);
self->label_file = g_strdup (filename);
} else {
GST_WARNING_OBJECT (self, "Label file '%s' not found!", filename);
{
gchar **labels;
filename = g_value_get_string (value);
labels = read_labels (filename);
if (labels) {
g_free (self->label_file);
self->label_file = g_strdup (filename);
g_strfreev (self->labels);
self->labels = labels;
} else {
GST_WARNING_OBJECT (self, "Label file '%s' not found!", filename);
}
}
break;
case PROP_SCORE_THRESHOLD:
@ -313,7 +320,7 @@ gst_ssd_object_detector_process (GstBaseTransform * trans, GstBuffer * buf)
std::vector < GstMlBoundingBox > boxes =
GST_ODUTILS_MEMBER (self)->run (self->video_info.width,
self->video_info.height, tmeta, self->label_file ? self->label_file : "",
self->video_info.height, tmeta, self->labels,
self->score_threshold);
for (auto & b:boxes) {

View file

@ -52,6 +52,7 @@ struct _GstSsdObjectDetector
{
GstBaseTransform basetransform;
gchar *label_file;
gchar **labels;
gfloat score_threshold;
gfloat confidence_threshold;
gfloat iou_threshold;

View file

@ -25,7 +25,7 @@ if onnxrt_dep.found()
link_args : noseh_link_args,
include_directories : [configinc, libsinc, cuda_stubinc],
dependencies : [gstbase_dep, gstvideo_dep, gstanalytics_dep, onnxrt_dep,
libm] + extra_deps,
libm, gio_dep] + extra_deps,
install : true,
install_dir : plugins_install_dir,
)