diff --git a/subprojects/gst-plugins-bad/ext/onnx/decoders/gstobjectdetectorutils.cpp b/subprojects/gst-plugins-bad/ext/onnx/decoders/gstobjectdetectorutils.cpp new file mode 100644 index 0000000000..d38fe6bbe7 --- /dev/null +++ b/subprojects/gst-plugins-bad/ext/onnx/decoders/gstobjectdetectorutils.cpp @@ -0,0 +1,195 @@ +/* + * GStreamer gstreamer-objectdetectorutils + * Copyright (C) 2023 Collabora Ltd + * + * gstobjectdetectorutils.cpp + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Library General Public + * License as published by the Free Software Foundation; either + * version 2 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Library General Public License for more details. + * + * You should have received a copy of the GNU Library General Public + * License along with this library; if not, write to the + * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, + * Boston, MA 02110-1301, USA. + */ + +#include "gstobjectdetectorutils.h" + +#include +#include "tensor/gsttensorid.h" + +GstMlBoundingBox::GstMlBoundingBox (std::string lbl, float score, float _x0, + float _y0, float _width, float _height): +label (lbl), +score (score), +x0 (_x0), +y0 (_y0), +width (_width), +height (_height) +{ +} + +GstMlBoundingBox::GstMlBoundingBox (): +GstMlBoundingBox ("", 0.0f, 0.0f, 0.0f, 0.0f, 0.0f) +{ +} + +namespace GstObjectDetectorUtils +{ + + GstObjectDetectorUtils::GstObjectDetectorUtils () + { + } + + std::vector < std::string > + GstObjectDetectorUtils::ReadLabels (const std::string & labelsFile) + { + std::vector < std::string > labels; + std::string line; + std::ifstream fp (labelsFile); + while (std::getline (fp, line)) + labels.push_back (line); + + return labels; + } + + std::vector < GstMlBoundingBox > GstObjectDetectorUtils::run (int32_t w, + int32_t h, GstTensorMeta * tmeta, std::string labelPath, + float scoreThreshold) + { + + auto classIndex = gst_tensor_meta_get_index_from_id (tmeta, + gst_tensorid_get_quark (GST_MODEL_OBJECT_DETECTOR_CLASSES)); + if (classIndex == GST_TENSOR_MISSING_ID) { + GST_ERROR ("Missing class tensor id"); + return std::vector < GstMlBoundingBox > (); + } + auto type = tmeta->tensor[classIndex].type; + return (type == GST_TENSOR_TYPE_FLOAT32) ? + doRun < float >(w, h, tmeta, labelPath, scoreThreshold) + : doRun < int >(w, h, tmeta, labelPath, scoreThreshold); + } + + template < typename T > std::vector < GstMlBoundingBox > + GstObjectDetectorUtils::doRun (int32_t w, int32_t h, + GstTensorMeta * tmeta, std::string labelPath, float scoreThreshold) + { + std::vector < GstMlBoundingBox > boundingBoxes; + GstMapInfo map_info[GstObjectDetectorMaxNodes]; + GstMemory *memory[GstObjectDetectorMaxNodes] = { NULL }; + std::vector < std::string > labels; + gint index; + T *numDetections = nullptr, *bboxes = nullptr, *scores = + nullptr, *labelIndex = nullptr; + + // number of detections + index = gst_tensor_meta_get_index_from_id (tmeta, + gst_tensorid_get_quark (GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS)); + if (index == GST_TENSOR_MISSING_ID) { + GST_WARNING ("Missing tensor data for tensor index %d", index); + goto cleanup; + } + memory[index] = gst_buffer_peek_memory (tmeta->tensor[index].data, 0); + if (!memory[index]) { + GST_WARNING ("Missing tensor data for tensor index %d", index); + goto cleanup; + } + if (!gst_memory_map (memory[index], map_info + index, GST_MAP_READ)) { + GST_WARNING ("Failed to map tensor memory for index %d", index); + goto cleanup; + } + numDetections = (T *) map_info[index].data; + + // bounding boxes + index = + gst_tensor_meta_get_index_from_id (tmeta, + gst_tensorid_get_quark (GST_MODEL_OBJECT_DETECTOR_BOXES)); + if (index == GST_TENSOR_MISSING_ID) { + GST_WARNING ("Missing tensor data for tensor index %d", index); + goto cleanup; + } + memory[index] = gst_buffer_peek_memory (tmeta->tensor[index].data, 0); + if (!memory[index]) { + GST_WARNING ("Failed to map tensor memory for index %d", index); + goto cleanup; + } + if (!gst_memory_map (memory[index], map_info + index, GST_MAP_READ)) { + GST_ERROR ("Failed to map GstMemory"); + goto cleanup; + } + bboxes = (T *) map_info[index].data; + + // scores + index = + gst_tensor_meta_get_index_from_id (tmeta, + gst_tensorid_get_quark (GST_MODEL_OBJECT_DETECTOR_SCORES)); + if (index == GST_TENSOR_MISSING_ID) { + GST_ERROR ("Missing scores tensor id"); + goto cleanup; + } + memory[index] = gst_buffer_peek_memory (tmeta->tensor[index].data, 0); + if (!memory[index]) { + GST_WARNING ("Missing tensor data for tensor index %d", index); + goto cleanup; + } + if (!gst_memory_map (memory[index], map_info + index, GST_MAP_READ)) { + GST_ERROR ("Failed to map GstMemory"); + goto cleanup; + } + scores = (T *) map_info[index].data; + + // optional label + labelIndex = nullptr; + index = + gst_tensor_meta_get_index_from_id (tmeta, + gst_tensorid_get_quark (GST_MODEL_OBJECT_DETECTOR_CLASSES)); + if (index != GST_TENSOR_MISSING_ID) { + memory[index] = gst_buffer_peek_memory (tmeta->tensor[index].data, 0); + if (!memory[index]) { + GST_WARNING ("Missing tensor data for tensor index %d", index); + goto cleanup; + } + if (!gst_memory_map (memory[index], map_info + index, GST_MAP_READ)) { + GST_ERROR ("Failed to map GstMemory"); + goto cleanup; + } + labelIndex = (T *) map_info[index].data; + } + + if (!labelPath.empty ()) + labels = ReadLabels (labelPath); + + for (int i = 0; i < numDetections[0]; ++i) { + if (scores[i] > scoreThreshold) { + std::string label = ""; + + if (labelIndex && !labels.empty ()) + label = labels[labelIndex[i] - 1]; + auto score = scores[i]; + auto y0 = bboxes[i * 4] * h; + auto x0 = bboxes[i * 4 + 1] * w; + auto bheight = bboxes[i * 4 + 2] * h - y0; + auto bwidth = bboxes[i * 4 + 3] * w - x0; + boundingBoxes.push_back (GstMlBoundingBox (label, score, x0, y0, bwidth, + bheight)); + } + } + + cleanup: + for (int i = 0; i < GstObjectDetectorMaxNodes; ++i) { + if (memory[i]) + gst_memory_unmap (memory[i], map_info + i); + + } + + return boundingBoxes; + } + +} diff --git a/subprojects/gst-plugins-bad/ext/onnx/decoders/gstobjectdetectorutils.h b/subprojects/gst-plugins-bad/ext/onnx/decoders/gstobjectdetectorutils.h new file mode 100644 index 0000000000..5668ec6b23 --- /dev/null +++ b/subprojects/gst-plugins-bad/ext/onnx/decoders/gstobjectdetectorutils.h @@ -0,0 +1,82 @@ +/* + * GStreamer gstreamer-objectdetectorutils + * Copyright (C) 2023 Collabora Ltd + * + * gstobjectdetectorutils.h + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Library General Public + * License as published by the Free Software Foundation; either + * version 2 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Library General Public License for more details. + * + * You should have received a copy of the GNU Library General Public + * License along with this library; if not, write to the + * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, + * Boston, MA 02110-1301, USA. + */ +#ifndef __GST_OBJECT_DETECTOR_UTILS_H__ +#define __GST_OBJECT_DETECTOR_UTILS_H__ + +#include +#include +#include + +#include "gstml.h" +#include "tensor/gsttensormeta.h" + +/* Object detection tensor id strings */ +#define GST_MODEL_OBJECT_DETECTOR_BOXES "Gst.Model.ObjectDetector.Boxes" +#define GST_MODEL_OBJECT_DETECTOR_SCORES "Gst.Model.ObjectDetector.Scores" +#define GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS "Gst.Model.ObjectDetector.NumDetections" +#define GST_MODEL_OBJECT_DETECTOR_CLASSES "Gst.Model.ObjectDetector.Classes" + + +/** + * GstMlBoundingBox: + * + * @label label + * @score detection confidence + * @x0 top left hand x coordinate + * @y0 top left hand y coordinate + * @width width + * @height height + * + * Since: 1.20 + */ +struct GstMlBoundingBox { + GstMlBoundingBox(std::string lbl, float score, float _x0, float _y0, + float _width, float _height); + GstMlBoundingBox(); + std::string label; + float score; + float x0; + float y0; + float width; + float height; +}; + +namespace GstObjectDetectorUtils { + const int GstObjectDetectorMaxNodes = 4; + class GstObjectDetectorUtils { + public: + GstObjectDetectorUtils(void); + ~GstObjectDetectorUtils(void) = default; + std::vector < GstMlBoundingBox > run(int32_t w, int32_t h, + GstTensorMeta *tmeta, + std::string labelPath, + float scoreThreshold); + private: + template < typename T > std::vector < GstMlBoundingBox > + doRun(int32_t w, int32_t h, + GstTensorMeta *tmeta, std::string labelPath, + float scoreThreshold); + std::vector < std::string > ReadLabels(const std::string & labelsFile); + }; +} + +#endif /* __GST_OBJECT_DETECTOR_UTILS_H__ */ diff --git a/subprojects/gst-plugins-bad/ext/onnx/decoders/gstssdobjectdetector.cpp b/subprojects/gst-plugins-bad/ext/onnx/decoders/gstssdobjectdetector.cpp new file mode 100644 index 0000000000..68af1fb521 --- /dev/null +++ b/subprojects/gst-plugins-bad/ext/onnx/decoders/gstssdobjectdetector.cpp @@ -0,0 +1,348 @@ +/* + * GStreamer gstreamer-ssdobjectdetector + * Copyright (C) 2021 Collabora Ltd. + * + * gstssdobjectdetector.cpp + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Library General Public + * License as published by the Free Software Foundation; either + * version 2 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Library General Public License for more details. + * + * You should have received a copy of the GNU Library General Public + * License along with this library; if not, write to the + * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, + * Boston, MA 02110-1301, USA. + */ + +/** + * SECTION:element-ssdobjectdetector + * @short_description: Detect objects in video buffers using SSD neural network + * + * This element can parse per-buffer inference tensor meta data generated by an upstream + * inference element + * + * + * ## Example launch command: + * + * note: image resolution may need to be adapted to the model, if the model expects + * a certain input resolution. The `videoscale` element in the pipeline below will scale + * the image, using padding if required, to 640x383 resolution required by model + * + * + * GST_DEBUG=objectdetector:5 gst-launch-1.0 multifilesrc \ + * location=bus.jpg ! jpegdec ! videoconvert ! \ + * videoscale ! 'video/x-raw,width=640,height=383' ! \ + * onnxinference execution-provider=cpu model-file=model.onnx \ + * ssdobjectdetector label-file=COCO_classes.txt ! \ + * videoconvert ! autovideosink + * + */ + +#ifdef HAVE_CONFIG_H +#include "config.h" +#endif + + +#include "gstssdobjectdetector.h" +#include "gstobjectdetectorutils.h" + +#include +#include +#include +#include "tensor/gsttensormeta.h" +#include "tensor/gsttensorid.h" + +GST_DEBUG_CATEGORY_STATIC (ssd_object_detector_debug); +#define GST_CAT_DEFAULT ssd_object_detector_debug +#define GST_ODUTILS_MEMBER( self ) ((GstObjectDetectorUtils::GstObjectDetectorUtils *) (self->odutils)) +GST_ELEMENT_REGISTER_DEFINE (ssd_object_detector, "ssdobjectdetector", + GST_RANK_PRIMARY, GST_TYPE_SSD_OBJECT_DETECTOR); + +/* GstSsdObjectDetector properties */ +enum +{ + PROP_0, + PROP_LABEL_FILE, + PROP_SCORE_THRESHOLD, +}; + +#define GST_SSD_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD 0.3f /* 0 to 1 */ + +static GstStaticPadTemplate gst_ssd_object_detector_src_template = +GST_STATIC_PAD_TEMPLATE ("src", + GST_PAD_SRC, + GST_PAD_ALWAYS, + GST_STATIC_CAPS ("video/x-raw") + ); + +static GstStaticPadTemplate gst_ssd_object_detector_sink_template = +GST_STATIC_PAD_TEMPLATE ("sink", + GST_PAD_SINK, + GST_PAD_ALWAYS, + GST_STATIC_CAPS ("video/x-raw") + ); + +static void gst_ssd_object_detector_set_property (GObject * object, + guint prop_id, const GValue * value, GParamSpec * pspec); +static void gst_ssd_object_detector_get_property (GObject * object, + guint prop_id, GValue * value, GParamSpec * pspec); +static void gst_ssd_object_detector_finalize (GObject * object); +static GstFlowReturn gst_ssd_object_detector_transform_ip (GstBaseTransform * + trans, GstBuffer * buf); +static gboolean gst_ssd_object_detector_process (GstBaseTransform * trans, + GstBuffer * buf); +static gboolean +gst_ssd_object_detector_set_caps (GstBaseTransform * trans, GstCaps * incaps, + GstCaps * outcaps); + +G_DEFINE_TYPE (GstSsdObjectDetector, gst_ssd_object_detector, GST_TYPE_BASE_TRANSFORM); + +static void +gst_ssd_object_detector_class_init (GstSsdObjectDetectorClass * klass) +{ + GObjectClass *gobject_class = (GObjectClass *) klass; + GstElementClass *element_class = (GstElementClass *) klass; + GstBaseTransformClass *basetransform_class = (GstBaseTransformClass *) klass; + + GST_DEBUG_CATEGORY_INIT (ssd_object_detector_debug, "ssdobjectdetector", + 0, "ssdobjectdetector"); + gobject_class->set_property = gst_ssd_object_detector_set_property; + gobject_class->get_property = gst_ssd_object_detector_get_property; + gobject_class->finalize = gst_ssd_object_detector_finalize; + + /** + * GstSsdObjectDetector:label-file + * + * Label file + * + * Since: 1.24 + */ + g_object_class_install_property (G_OBJECT_CLASS (klass), PROP_LABEL_FILE, + g_param_spec_string ("label-file", + "Label file", "Label file", NULL, (GParamFlags) + (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS))); + + /** + * GstSsdObjectDetector:score-threshold + * + * Threshold for deciding when to remove boxes based on score + * + * Since: 1.24 + */ + g_object_class_install_property (G_OBJECT_CLASS (klass), PROP_SCORE_THRESHOLD, + g_param_spec_float ("score-threshold", + "Score threshold", + "Threshold for deciding when to remove boxes based on score", + 0.0, 1.0, GST_SSD_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD, (GParamFlags) + (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS))); + + gst_element_class_set_static_metadata (element_class, "objectdetector", + "Filter/Effect/Video", + "Apply tensor output from inference to detect objects in video frames", + "Aaron Boxer , Marcus Edel "); + gst_element_class_add_pad_template (element_class, + gst_static_pad_template_get (&gst_ssd_object_detector_sink_template)); + gst_element_class_add_pad_template (element_class, + gst_static_pad_template_get (&gst_ssd_object_detector_src_template)); + basetransform_class->transform_ip = + GST_DEBUG_FUNCPTR (gst_ssd_object_detector_transform_ip); + basetransform_class->set_caps = + GST_DEBUG_FUNCPTR (gst_ssd_object_detector_set_caps); +} + +static void +gst_ssd_object_detector_init (GstSsdObjectDetector * self) +{ + self->odutils = new GstObjectDetectorUtils::GstObjectDetectorUtils (); +} + +static void +gst_ssd_object_detector_finalize (GObject * object) +{ + GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (object); + + delete GST_ODUTILS_MEMBER (self); + G_OBJECT_CLASS (gst_ssd_object_detector_parent_class)->finalize (object); +} + +static void +gst_ssd_object_detector_set_property (GObject * object, guint prop_id, + const GValue * value, GParamSpec * pspec) +{ + GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (object); + const gchar *filename; + + switch (prop_id) { + case PROP_LABEL_FILE: + filename = g_value_get_string (value); + if (filename + && g_file_test (filename, + (GFileTest) (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR))) { + g_free (self->label_file); + self->label_file = g_strdup (filename); + } else { + GST_WARNING_OBJECT (self, "Label file '%s' not found!", filename); + } + break; + case PROP_SCORE_THRESHOLD: + GST_OBJECT_LOCK (self); + self->score_threshold = g_value_get_float (value); + GST_OBJECT_UNLOCK (self); + break; + default: + G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec); + break; + } +} + +static void +gst_ssd_object_detector_get_property (GObject * object, guint prop_id, + GValue * value, GParamSpec * pspec) +{ + GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (object); + + switch (prop_id) { + case PROP_LABEL_FILE: + g_value_set_string (value, self->label_file); + break; + case PROP_SCORE_THRESHOLD: + GST_OBJECT_LOCK (self); + g_value_set_float (value, self->score_threshold); + GST_OBJECT_UNLOCK (self); + break; + default: + G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec); + break; + } +} + +static GstTensorMeta * +gst_ssd_object_detector_get_tensor_meta (GstSsdObjectDetector * object_detector, + GstBuffer * buf) +{ + GstTensorMeta *tmeta = NULL; + GList *tensor_metas; + GList *iter; + + // get all tensor metas + tensor_metas = gst_tensor_meta_get_all_from_buffer (buf); + if (!tensor_metas) { + GST_TRACE_OBJECT (object_detector, + "missing tensor meta from buffer %" GST_PTR_FORMAT, buf); + goto cleanup; + } + + // find object detector meta + for (iter = tensor_metas; iter != NULL; iter = g_list_next (iter)) { + GstTensorMeta *tensor_meta = (GstTensorMeta *) iter->data; + gint numTensors = tensor_meta->num_tensors; + /* SSD model must have either 3 or 4 output tensor nodes: 4 if there is a label node, + * and only 3 if there is no label */ + if (numTensors != 3 && numTensors != 4) + continue; + + gint boxesIndex = gst_tensor_meta_get_index_from_id (tensor_meta, + gst_tensorid_get_quark (GST_MODEL_OBJECT_DETECTOR_BOXES)); + gint scoresIndex = gst_tensor_meta_get_index_from_id (tensor_meta, + gst_tensorid_get_quark (GST_MODEL_OBJECT_DETECTOR_SCORES)); + gint numDetectionsIndex = gst_tensor_meta_get_index_from_id (tensor_meta, + gst_tensorid_get_quark (GST_MODEL_OBJECT_DETECTOR_NUM_DETECTIONS)); + gint clasesIndex = gst_tensor_meta_get_index_from_id (tensor_meta, + gst_tensorid_get_quark (GST_MODEL_OBJECT_DETECTOR_CLASSES)); + + if (boxesIndex == GST_TENSOR_MISSING_ID || scoresIndex == GST_TENSOR_MISSING_ID + || numDetectionsIndex == GST_TENSOR_MISSING_ID) + continue; + + if (numTensors == 4 && clasesIndex == GST_TENSOR_MISSING_ID) + continue; + + tmeta = tensor_meta; + break; + } + +cleanup: + g_list_free (tensor_metas); + + return tmeta; +} + +static gboolean +gst_ssd_object_detector_set_caps (GstBaseTransform * trans, GstCaps * incaps, + GstCaps * outcaps) +{ + GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (trans); + + if (!gst_video_info_from_caps (&self->video_info, incaps)) { + GST_ERROR_OBJECT (self, "Failed to parse caps"); + return FALSE; + } + + return TRUE; +} + +static GstFlowReturn +gst_ssd_object_detector_transform_ip (GstBaseTransform * trans, GstBuffer * buf) +{ + if (!gst_base_transform_is_passthrough (trans)) { + if (!gst_ssd_object_detector_process (trans, buf)) { + GST_ELEMENT_ERROR (trans, STREAM, FAILED, + (NULL), ("ssd object detection failed")); + return GST_FLOW_ERROR; + } + } + + return GST_FLOW_OK; +} + +static gboolean +gst_ssd_object_detector_process (GstBaseTransform * trans, GstBuffer * buf) +{ + GstTensorMeta *tmeta = NULL; + GstSsdObjectDetector *self = GST_SSD_OBJECT_DETECTOR (trans); + + // get all tensor metas + tmeta = gst_ssd_object_detector_get_tensor_meta (self, buf); + if (!tmeta) { + GST_WARNING_OBJECT (trans, "missing tensor meta"); + return TRUE; + } + + std::vector < GstMlBoundingBox > boxes = + GST_ODUTILS_MEMBER (self)->run (self->video_info.width, + self->video_info.height, tmeta, self->label_file ? self->label_file : "", + self->score_threshold); + + for (auto & b:boxes) { + auto vroi_meta = gst_buffer_add_video_region_of_interest_meta (buf, + GST_SSD_OBJECT_DETECTOR_META_NAME, + b.x0, b.y0, + b.width, + b.height); + if (!vroi_meta) { + GST_WARNING_OBJECT (trans, + "Unable to attach GstVideoRegionOfInterestMeta to buffer"); + return FALSE; + } + auto s = gst_structure_new (GST_SSD_OBJECT_DETECTOR_META_PARAM_NAME, + GST_SSD_OBJECT_DETECTOR_META_FIELD_LABEL, + G_TYPE_STRING, + b.label.c_str (), + GST_SSD_OBJECT_DETECTOR_META_FIELD_SCORE, + G_TYPE_DOUBLE, + b.score, + NULL); + gst_video_region_of_interest_meta_add_param (vroi_meta, s); + GST_DEBUG_OBJECT (self, + "Object detected with label : %s, score: %f, bound box: (%f,%f,%f,%f) \n", + b.label.c_str (), b.score, b.x0, b.y0, b.x0 + b.width, b.y0 + b.height); + } + + return TRUE; +} diff --git a/subprojects/gst-plugins-bad/ext/onnx/decoders/gstssdobjectdetector.h b/subprojects/gst-plugins-bad/ext/onnx/decoders/gstssdobjectdetector.h new file mode 100644 index 0000000000..4549ad4c92 --- /dev/null +++ b/subprojects/gst-plugins-bad/ext/onnx/decoders/gstssdobjectdetector.h @@ -0,0 +1,78 @@ +/* + * GStreamer gstreamer-ssdobjectdetector + * Copyright (C) 2021 Collabora Ltd + * + * gstssdobjectdetector.h + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Library General Public + * License as published by the Free Software Foundation; either + * version 2 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Library General Public License for more details. + * + * You should have received a copy of the GNU Library General Public + * License along with this library; if not, write to the + * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, + * Boston, MA 02110-1301, USA. + */ + +#ifndef __GST_SSD_OBJECT_DETECTOR_H__ +#define __GST_SSD_OBJECT_DETECTOR_H__ + +#include +#include +#include + +G_BEGIN_DECLS + +#define GST_TYPE_SSD_OBJECT_DETECTOR (gst_ssd_object_detector_get_type()) +G_DECLARE_FINAL_TYPE (GstSsdObjectDetector, gst_ssd_object_detector, GST, SSD_OBJECT_DETECTOR, GstBaseTransform) + +#define GST_SSD_OBJECT_DETECTOR_META_NAME "ssd-object-detector" +#define GST_SSD_OBJECT_DETECTOR_META_PARAM_NAME "extra-data" +#define GST_SSD_OBJECT_DETECTOR_META_FIELD_LABEL "label" +#define GST_SSD_OBJECT_DETECTOR_META_FIELD_SCORE "score" + +/** + * GstSsdObjectDetector: + * + * @label_file label file + * @score_threshold score threshold + * @confidence_threshold confidence threshold + * @iou_threhsold iou threshold + * @od_ptr opaque pointer to GstOd object detection implementation + * + * Since: 1.20 + */ +struct _GstSsdObjectDetector +{ + GstBaseTransform basetransform; + gchar *label_file; + gfloat score_threshold; + gfloat confidence_threshold; + gfloat iou_threshold; + gpointer odutils; + GstVideoInfo video_info; +}; + +/** + * GstSsdObjectDetectorClass: + * + * @parent_class base transform base class + * + * Since: 1.20 + */ +struct _GstSsdObjectDetectorClass +{ + GstBaseTransformClass parent_class; +}; + +GST_ELEMENT_REGISTER_DECLARE (ssd_object_detector) + +G_END_DECLS + +#endif /* __GST_SSD_OBJECT_DETECTOR_H__ */ diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstml.h b/subprojects/gst-plugins-bad/ext/onnx/gstml.h new file mode 100644 index 0000000000..a36c37ecb2 --- /dev/null +++ b/subprojects/gst-plugins-bad/ext/onnx/gstml.h @@ -0,0 +1,41 @@ +/* + * GStreamer gstreamer-ml + * Copyright (C) 2021 Collabora Ltd + * + * gstml.h + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Library General Public + * License as published by the Free Software Foundation; either + * version 2 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Library General Public License for more details. + * + * You should have received a copy of the GNU Library General Public + * License along with this library; if not, write to the + * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, + * Boston, MA 02110-1301, USA. + */ +#ifndef __GST_ML_H__ +#define __GST_ML_H__ + + +/** + * GstMlInputImageFormat: + * + * @GST_ML_INPUT_IMAGE_FORMAT_HWC Height Width Channel (a.k.a. interleaved) format + * @GST_ML_INPUT_IMAGE_FORMAT_CHW Channel Height Width (a.k.a. planar) format + * + * Since: 1.20 + */ +typedef enum { + GST_ML_INPUT_IMAGE_FORMAT_HWC, + GST_ML_INPUT_IMAGE_FORMAT_CHW, +} GstMlInputImageFormat; + + + +#endif diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstonnx.c b/subprojects/gst-plugins-bad/ext/onnx/gstonnx.c index 4ee438a1c6..0f567b7ad1 100644 --- a/subprojects/gst-plugins-bad/ext/onnx/gstonnx.c +++ b/subprojects/gst-plugins-bad/ext/onnx/gstonnx.c @@ -1,3 +1,4 @@ + /* * GStreamer gstreamer-onnx * Copyright (C) 2021 Collabora Ltd @@ -23,14 +24,17 @@ #include "config.h" #endif -#include "gstonnxobjectdetector.h" +#include "decoders/gstssdobjectdetector.h" +#include "gstonnxinference.h" +#include "tensor/gsttensormeta.h" static gboolean plugin_init (GstPlugin * plugin) { - GST_ELEMENT_REGISTER (onnx_object_detector, plugin); + gboolean success = GST_ELEMENT_REGISTER (ssd_object_detector, plugin); + success |= GST_ELEMENT_REGISTER (onnx_inference, plugin); - return TRUE; + return success; } GST_PLUGIN_DEFINE (GST_VERSION_MAJOR, diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp index f8df0f3374..c7653e9119 100644 --- a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp +++ b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp @@ -21,23 +21,15 @@ */ #include "gstonnxclient.h" +#include #include -#ifdef GST_ML_ONNX_RUNTIME_HAVE_CUDA -#include -#endif -#include -#include -#include -#include -#include -#include #include namespace GstOnnxNamespace { -template < typename T > - std::ostream & operator<< (std::ostream & os, const std::vector < T > &v) -{ + template < typename T > + std::ostream & operator<< (std::ostream & os, const std::vector < T > &v) + { os << "["; for (size_t i = 0; i < v.size (); ++i) { @@ -50,13 +42,7 @@ template < typename T > os << "]"; return os; -} - -GstMlOutputNodeInfo::GstMlOutputNodeInfo (void):index - (GST_ML_NODE_INDEX_DISABLED), - type (ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) -{ -} + } GstOnnxClient::GstOnnxClient ():session (nullptr), width (0), @@ -64,123 +50,59 @@ GstOnnxClient::GstOnnxClient ():session (nullptr), channels (0), dest (nullptr), m_provider (GST_ONNX_EXECUTION_PROVIDER_CPU), - inputImageFormat (GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC), - fixedInputImageSize (true) -{ - for (size_t i = 0; i < GST_ML_OUTPUT_NODE_NUMBER_OF; ++i) - outputNodeIndexToFunction[i] = (GstMlOutputNodeFunction) i; -} + inputImageFormat (GST_ML_INPUT_IMAGE_FORMAT_HWC), + fixedInputImageSize (false) { + } -GstOnnxClient::~GstOnnxClient () -{ - outputNames.clear(); + GstOnnxClient::~GstOnnxClient () { delete session; delete[]dest; -} + } -Ort::Env & GstOnnxClient::getEnv (void) -{ - static Ort::Env env (OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, - "GstOnnxNamespace"); - - return env; -} - -int32_t GstOnnxClient::getWidth (void) -{ + int32_t GstOnnxClient::getWidth (void) + { return width; -} + } -int32_t GstOnnxClient::getHeight (void) -{ + int32_t GstOnnxClient::getHeight (void) + { return height; -} + } -bool GstOnnxClient::isFixedInputImageSize (void) -{ + bool GstOnnxClient::isFixedInputImageSize (void) + { return fixedInputImageSize; -} + } -std::string GstOnnxClient::getOutputNodeName (GstMlOutputNodeFunction nodeType) -{ - switch (nodeType) { - case GST_ML_OUTPUT_NODE_FUNCTION_DETECTION: - return "detection"; - break; - case GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX: - return "bounding box"; - break; - case GST_ML_OUTPUT_NODE_FUNCTION_SCORE: - return "score"; - break; - case GST_ML_OUTPUT_NODE_FUNCTION_CLASS: - return "label"; - break; - case GST_ML_OUTPUT_NODE_NUMBER_OF: - g_assert_not_reached(); - GST_WARNING("Invalid parameter"); - break; - }; - - return ""; -} - -void GstOnnxClient::setInputImageFormat (GstMlModelInputImageFormat format) -{ + void GstOnnxClient::setInputImageFormat (GstMlInputImageFormat format) + { inputImageFormat = format; -} + } -GstMlModelInputImageFormat GstOnnxClient::getInputImageFormat (void) -{ + GstMlInputImageFormat GstOnnxClient::getInputImageFormat (void) + { return inputImageFormat; -} + } -std::vector< const char *> GstOnnxClient::getOutputNodeNames (void) -{ - if (!outputNames.empty() && outputNamesRaw.size() != outputNames.size()) { - outputNamesRaw.resize(outputNames.size()); - for (size_t i = 0; i < outputNamesRaw.size(); i++) { - outputNamesRaw[i] = outputNames[i].get(); - } + std::vector < const char *>GstOnnxClient::genOutputNamesRaw (void) + { + if (!outputNames.empty () && outputNamesRaw.size () != outputNames.size ()) { + outputNamesRaw.resize (outputNames.size ()); + for (size_t i = 0; i < outputNamesRaw.size (); i++) + outputNamesRaw[i] = outputNames[i].get (); } return outputNamesRaw; -} + } -void GstOnnxClient::setOutputNodeIndex (GstMlOutputNodeFunction node, - gint index) -{ - g_assert (index < GST_ML_OUTPUT_NODE_NUMBER_OF); - outputNodeInfo[node].index = index; - if (index != GST_ML_NODE_INDEX_DISABLED) - outputNodeIndexToFunction[index] = node; -} - -gint GstOnnxClient::getOutputNodeIndex (GstMlOutputNodeFunction node) -{ - return outputNodeInfo[node].index; -} - -void GstOnnxClient::setOutputNodeType (GstMlOutputNodeFunction node, - ONNXTensorElementDataType type) -{ - outputNodeInfo[node].type = type; -} - -ONNXTensorElementDataType - GstOnnxClient::getOutputNodeType (GstMlOutputNodeFunction node) -{ - return outputNodeInfo[node].type; -} - -bool GstOnnxClient::hasSession (void) -{ + bool GstOnnxClient::hasSession (void) + { return session != nullptr; -} + } -bool GstOnnxClient::createSession (std::string modelFile, + bool GstOnnxClient::createSession (std::string modelFile, GstOnnxOptimizationLevel optim, GstOnnxExecutionProvider provider) -{ + { if (session) return true; @@ -205,30 +127,43 @@ bool GstOnnxClient::createSession (std::string modelFile, try { Ort::SessionOptions sessionOptions; + const auto & api = Ort::GetApi (); // for debugging //sessionOptions.SetIntraOpNumThreads (1); sessionOptions.SetGraphOptimizationLevel (onnx_optim); m_provider = provider; switch (m_provider) { case GST_ONNX_EXECUTION_PROVIDER_CUDA: -#ifdef GST_ML_ONNX_RUNTIME_HAVE_CUDA - Ort::ThrowOnError (OrtSessionOptionsAppendExecutionProvider_CUDA - (sessionOptions, 0)); -#else - GST_ERROR ("ONNX CUDA execution provider not supported"); - return false; -#endif + try { + OrtCUDAProviderOptionsV2 *cuda_options = nullptr; + Ort::ThrowOnError (api.CreateCUDAProviderOptions (&cuda_options)); + std::unique_ptr < OrtCUDAProviderOptionsV2, + decltype (api.ReleaseCUDAProviderOptions) > + rel_cuda_options (cuda_options, api.ReleaseCUDAProviderOptions); + Ort::ThrowOnError (api.SessionOptionsAppendExecutionProvider_CUDA_V2 + (static_cast < OrtSessionOptions * >(sessionOptions), + rel_cuda_options.get ())); + } + catch (Ort::Exception & ortex) { + GST_WARNING + ("Failed to create CUDA provider - dropping back to CPU"); + Ort::ThrowOnError (OrtSessionOptionsAppendExecutionProvider_CPU + (sessionOptions, 1)); + } break; default: + Ort::ThrowOnError (OrtSessionOptionsAppendExecutionProvider_CPU + (sessionOptions, 1)); break; - }; - session = - new Ort::Session (getEnv (), modelFile.c_str (), sessionOptions); + env = + Ort::Env (OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + "GstOnnxNamespace"); + session = new Ort::Session (env, modelFile.c_str (), sessionOptions); auto inputTypeInfo = session->GetInputTypeInfo (0); std::vector < int64_t > inputDims = inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape (); - if (inputImageFormat == GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC) { + if (inputImageFormat == GST_ML_INPUT_IMAGE_FORMAT_HWC) { height = inputDims[1]; width = inputDims[2]; channels = inputDims[3]; @@ -250,14 +185,37 @@ bool GstOnnxClient::createSession (std::string modelFile, auto output_name = session->GetOutputNameAllocated (i, allocator); GST_DEBUG ("Output name %lu:%s", i, output_name.get ()); outputNames.push_back (std::move (output_name)); - auto type_info = session->GetOutputTypeInfo (i); - auto tensor_info = type_info.GetTensorTypeAndShapeInfo (); + } + genOutputNamesRaw (); - if (i < GST_ML_OUTPUT_NODE_NUMBER_OF) { - auto function = outputNodeIndexToFunction[i]; - outputNodeInfo[function].type = tensor_info.GetElementType (); + // look up tensor ids + auto metaData = session->GetModelMetadata (); + OrtAllocator *ortAllocator; + auto status = + Ort::GetApi ().GetAllocatorWithDefaultOptions (&ortAllocator); + if (status) { + // Handle the error case + const char *errorString = Ort::GetApi ().GetErrorMessage (status); + GST_WARNING ("Failed to get allocator: %s", errorString); + + // Clean up the error status + Ort::GetApi ().ReleaseStatus (status); + + return false; + } else { + for (auto & name:outputNamesRaw) { + Ort::AllocatedStringPtr res = + metaData.LookupCustomMetadataMapAllocated (name, ortAllocator); + if (res) { + GQuark quark = gst_tensorid_get_quark (res.get ()); + outputIds.push_back (quark); + } else { + GST_ERROR ("Failed to look up id for key %s", name); + + return false; } } + } } catch (Ort::Exception & ortex) { GST_ERROR ("%s", ortex.what ()); @@ -265,40 +223,110 @@ bool GstOnnxClient::createSession (std::string modelFile, } return true; -} + } -std::vector < GstMlBoundingBox > GstOnnxClient::run (uint8_t * img_data, - GstVideoMeta * vmeta, std::string labelPath, float scoreThreshold) -{ - auto type = getOutputNodeType (GST_ML_OUTPUT_NODE_FUNCTION_CLASS); - return (type == - ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) ? - doRun < float >(img_data, vmeta, labelPath, scoreThreshold) - : doRun < int >(img_data, vmeta, labelPath, scoreThreshold); -} - -void GstOnnxClient::parseDimensions (GstVideoMeta * vmeta) -{ - int32_t newWidth = fixedInputImageSize ? width : vmeta->width; - int32_t newHeight = fixedInputImageSize ? height : vmeta->height; + void GstOnnxClient::parseDimensions (GstVideoInfo vinfo) + { + int32_t newWidth = fixedInputImageSize ? width : vinfo.width; + int32_t newHeight = fixedInputImageSize ? height : vinfo.height; if (!dest || width * height < newWidth * newHeight) { - delete[] dest; + delete[]dest; dest = new uint8_t[newWidth * newHeight * channels]; } width = newWidth; height = newHeight; -} + } -template < typename T > std::vector < GstMlBoundingBox > - GstOnnxClient::doRun (uint8_t * img_data, GstVideoMeta * vmeta, - std::string labelPath, float scoreThreshold) -{ - std::vector < GstMlBoundingBox > boundingBoxes; +// copy tensor data to a GstTensorMeta + GstTensorMeta *GstOnnxClient::copy_tensors_to_meta (std::vector < Ort::Value > + &outputs, GstBuffer * buffer) + { + size_t num_tensors = outputNamesRaw.size (); + GstTensorMeta *tmeta = (GstTensorMeta *) gst_buffer_add_meta (buffer, + gst_tensor_meta_get_info (), + NULL); + tmeta->num_tensors = num_tensors; + tmeta->tensor = (GstTensor *) g_malloc (num_tensors * sizeof (GstTensor)); + bool hasIds = outputIds.size () == num_tensors; + for (size_t i = 0; i < num_tensors; i++) { + Ort::Value outputTensor = std::move (outputs[i]); + + ONNXTensorElementDataType tensorType = + outputTensor.GetTensorTypeAndShapeInfo ().GetElementType (); + + GstTensor *tensor = &tmeta->tensor[i]; + if (hasIds) + tensor->id = outputIds[i]; + tensor->data = gst_buffer_new (); + auto tensorShape = outputTensor.GetTensorTypeAndShapeInfo ().GetShape (); + tensor->num_dims = tensorShape.size (); + tensor->dims = g_new (int64_t, tensor->num_dims); + + for (size_t j = 0; j < tensorShape.size (); ++j) { + tensor->dims[j] = tensorShape[j]; + } + + size_t numElements = + outputTensor.GetTensorTypeAndShapeInfo ().GetElementCount (); + + size_t buffer_size = 0; + guint8 *buffer_data = NULL; + if (tensorType == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { + buffer_size = numElements * sizeof (float); + + // Allocate memory for the buffer data + buffer_data = (guint8 *) malloc (buffer_size); + if (buffer_data == NULL) { + GST_ERROR ("Failed to allocate memory"); + return NULL; + } + // Copy the data from the source buffer to the allocated memory + memcpy (buffer_data, outputTensor.GetTensorData < float >(), + buffer_size); + tensor->type = GST_TENSOR_TYPE_FLOAT32; + } else if (tensorType == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) { + buffer_size = numElements * sizeof (int); + + // Allocate memory for the buffer data + guint8 *buffer_data = (guint8 *) malloc (buffer_size); + if (buffer_data == NULL) { + GST_ERROR ("Failed to allocate memory"); + return NULL; + } + // Copy the data from the source buffer to the allocated memory + memcpy (buffer_data, outputTensor.GetTensorData < int >(), + buffer_size); + tensor->type = GST_TENSOR_TYPE_INT32; + } + if (buffer_data) { + + // Create a GstMemory object from the allocated memory + GstMemory *memory = gst_memory_new_wrapped ((GstMemoryFlags) 0, + buffer_data, buffer_size, 0, buffer_size, NULL, NULL); + + // Append the GstMemory object to the GstBuffer + gst_buffer_append_memory (tmeta->tensor[i].data, memory); + } + } + + return tmeta; + + } + + std::vector < Ort::Value > GstOnnxClient::run (uint8_t * img_data, + GstVideoInfo vinfo) { + std::vector < Ort::Value > modelOutput; + doRun (img_data, vinfo, modelOutput); + + return modelOutput; + } + + bool GstOnnxClient::doRun (uint8_t * img_data, GstVideoInfo vinfo, + std::vector < Ort::Value > &modelOutput) + { if (!img_data) - return boundingBoxes; - - parseDimensions (vmeta); + return false; Ort::AllocatorWithDefaultOptions allocator; auto inputName = session->GetInputNameAllocated (0, allocator); @@ -306,7 +334,7 @@ template < typename T > std::vector < GstMlBoundingBox > std::vector < int64_t > inputDims = inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape (); inputDims[0] = 1; - if (inputImageFormat == GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC) { + if (inputImageFormat == GST_ML_INPUT_IMAGE_FORMAT_HWC) { inputDims[1] = height; inputDims[2] = width; } else { @@ -321,7 +349,7 @@ template < typename T > std::vector < GstMlBoundingBox > // copy video frame uint8_t *srcPtr[3] = { img_data, img_data + 1, img_data + 2 }; uint32_t srcSamplesPerPixel = 3; - switch (vmeta->format) { + switch (vinfo.finfo->format) { case GST_VIDEO_FORMAT_RGBA: srcSamplesPerPixel = 4; break; @@ -352,8 +380,8 @@ template < typename T > std::vector < GstMlBoundingBox > break; } size_t destIndex = 0; - uint32_t stride = vmeta->stride[0]; - if (inputImageFormat == GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC) { + uint32_t stride = vinfo.stride[0]; + if (inputImageFormat == GST_ML_INPUT_IMAGE_FORMAT_HWC) { for (int32_t j = 0; j < height; ++j) { for (int32_t i = 0; i < width; ++i) { for (int32_t k = 0; k < channels; ++k) { @@ -389,58 +417,17 @@ template < typename T > std::vector < GstMlBoundingBox > std::vector < Ort::Value > inputTensors; inputTensors.push_back (Ort::Value::CreateTensor < uint8_t > (memoryInfo, dest, inputTensorSize, inputDims.data (), inputDims.size ())); - std::vector < const char *>inputNames { inputName.get () }; + std::vector < const char *>inputNames + { + inputName.get ()}; - std::vector < Ort::Value > modelOutput = session->Run (Ort::RunOptions { nullptr}, + modelOutput = session->Run (Ort::RunOptions { + nullptr}, inputNames.data (), - inputTensors.data (), 1, outputNamesRaw.data (), outputNamesRaw.size ()); + inputTensors.data (), 1, outputNamesRaw.data (), + outputNamesRaw.size ()); - auto numDetections = - modelOutput[getOutputNodeIndex - (GST_ML_OUTPUT_NODE_FUNCTION_DETECTION)].GetTensorMutableData < float >(); - auto bboxes = - modelOutput[getOutputNodeIndex - (GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX)].GetTensorMutableData < float >(); - auto scores = - modelOutput[getOutputNodeIndex - (GST_ML_OUTPUT_NODE_FUNCTION_SCORE)].GetTensorMutableData < float >(); - T *labelIndex = nullptr; - if (getOutputNodeIndex (GST_ML_OUTPUT_NODE_FUNCTION_CLASS) != - GST_ML_NODE_INDEX_DISABLED) { - labelIndex = - modelOutput[getOutputNodeIndex - (GST_ML_OUTPUT_NODE_FUNCTION_CLASS)].GetTensorMutableData < T > (); - } - if (labels.empty () && !labelPath.empty ()) - labels = ReadLabels (labelPath); - - for (int i = 0; i < numDetections[0]; ++i) { - if (scores[i] > scoreThreshold) { - std::string label = ""; - - if (labelIndex && !labels.empty ()) - label = labels[labelIndex[i] - 1]; - auto score = scores[i]; - auto y0 = bboxes[i * 4] * height; - auto x0 = bboxes[i * 4 + 1] * width; - auto bheight = bboxes[i * 4 + 2] * height - y0; - auto bwidth = bboxes[i * 4 + 3] * width - x0; - boundingBoxes.push_back (GstMlBoundingBox (label, score, x0, y0, bwidth, - bheight)); - } - } - return boundingBoxes; -} - -std::vector < std::string > - GstOnnxClient::ReadLabels (const std::string & labelsFile) -{ - std::vector < std::string > labels; - std::string line; - std::ifstream fp (labelsFile); - while (std::getline (fp, line)) - labels.push_back (line); - - return labels; + return true; } + } diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.h b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.h index edbc2f4655..18a661a8d1 100644 --- a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.h +++ b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.h @@ -25,45 +25,11 @@ #include #include #include -#include "gstonnxelement.h" -#include -#include +#include "gstml.h" +#include "gstonnxenums.h" +#include "tensor/gsttensormeta.h" namespace GstOnnxNamespace { - enum GstMlOutputNodeFunction { - GST_ML_OUTPUT_NODE_FUNCTION_DETECTION, - GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX, - GST_ML_OUTPUT_NODE_FUNCTION_SCORE, - GST_ML_OUTPUT_NODE_FUNCTION_CLASS, - GST_ML_OUTPUT_NODE_NUMBER_OF, - }; - - const gint GST_ML_NODE_INDEX_DISABLED = -1; - - struct GstMlOutputNodeInfo { - GstMlOutputNodeInfo(void); - gint index; - ONNXTensorElementDataType type; - }; - - struct GstMlBoundingBox { - GstMlBoundingBox(std::string lbl, - float score, - float _x0, - float _y0, - float _width, - float _height):label(lbl), - score(score), x0(_x0), y0(_y0), width(_width), height(_height) { - } - GstMlBoundingBox():GstMlBoundingBox("", 0.0f, 0.0f, 0.0f, 0.0f, 0.0f) { - } - std::string label; - float score; - float x0; - float y0; - float width; - float height; - }; class GstOnnxClient { public: @@ -72,30 +38,18 @@ namespace GstOnnxNamespace { bool createSession(std::string modelFile, GstOnnxOptimizationLevel optim, GstOnnxExecutionProvider provider); bool hasSession(void); - void setInputImageFormat(GstMlModelInputImageFormat format); - GstMlModelInputImageFormat getInputImageFormat(void); - void setOutputNodeIndex(GstMlOutputNodeFunction nodeType, gint index); - gint getOutputNodeIndex(GstMlOutputNodeFunction nodeType); - void setOutputNodeType(GstMlOutputNodeFunction nodeType, - ONNXTensorElementDataType type); - ONNXTensorElementDataType getOutputNodeType(GstMlOutputNodeFunction type); - std::string getOutputNodeName(GstMlOutputNodeFunction nodeType); - std::vector < GstMlBoundingBox > run(uint8_t * img_data, - GstVideoMeta * vmeta, - std::string labelPath, - float scoreThreshold); - std::vector < GstMlBoundingBox > &getBoundingBoxes(void); - std::vector < const char *>getOutputNodeNames(void); + void setInputImageFormat(GstMlInputImageFormat format); + GstMlInputImageFormat getInputImageFormat(void); + std::vector < Ort::Value > run (uint8_t * img_data, GstVideoInfo vinfo); + std::vector < const char *> genOutputNamesRaw(void); bool isFixedInputImageSize(void); int32_t getWidth(void); int32_t getHeight(void); + GstTensorMeta* copy_tensors_to_meta(std::vector < Ort::Value > &outputs,GstBuffer* buffer); + void parseDimensions(GstVideoInfo vinfo); private: - void parseDimensions(GstVideoMeta * vmeta); - template < typename T > std::vector < GstMlBoundingBox > - doRun(uint8_t * img_data, GstVideoMeta * vmeta, std::string labelPath, - float scoreThreshold); - std::vector < std::string > ReadLabels(const std::string & labelsFile); - Ort::Env & getEnv(void); + bool doRun(uint8_t * img_data, GstVideoInfo vinfo, std::vector < Ort::Value > &modelOutput); + Ort::Env env; Ort::Session * session; int32_t width; int32_t height; @@ -104,13 +58,10 @@ namespace GstOnnxNamespace { GstOnnxExecutionProvider m_provider; std::vector < Ort::Value > modelOutput; std::vector < std::string > labels; - // !! indexed by function - GstMlOutputNodeInfo outputNodeInfo[GST_ML_OUTPUT_NODE_NUMBER_OF]; - // !! indexed by array index - size_t outputNodeIndexToFunction[GST_ML_OUTPUT_NODE_NUMBER_OF]; std::vector < const char *> outputNamesRaw; std::vector < Ort::AllocatedStringPtr > outputNames; - GstMlModelInputImageFormat inputImageFormat; + std::vector < GQuark > outputIds; + GstMlInputImageFormat inputImageFormat; bool fixedInputImageSize; }; } diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstonnxelement.c b/subprojects/gst-plugins-bad/ext/onnx/gstonnxelement.c deleted file mode 100644 index 99e64ff3ac..0000000000 --- a/subprojects/gst-plugins-bad/ext/onnx/gstonnxelement.c +++ /dev/null @@ -1,104 +0,0 @@ -/* - * GStreamer gstreamer-onnxelement - * Copyright (C) 2021 Collabora Ltd - * - * gstonnxelement.c - * - * This library is free software; you can redistribute it and/or - * modify it under the terms of the GNU Library General Public - * License as published by the Free Software Foundation; either - * version 2 of the License, or (at your option) any later version. - * - * This library is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - * Library General Public License for more details. - * - * You should have received a copy of the GNU Library General Public - * License along with this library; if not, write to the - * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, - * Boston, MA 02110-1301, USA. - */ -#ifdef HAVE_CONFIG_H -#include "config.h" -#endif - -#include "gstonnxelement.h" - -GType -gst_onnx_optimization_level_get_type (void) -{ - static GType onnx_optimization_type = 0; - - if (g_once_init_enter (&onnx_optimization_type)) { - static GEnumValue optimization_level_types[] = { - {GST_ONNX_OPTIMIZATION_LEVEL_DISABLE_ALL, "Disable all optimization", - "disable-all"}, - {GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_BASIC, - "Enable basic optimizations (redundant node removals))", - "enable-basic"}, - {GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED, - "Enable extended optimizations (redundant node removals + node fusions)", - "enable-extended"}, - {GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_ALL, - "Enable all possible optimizations", "enable-all"}, - {0, NULL, NULL}, - }; - - GType temp = g_enum_register_static ("GstOnnxOptimizationLevel", - optimization_level_types); - - g_once_init_leave (&onnx_optimization_type, temp); - } - - return onnx_optimization_type; -} - -GType -gst_onnx_execution_provider_get_type (void) -{ - static GType onnx_execution_type = 0; - - if (g_once_init_enter (&onnx_execution_type)) { - static GEnumValue execution_provider_types[] = { - {GST_ONNX_EXECUTION_PROVIDER_CPU, "CPU execution provider", - "cpu"}, - {GST_ONNX_EXECUTION_PROVIDER_CUDA, - "CUDA execution provider", - "cuda"}, - {0, NULL, NULL}, - }; - - GType temp = g_enum_register_static ("GstOnnxExecutionProvider", - execution_provider_types); - - g_once_init_leave (&onnx_execution_type, temp); - } - - return onnx_execution_type; -} - -GType -gst_ml_model_input_image_format_get_type (void) -{ - static GType ml_model_input_image_format = 0; - - if (g_once_init_enter (&ml_model_input_image_format)) { - static GEnumValue ml_model_input_image_format_types[] = { - {GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC, - "Height Width Channel (HWC) a.k.a. interleaved image data format", - "hwc"}, - {GST_ML_MODEL_INPUT_IMAGE_FORMAT_CHW, - "Channel Height Width (CHW) a.k.a. planar image data format", - "chw"}, - {0, NULL, NULL}, - }; - - GType temp = g_enum_register_static ("GstMlModelInputImageFormat", - ml_model_input_image_format_types); - - g_once_init_leave (&ml_model_input_image_format, temp); - } - - return ml_model_input_image_format; -} diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstonnxelement.h b/subprojects/gst-plugins-bad/ext/onnx/gstonnxenums.h similarity index 57% rename from subprojects/gst-plugins-bad/ext/onnx/gstonnxelement.h rename to subprojects/gst-plugins-bad/ext/onnx/gstonnxenums.h index 338948359e..6cc8658b91 100644 --- a/subprojects/gst-plugins-bad/ext/onnx/gstonnxelement.h +++ b/subprojects/gst-plugins-bad/ext/onnx/gstonnxenums.h @@ -1,8 +1,8 @@ /* - * GStreamer gstreamer-onnxelement + * GStreamer gstreamer-onnxenums * Copyright (C) 2021 Collabora Ltd * - * gstonnxelement.h + * gstonnxenums.h * * This library is free software; you can redistribute it and/or * modify it under the terms of the GNU Library General Public @@ -20,10 +20,8 @@ * Boston, MA 02110-1301, USA. */ -#ifndef __GST_ONNX_ELEMENT_H__ -#define __GST_ONNX_ELEMENT_H__ - -#include +#ifndef __GST_ONNX_ENUMS_H__ +#define __GST_ONNX_ENUMS_H__ typedef enum { @@ -39,26 +37,5 @@ typedef enum GST_ONNX_EXECUTION_PROVIDER_CUDA, } GstOnnxExecutionProvider; -typedef enum { - /* Height Width Channel (a.k.a. interleaved) format */ - GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC, - /* Channel Height Width (a.k.a. planar) format */ - GST_ML_MODEL_INPUT_IMAGE_FORMAT_CHW, -} GstMlModelInputImageFormat; - - -G_BEGIN_DECLS - -GType gst_onnx_optimization_level_get_type (void); -#define GST_TYPE_ONNX_OPTIMIZATION_LEVEL (gst_onnx_optimization_level_get_type ()) - -GType gst_onnx_execution_provider_get_type (void); -#define GST_TYPE_ONNX_EXECUTION_PROVIDER (gst_onnx_execution_provider_get_type ()) - -GType gst_ml_model_input_image_format_get_type (void); -#define GST_TYPE_ML_MODEL_INPUT_IMAGE_FORMAT (gst_ml_model_input_image_format_get_type ()) - -G_END_DECLS - -#endif +#endif /* __GST_ONNX_ENUMS_H__ */ diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstonnxinference.cpp b/subprojects/gst-plugins-bad/ext/onnx/gstonnxinference.cpp new file mode 100644 index 0000000000..15d68ed7b5 --- /dev/null +++ b/subprojects/gst-plugins-bad/ext/onnx/gstonnxinference.cpp @@ -0,0 +1,539 @@ +/* + * GStreamer gstreamer-onnxinference + * Copyright (C) 2023 Collabora Ltd. + * + * gstonnxinference.cpp + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Library General Public + * License as published by the Free Software Foundation; either + * version 2 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Library General Public License for more details. + * + * You should have received a copy of the GNU Library General Public + * License along with this library; if not, write to the + * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, + * Boston, MA 02110-1301, USA. + */ + +/** + * SECTION:element-onnxinference + * @short_description: Run ONNX inference model on video buffers + * + * This element can apply an ONNX model to video buffers. It attaches + * the tensor output to the buffer as a @ref GstTensorMeta. + * + * To install ONNX on your system, recursively clone this repository + * https://github.com/microsoft/onnxruntime.git + * + * and build and install with cmake: + * + * CPU: + * + * cmake -Donnxruntime_BUILD_SHARED_LIB:ON -DBUILD_TESTING:OFF \ + * $SRC_DIR/onnxruntime/cmake && make -j$(nproc) && sudo make install + * + * + * CUDA : + * + * cmake -Donnxruntime_BUILD_SHARED_LIB:ON -DBUILD_TESTING:OFF -Donnxruntime_USE_CUDA:ON \ + * -Donnxruntime_CUDA_HOME=$CUDA_PATH -Donnxruntime_CUDNN_HOME=$CUDA_PATH \ + * $SRC_DIR/onnxruntime/cmake && make -j$(nproc) && sudo make install + * + * + * where : + * + * 1. $SRC_DIR and $BUILD_DIR are local source and build directories + * 2. To run with CUDA, both CUDA and cuDNN libraries must be installed. + * $CUDA_PATH is an environment variable set to the CUDA root path. + * On Linux, it would be /usr/local/cuda + * + * + * ## Example launch command: + * + * GST_DEBUG=onnxinference:5 gst-launch-1.0 multifilesrc location=bus.jpg ! \ + * jpegdec ! videoconvert ! \ + * onnxinference execution-provider=cpu model-file=model.onnx \ + * videoconvert ! autovideosink + * + * + * Note: in order for downstream tensor decoders to correctly parse the tensor + * data in the GstTensorMeta, meta data must be attached to the ONNX model + * assigning a unique string id to each output layer. These unique string ids + * and corresponding GQuark ids are currently stored in the ONNX plugin source + * in the file 'gsttensorid.h'. For an output layer with name Foo and with context + * unique string id Gst.Model.Bar, a meta data key/value pair must be added + * to the ONNX model with "Foo" mapped to "Gst.Model.Bar" in order for a downstream + * decoder to make use of this model. If the meta data is absent, the pipeline will + * fail. + * + * As a convenience, there is a python script + * currently stored at + * https://gitlab.collabora.com/gstreamer/onnx-models/-/blob/master/scripts/modify_onnx_metadata.py + * to enable users to easily add and remove meta data from json files. It can also dump + * the names of all output layers, which can then be used to craft the json meta data file. + * + * + */ +#ifdef HAVE_CONFIG_H +#include "config.h" +#endif + +#include +#include "gstonnxinference.h" +#include "gstonnxclient.h" + +GST_DEBUG_CATEGORY_STATIC (onnx_inference_debug); +#define GST_CAT_DEFAULT onnx_inference_debug +#define GST_ONNX_CLIENT_MEMBER( self ) ((GstOnnxNamespace::GstOnnxClient *) (self->onnx_client)) +GST_ELEMENT_REGISTER_DEFINE (onnx_inference, "onnxinference", + GST_RANK_PRIMARY, GST_TYPE_ONNX_INFERENCE); + +/* GstOnnxInference properties */ +enum +{ + PROP_0, + PROP_MODEL_FILE, + PROP_INPUT_IMAGE_FORMAT, + PROP_OPTIMIZATION_LEVEL, + PROP_EXECUTION_PROVIDER +}; + +#define GST_ONNX_INFERENCE_DEFAULT_EXECUTION_PROVIDER GST_ONNX_EXECUTION_PROVIDER_CPU +#define GST_ONNX_INFERENCE_DEFAULT_OPTIMIZATION_LEVEL GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED + +static GstStaticPadTemplate gst_onnx_inference_src_template = +GST_STATIC_PAD_TEMPLATE ("src", + GST_PAD_SRC, + GST_PAD_ALWAYS, + GST_STATIC_CAPS (GST_VIDEO_CAPS_MAKE ("{ RGB,RGBA,BGR,BGRA }")) + ); + +static GstStaticPadTemplate gst_onnx_inference_sink_template = +GST_STATIC_PAD_TEMPLATE ("sink", + GST_PAD_SINK, + GST_PAD_ALWAYS, + GST_STATIC_CAPS (GST_VIDEO_CAPS_MAKE ("{ RGB,RGBA,BGR,BGRA }")) + ); + +static void gst_onnx_inference_set_property (GObject * object, + guint prop_id, const GValue * value, GParamSpec * pspec); +static void gst_onnx_inference_get_property (GObject * object, + guint prop_id, GValue * value, GParamSpec * pspec); +static void gst_onnx_inference_finalize (GObject * object); +static GstFlowReturn gst_onnx_inference_transform_ip (GstBaseTransform * + trans, GstBuffer * buf); +static gboolean gst_onnx_inference_process (GstBaseTransform * trans, + GstBuffer * buf); +static gboolean gst_onnx_inference_create_session (GstBaseTransform * trans); +static GstCaps *gst_onnx_inference_transform_caps (GstBaseTransform * + trans, GstPadDirection direction, GstCaps * caps, GstCaps * filter_caps); +static gboolean +gst_onnx_inference_set_caps (GstBaseTransform * trans, GstCaps * incaps, + GstCaps * outcaps); + +G_DEFINE_TYPE (GstOnnxInference, gst_onnx_inference, GST_TYPE_BASE_TRANSFORM); + +GType gst_onnx_optimization_level_get_type (void); +#define GST_TYPE_ONNX_OPTIMIZATION_LEVEL (gst_onnx_optimization_level_get_type ()) + +GType gst_onnx_execution_provider_get_type (void); +#define GST_TYPE_ONNX_EXECUTION_PROVIDER (gst_onnx_execution_provider_get_type ()) + +GType gst_ml_model_input_image_format_get_type (void); +#define GST_TYPE_ML_MODEL_INPUT_IMAGE_FORMAT (gst_ml_model_input_image_format_get_type ()) + +GType +gst_onnx_optimization_level_get_type (void) +{ + static GType onnx_optimization_type = 0; + + if (g_once_init_enter (&onnx_optimization_type)) { + static GEnumValue optimization_level_types[] = { + {GST_ONNX_OPTIMIZATION_LEVEL_DISABLE_ALL, "Disable all optimization", + "disable-all"}, + {GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_BASIC, + "Enable basic optimizations (redundant node removals))", + "enable-basic"}, + {GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED, + "Enable extended optimizations (redundant node removals + node fusions)", + "enable-extended"}, + {GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_ALL, + "Enable all possible optimizations", "enable-all"}, + {0, NULL, NULL}, + }; + + GType temp = g_enum_register_static ("GstOnnxOptimizationLevel", + optimization_level_types); + + g_once_init_leave (&onnx_optimization_type, temp); + } + + return onnx_optimization_type; +} + +GType +gst_onnx_execution_provider_get_type (void) +{ + static GType onnx_execution_type = 0; + + if (g_once_init_enter (&onnx_execution_type)) { + static GEnumValue execution_provider_types[] = { + {GST_ONNX_EXECUTION_PROVIDER_CPU, "CPU execution provider", + "cpu"}, + {GST_ONNX_EXECUTION_PROVIDER_CUDA, + "CUDA execution provider", + "cuda"}, + {0, NULL, NULL}, + }; + + GType temp = g_enum_register_static ("GstOnnxExecutionProvider", + execution_provider_types); + + g_once_init_leave (&onnx_execution_type, temp); + } + + return onnx_execution_type; +} + +GType +gst_ml_model_input_image_format_get_type (void) +{ + static GType ml_model_input_image_format = 0; + + if (g_once_init_enter (&ml_model_input_image_format)) { + static GEnumValue ml_model_input_image_format_types[] = { + {GST_ML_INPUT_IMAGE_FORMAT_HWC, + "Height Width Channel (HWC) a.k.a. interleaved image data format", + "hwc"}, + {GST_ML_INPUT_IMAGE_FORMAT_CHW, + "Channel Height Width (CHW) a.k.a. planar image data format", + "chw"}, + {0, NULL, NULL}, + }; + + GType temp = g_enum_register_static ("GstMlInputImageFormat", + ml_model_input_image_format_types); + + g_once_init_leave (&ml_model_input_image_format, temp); + } + + return ml_model_input_image_format; +} + +static void +gst_onnx_inference_class_init (GstOnnxInferenceClass * klass) +{ + GObjectClass *gobject_class = (GObjectClass *) klass; + GstElementClass *element_class = (GstElementClass *) klass; + GstBaseTransformClass *basetransform_class = (GstBaseTransformClass *) klass; + + GST_DEBUG_CATEGORY_INIT (onnx_inference_debug, "onnxinference", + 0, "onnx_inference"); + gobject_class->set_property = gst_onnx_inference_set_property; + gobject_class->get_property = gst_onnx_inference_get_property; + gobject_class->finalize = gst_onnx_inference_finalize; + + /** + * GstOnnxInference:model-file + * + * ONNX model file + * + * Since: 1.24 + */ + g_object_class_install_property (G_OBJECT_CLASS (klass), PROP_MODEL_FILE, + g_param_spec_string ("model-file", + "ONNX model file", "ONNX model file", NULL, (GParamFlags) + (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS))); + + /** + * GstOnnxInference:input-image-format + * + * Model input image format + * + * Since: 1.24 + */ + g_object_class_install_property (G_OBJECT_CLASS (klass), + PROP_INPUT_IMAGE_FORMAT, + g_param_spec_enum ("input-image-format", + "Input image format", + "Input image format", + GST_TYPE_ML_MODEL_INPUT_IMAGE_FORMAT, + GST_ML_INPUT_IMAGE_FORMAT_HWC, (GParamFlags) + (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS))); + + /** + * GstOnnxInference:optimization-level + * + * ONNX optimization level + * + * Since: 1.24 + */ + g_object_class_install_property (G_OBJECT_CLASS (klass), + PROP_OPTIMIZATION_LEVEL, + g_param_spec_enum ("optimization-level", + "Optimization level", + "ONNX optimization level", + GST_TYPE_ONNX_OPTIMIZATION_LEVEL, + GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED, (GParamFlags) + (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS))); + + /** + * GstOnnxInference:execution-provider + * + * ONNX execution provider + * + * Since: 1.24 + */ + g_object_class_install_property (G_OBJECT_CLASS (klass), + PROP_EXECUTION_PROVIDER, + g_param_spec_enum ("execution-provider", + "Execution provider", + "ONNX execution provider", + GST_TYPE_ONNX_EXECUTION_PROVIDER, + GST_ONNX_EXECUTION_PROVIDER_CPU, (GParamFlags) + (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS))); + + gst_element_class_set_static_metadata (element_class, "onnxinference", + "Filter/Effect/Video", + "Apply neural network to video frames and create tensor output", + "Aaron Boxer "); + gst_element_class_add_pad_template (element_class, + gst_static_pad_template_get (&gst_onnx_inference_sink_template)); + gst_element_class_add_pad_template (element_class, + gst_static_pad_template_get (&gst_onnx_inference_src_template)); + basetransform_class->transform_ip = + GST_DEBUG_FUNCPTR (gst_onnx_inference_transform_ip); + basetransform_class->transform_caps = + GST_DEBUG_FUNCPTR (gst_onnx_inference_transform_caps); + basetransform_class->set_caps = + GST_DEBUG_FUNCPTR (gst_onnx_inference_set_caps); +} + +static void +gst_onnx_inference_init (GstOnnxInference * self) +{ + self->onnx_client = new GstOnnxNamespace::GstOnnxClient (); + self->onnx_disabled = TRUE; +} + +static void +gst_onnx_inference_finalize (GObject * object) +{ + GstOnnxInference *self = GST_ONNX_INFERENCE (object); + + g_free (self->model_file); + delete GST_ONNX_CLIENT_MEMBER (self); + G_OBJECT_CLASS (gst_onnx_inference_parent_class)->finalize (object); +} + +static void +gst_onnx_inference_set_property (GObject * object, guint prop_id, + const GValue * value, GParamSpec * pspec) +{ + GstOnnxInference *self = GST_ONNX_INFERENCE (object); + const gchar *filename; + auto onnxClient = GST_ONNX_CLIENT_MEMBER (self); + + switch (prop_id) { + case PROP_MODEL_FILE: + filename = g_value_get_string (value); + if (filename + && g_file_test (filename, + (GFileTest) (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR))) { + if (self->model_file) + g_free (self->model_file); + self->model_file = g_strdup (filename); + self->onnx_disabled = FALSE; + } else { + GST_WARNING_OBJECT (self, "Model file '%s' not found!", filename); + } + break; + case PROP_OPTIMIZATION_LEVEL: + self->optimization_level = + (GstOnnxOptimizationLevel) g_value_get_enum (value); + break; + case PROP_EXECUTION_PROVIDER: + self->execution_provider = + (GstOnnxExecutionProvider) g_value_get_enum (value); + break; + case PROP_INPUT_IMAGE_FORMAT: + onnxClient->setInputImageFormat ((GstMlInputImageFormat) + g_value_get_enum (value)); + break; + default: + G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec); + break; + } +} + +static void +gst_onnx_inference_get_property (GObject * object, guint prop_id, + GValue * value, GParamSpec * pspec) +{ + GstOnnxInference *self = GST_ONNX_INFERENCE (object); + auto onnxClient = GST_ONNX_CLIENT_MEMBER (self); + + switch (prop_id) { + case PROP_MODEL_FILE: + g_value_set_string (value, self->model_file); + break; + case PROP_OPTIMIZATION_LEVEL: + g_value_set_enum (value, self->optimization_level); + break; + case PROP_EXECUTION_PROVIDER: + g_value_set_enum (value, self->execution_provider); + break; + case PROP_INPUT_IMAGE_FORMAT: + g_value_set_enum (value, onnxClient->getInputImageFormat ()); + break; + default: + G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec); + break; + } +} + +static gboolean +gst_onnx_inference_create_session (GstBaseTransform * trans) +{ + GstOnnxInference *self = GST_ONNX_INFERENCE (trans); + auto onnxClient = GST_ONNX_CLIENT_MEMBER (self); + + GST_OBJECT_LOCK (self); + if (self->onnx_disabled) { + GST_OBJECT_UNLOCK (self); + + return FALSE; + } + if (onnxClient->hasSession ()) { + GST_OBJECT_UNLOCK (self); + + return TRUE; + } + if (self->model_file) { + gboolean ret = + GST_ONNX_CLIENT_MEMBER (self)->createSession (self->model_file, + self->optimization_level, + self->execution_provider); + if (!ret) { + GST_ERROR_OBJECT (self, + "Unable to create ONNX session. Model is disabled."); + self->onnx_disabled = TRUE; + } + } else { + self->onnx_disabled = TRUE; + GST_ELEMENT_ERROR (self, STREAM, FAILED, (NULL), ("Model file not found")); + } + GST_OBJECT_UNLOCK (self); + if (self->onnx_disabled) { + gst_base_transform_set_passthrough (GST_BASE_TRANSFORM (self), TRUE); + } + + return TRUE; +} + +static GstCaps * +gst_onnx_inference_transform_caps (GstBaseTransform * + trans, GstPadDirection direction, GstCaps * caps, GstCaps * filter_caps) +{ + GstOnnxInference *self = GST_ONNX_INFERENCE (trans); + auto onnxClient = GST_ONNX_CLIENT_MEMBER (self); + GstCaps *other_caps; + guint i; + + if (!gst_onnx_inference_create_session (trans)) + return NULL; + GST_LOG_OBJECT (self, "transforming caps %" GST_PTR_FORMAT, caps); + + if (gst_base_transform_is_passthrough (trans) + || (!onnxClient->isFixedInputImageSize ())) + return gst_caps_ref (caps); + + other_caps = gst_caps_new_empty (); + for (i = 0; i < gst_caps_get_size (caps); ++i) { + GstStructure *structure, *new_structure; + + structure = gst_caps_get_structure (caps, i); + new_structure = gst_structure_copy (structure); + gst_structure_set (new_structure, "width", G_TYPE_INT, + onnxClient->getWidth (), "height", G_TYPE_INT, + onnxClient->getHeight (), NULL); + GST_LOG_OBJECT (self, + "transformed structure %2d: %" GST_PTR_FORMAT " => %" + GST_PTR_FORMAT, i, structure, new_structure); + gst_caps_append_structure (other_caps, new_structure); + } + + if (!gst_caps_is_empty (other_caps) && filter_caps) { + GstCaps *tmp = gst_caps_intersect_full (other_caps, filter_caps, + GST_CAPS_INTERSECT_FIRST); + gst_caps_replace (&other_caps, tmp); + gst_caps_unref (tmp); + } + + return other_caps; +} + +static gboolean +gst_onnx_inference_set_caps (GstBaseTransform * trans, GstCaps * incaps, + GstCaps * outcaps) +{ + GstOnnxInference *self = GST_ONNX_INFERENCE (trans); + auto onnxClient = GST_ONNX_CLIENT_MEMBER (self); + + if (!gst_video_info_from_caps (&self->video_info, incaps)) { + GST_ERROR_OBJECT (self, "Failed to parse caps"); + return FALSE; + } + + onnxClient->parseDimensions (self->video_info); + return TRUE; +} + +static GstFlowReturn +gst_onnx_inference_transform_ip (GstBaseTransform * trans, GstBuffer * buf) +{ + if (!gst_base_transform_is_passthrough (trans) + && !gst_onnx_inference_process (trans, buf)) { + GST_ELEMENT_ERROR (trans, STREAM, FAILED, + (NULL), ("ONNX inference failed")); + return GST_FLOW_ERROR; + } + + return GST_FLOW_OK; +} + +static gboolean +gst_onnx_inference_process (GstBaseTransform * trans, GstBuffer * buf) +{ + GstMapInfo info; + GstVideoMeta *vmeta = gst_buffer_get_video_meta (buf); + + if (!vmeta) { + GST_WARNING_OBJECT (trans, "missing video meta"); + return FALSE; + } + if (gst_buffer_map (buf, &info, GST_MAP_READ)) { + GstOnnxInference *self = GST_ONNX_INFERENCE (trans); + try { + auto client = GST_ONNX_CLIENT_MEMBER (self); + auto outputs = client->run (info.data, self->video_info); + auto meta = client->copy_tensors_to_meta (outputs, buf); + if (!meta) + return FALSE; + meta->batch_size = 1; + } + catch (Ort::Exception & ortex) { + GST_ERROR_OBJECT (self, "%s", ortex.what ()); + return FALSE; + } + + gst_buffer_unmap (buf, &info); + } + + return TRUE; +} diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstonnxinference.h b/subprojects/gst-plugins-bad/ext/onnx/gstonnxinference.h new file mode 100644 index 0000000000..1f12941882 --- /dev/null +++ b/subprojects/gst-plugins-bad/ext/onnx/gstonnxinference.h @@ -0,0 +1,64 @@ +/* + * GStreamer gstreamer-onnxinference + * Copyright (C) 2023 Collabora Ltd + * + * gstonnxinference.h + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Library General Public + * License as published by the Free Software Foundation; either + * version 2 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Library General Public License for more details. + * + * You should have received a copy of the GNU Library General Public + * License along with this library; if not, write to the + * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, + * Boston, MA 02110-1301, USA. + */ + +#ifndef __GST_ONNX_INFERENCE_H__ +#define __GST_ONNX_INFERENCE_H__ + +#include +#include +#include +#include "gstonnxenums.h" + +G_BEGIN_DECLS + +#define GST_TYPE_ONNX_INFERENCE (gst_onnx_inference_get_type()) +G_DECLARE_FINAL_TYPE (GstOnnxInference, gst_onnx_inference, GST, + ONNX_INFERENCE, GstBaseTransform) + +/** + * GstOnnxInference: + * + * @model_file model file + * @optimization_level: ONNX session optimization level + * @execution_provider: ONNX execution provider + * @onnx_client opaque pointer to ONNX client + * @onnx_disabled true if inference is disabled + * @video_info @ref GstVideoInfo of sink caps + * + * Since: 1.24 + */ + struct _GstOnnxInference + { + GstBaseTransform basetransform; + gchar *model_file; + GstOnnxOptimizationLevel optimization_level; + GstOnnxExecutionProvider execution_provider; + gpointer onnx_client; + gboolean onnx_disabled; + GstVideoInfo video_info; + }; + +GST_ELEMENT_REGISTER_DECLARE (onnx_inference) + +G_END_DECLS + +#endif /* __GST_ONNX_INFERENCE_H__ */ diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstonnxobjectdetector.cpp b/subprojects/gst-plugins-bad/ext/onnx/gstonnxobjectdetector.cpp deleted file mode 100644 index c86bd40205..0000000000 --- a/subprojects/gst-plugins-bad/ext/onnx/gstonnxobjectdetector.cpp +++ /dev/null @@ -1,684 +0,0 @@ -/* - * GStreamer gstreamer-onnxobjectdetector - * Copyright (C) 2021 Collabora Ltd. - * - * gstonnxobjectdetector.c - * - * This library is free software; you can redistribute it and/or - * modify it under the terms of the GNU Library General Public - * License as published by the Free Software Foundation; either - * version 2 of the License, or (at your option) any later version. - * - * This library is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - * Library General Public License for more details. - * - * You should have received a copy of the GNU Library General Public - * License along with this library; if not, write to the - * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, - * Boston, MA 02110-1301, USA. - */ - -/** - * SECTION:element-onnxobjectdetector - * @short_description: Detect objects in video frame - * - * This element can apply a generic ONNX object detection model such as YOLO or SSD - * to each video frame. - * - * To install ONNX on your system, recursively clone this repository - * https://github.com/microsoft/onnxruntime.git - * - * and build and install with cmake: - * - * CPU: - * - * cmake -Donnxruntime_BUILD_SHARED_LIB:ON -DBUILD_TESTING:OFF \ - * $SRC_DIR/onnxruntime/cmake && make -j8 && sudo make install - * - * - * GPU : - * - * cmake -Donnxruntime_BUILD_SHARED_LIB:ON -DBUILD_TESTING:OFF -Donnxruntime_USE_CUDA:ON \ - * -Donnxruntime_CUDA_HOME=$CUDA_PATH -Donnxruntime_CUDNN_HOME=$CUDA_PATH \ - * $SRC_DIR/onnxruntime/cmake && make -j8 && sudo make install - * - * - * where : - * - * 1. $SRC_DIR and $BUILD_DIR are local source and build directories - * 2. To run with CUDA, both CUDA and cuDNN libraries must be installed. - * $CUDA_PATH is an environment variable set to the CUDA root path. - * On Linux, it would be /usr/local/cuda-XX.X where XX.X is the installed version of CUDA. - * - * - * ## Example launch command: - * - * (note: an object detection model has 3 or 4 output nodes, but there is no - * naming convention to indicate which node outputs the bounding box, which - * node outputs the label, etc. So, the `onnxobjectdetector` element has - * properties to map each node's functionality to its respective node index in - * the specified model. Image resolution also need to be adapted to the model. - * The videoscale in the pipeline below will scale the image, using padding if - * required, to 640x383 resolution required by the model.) - * - * model.onnx can be found here: - * https://github.com/zoq/onnx-runtime-examples/raw/main/data/models/model.onnx - * - * ``` - * GST_DEBUG=objectdetector:5 gst-launch-1.0 multifilesrc \ - * location=000000088462.jpg caps=image/jpeg,framerate=\(fraction\)30/1 ! jpegdec ! \ - * videoconvert ! \ - * videoscale ! \ - * 'video/x-raw,width=640,height=383' ! \ - * onnxobjectdetector \ - * box-node-index=0 \ - * class-node-index=1 \ - * score-node-index=2 \ - * detection-node-index=3 \ - * execution-provider=cpu \ - * model-file=model.onnx \ - * label-file=COCO_classes.txt ! \ - * videoconvert ! \ - * autovideosink - * ``` - */ - -#ifdef HAVE_CONFIG_H -#include "config.h" -#endif - -#include "gstonnxobjectdetector.h" -#include "gstonnxclient.h" - -#include -#include -#include -#include -#include -#include - -GST_DEBUG_CATEGORY_STATIC (onnx_object_detector_debug); -#define GST_CAT_DEFAULT onnx_object_detector_debug -#define GST_ONNX_MEMBER( self ) ((GstOnnxNamespace::GstOnnxClient *) (self->onnx_ptr)) -GST_ELEMENT_REGISTER_DEFINE (onnx_object_detector, "onnxobjectdetector", - GST_RANK_PRIMARY, GST_TYPE_ONNX_OBJECT_DETECTOR); - -/* GstOnnxObjectDetector properties */ -enum -{ - PROP_0, - PROP_MODEL_FILE, - PROP_LABEL_FILE, - PROP_SCORE_THRESHOLD, - PROP_DETECTION_NODE_INDEX, - PROP_BOUNDING_BOX_NODE_INDEX, - PROP_SCORE_NODE_INDEX, - PROP_CLASS_NODE_INDEX, - PROP_INPUT_IMAGE_FORMAT, - PROP_OPTIMIZATION_LEVEL, - PROP_EXECUTION_PROVIDER -}; - - -#define GST_ONNX_OBJECT_DETECTOR_DEFAULT_EXECUTION_PROVIDER GST_ONNX_EXECUTION_PROVIDER_CPU -#define GST_ONNX_OBJECT_DETECTOR_DEFAULT_OPTIMIZATION_LEVEL GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED -#define GST_ONNX_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD 0.3f /* 0 to 1 */ - -static GstStaticPadTemplate gst_onnx_object_detector_src_template = -GST_STATIC_PAD_TEMPLATE ("src", - GST_PAD_SRC, - GST_PAD_ALWAYS, - GST_STATIC_CAPS (GST_VIDEO_CAPS_MAKE ("{ RGB,RGBA,BGR,BGRA }")) - ); - -static GstStaticPadTemplate gst_onnx_object_detector_sink_template = -GST_STATIC_PAD_TEMPLATE ("sink", - GST_PAD_SINK, - GST_PAD_ALWAYS, - GST_STATIC_CAPS (GST_VIDEO_CAPS_MAKE ("{ RGB,RGBA,BGR,BGRA }")) - ); - -static void gst_onnx_object_detector_set_property (GObject * object, - guint prop_id, const GValue * value, GParamSpec * pspec); -static void gst_onnx_object_detector_get_property (GObject * object, - guint prop_id, GValue * value, GParamSpec * pspec); -static void gst_onnx_object_detector_finalize (GObject * object); -static GstFlowReturn gst_onnx_object_detector_transform_ip (GstBaseTransform * - trans, GstBuffer * buf); -static gboolean gst_onnx_object_detector_process (GstBaseTransform * trans, - GstBuffer * buf); -static gboolean gst_onnx_object_detector_create_session (GstBaseTransform * trans); -static GstCaps *gst_onnx_object_detector_transform_caps (GstBaseTransform * - trans, GstPadDirection direction, GstCaps * caps, GstCaps * filter_caps); - -G_DEFINE_TYPE (GstOnnxObjectDetector, gst_onnx_object_detector, - GST_TYPE_BASE_TRANSFORM); - -static void -gst_onnx_object_detector_class_init (GstOnnxObjectDetectorClass * klass) -{ - GObjectClass *gobject_class = (GObjectClass *) klass; - GstElementClass *element_class = (GstElementClass *) klass; - GstBaseTransformClass *basetransform_class = (GstBaseTransformClass *) klass; - - GST_DEBUG_CATEGORY_INIT (onnx_object_detector_debug, "onnxobjectdetector", - 0, "onnx_objectdetector"); - gobject_class->set_property = gst_onnx_object_detector_set_property; - gobject_class->get_property = gst_onnx_object_detector_get_property; - gobject_class->finalize = gst_onnx_object_detector_finalize; - - /** - * GstOnnxObjectDetector:model-file - * - * ONNX model file - * - * Since: 1.20 - */ - g_object_class_install_property (G_OBJECT_CLASS (klass), PROP_MODEL_FILE, - g_param_spec_string ("model-file", - "ONNX model file", "ONNX model file", NULL, (GParamFlags) - (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS))); - - /** - * GstOnnxObjectDetector:label-file - * - * Label file for ONNX model - * - * Since: 1.20 - */ - g_object_class_install_property (G_OBJECT_CLASS (klass), PROP_LABEL_FILE, - g_param_spec_string ("label-file", - "Label file", "Label file associated with model", NULL, (GParamFlags) - (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS))); - - - /** - * GstOnnxObjectDetector:detection-node-index - * - * Index of model detection node - * - * Since: 1.20 - */ - g_object_class_install_property (G_OBJECT_CLASS (klass), - PROP_DETECTION_NODE_INDEX, - g_param_spec_int ("detection-node-index", - "Detection node index", - "Index of neural network output node corresponding to number of detected objects", - GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED, - GstOnnxNamespace::GST_ML_OUTPUT_NODE_NUMBER_OF-1, - GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED, (GParamFlags) - (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS))); - - - /** - * GstOnnxObjectDetector:bounding-box-node-index - * - * Index of model bounding box node - * - * Since: 1.20 - */ - g_object_class_install_property (G_OBJECT_CLASS (klass), - PROP_BOUNDING_BOX_NODE_INDEX, - g_param_spec_int ("box-node-index", - "Bounding box node index", - "Index of neural network output node corresponding to bounding box", - GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED, - GstOnnxNamespace::GST_ML_OUTPUT_NODE_NUMBER_OF-1, - GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED, (GParamFlags) - (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS))); - - /** - * GstOnnxObjectDetector:score-node-index - * - * Index of model score node - * - * Since: 1.20 - */ - g_object_class_install_property (G_OBJECT_CLASS (klass), - PROP_SCORE_NODE_INDEX, - g_param_spec_int ("score-node-index", - "Score node index", - "Index of neural network output node corresponding to score", - GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED, - GstOnnxNamespace::GST_ML_OUTPUT_NODE_NUMBER_OF-1, - GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED, (GParamFlags) - (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS))); - - /** - * GstOnnxObjectDetector:class-node-index - * - * Index of model class (label) node - * - * Since: 1.20 - */ - g_object_class_install_property (G_OBJECT_CLASS (klass), - PROP_CLASS_NODE_INDEX, - g_param_spec_int ("class-node-index", - "Class node index", - "Index of neural network output node corresponding to class (label)", - GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED, - GstOnnxNamespace::GST_ML_OUTPUT_NODE_NUMBER_OF-1, - GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED, (GParamFlags) - (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS))); - - - /** - * GstOnnxObjectDetector:score-threshold - * - * Threshold for deciding when to remove boxes based on score - * - * Since: 1.20 - */ - g_object_class_install_property (G_OBJECT_CLASS (klass), PROP_SCORE_THRESHOLD, - g_param_spec_float ("score-threshold", - "Score threshold", - "Threshold for deciding when to remove boxes based on score", - 0.0, 1.0, - GST_ONNX_OBJECT_DETECTOR_DEFAULT_SCORE_THRESHOLD, (GParamFlags) - (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS))); - - /** - * GstOnnxObjectDetector:input-image-format - * - * Model input image format - * - * Since: 1.20 - */ - g_object_class_install_property (G_OBJECT_CLASS (klass), - PROP_INPUT_IMAGE_FORMAT, - g_param_spec_enum ("input-image-format", - "Input image format", - "Input image format", - GST_TYPE_ML_MODEL_INPUT_IMAGE_FORMAT, - GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC, (GParamFlags) - (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS))); - - /** - * GstOnnxObjectDetector:optimization-level - * - * ONNX optimization level - * - * Since: 1.20 - */ - g_object_class_install_property (G_OBJECT_CLASS (klass), - PROP_OPTIMIZATION_LEVEL, - g_param_spec_enum ("optimization-level", - "Optimization level", - "ONNX optimization level", - GST_TYPE_ONNX_OPTIMIZATION_LEVEL, - GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED, (GParamFlags) - (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS))); - - /** - * GstOnnxObjectDetector:execution-provider - * - * ONNX execution provider - * - * Since: 1.20 - */ - g_object_class_install_property (G_OBJECT_CLASS (klass), - PROP_EXECUTION_PROVIDER, - g_param_spec_enum ("execution-provider", - "Execution provider", - "ONNX execution provider", - GST_TYPE_ONNX_EXECUTION_PROVIDER, - GST_ONNX_EXECUTION_PROVIDER_CPU, (GParamFlags) - (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS))); - - gst_element_class_set_static_metadata (element_class, "onnxobjectdetector", - "Filter/Effect/Video", - "Apply neural network to detect objects in video frames", - "Aaron Boxer , Marcus Edel "); - gst_element_class_add_pad_template (element_class, - gst_static_pad_template_get (&gst_onnx_object_detector_sink_template)); - gst_element_class_add_pad_template (element_class, - gst_static_pad_template_get (&gst_onnx_object_detector_src_template)); - basetransform_class->transform_ip = - GST_DEBUG_FUNCPTR (gst_onnx_object_detector_transform_ip); - basetransform_class->transform_caps = - GST_DEBUG_FUNCPTR (gst_onnx_object_detector_transform_caps); -} - -static void -gst_onnx_object_detector_init (GstOnnxObjectDetector * self) -{ - self->onnx_ptr = new GstOnnxNamespace::GstOnnxClient (); - self->onnx_disabled = false; -} - -static void -gst_onnx_object_detector_finalize (GObject * object) -{ - GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (object); - - g_free (self->model_file); - delete GST_ONNX_MEMBER (self); - G_OBJECT_CLASS (gst_onnx_object_detector_parent_class)->finalize (object); -} - -static void -gst_onnx_object_detector_set_property (GObject * object, guint prop_id, - const GValue * value, GParamSpec * pspec) -{ - GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (object); - const gchar *filename; - auto onnxClient = GST_ONNX_MEMBER (self); - - switch (prop_id) { - case PROP_MODEL_FILE: - filename = g_value_get_string (value); - if (filename - && g_file_test (filename, - (GFileTest) (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR))) { - if (self->model_file) - g_free (self->model_file); - self->model_file = g_strdup (filename); - } else { - GST_WARNING_OBJECT (self, "Model file '%s' not found!", filename); - gst_base_transform_set_passthrough (GST_BASE_TRANSFORM (self), TRUE); - } - break; - case PROP_LABEL_FILE: - filename = g_value_get_string (value); - if (filename - && g_file_test (filename, - (GFileTest) (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR))) { - if (self->label_file) - g_free (self->label_file); - self->label_file = g_strdup (filename); - } else { - GST_WARNING_OBJECT (self, "Label file '%s' not found!", filename); - } - break; - case PROP_SCORE_THRESHOLD: - GST_OBJECT_LOCK (self); - self->score_threshold = g_value_get_float (value); - GST_OBJECT_UNLOCK (self); - break; - case PROP_OPTIMIZATION_LEVEL: - self->optimization_level = - (GstOnnxOptimizationLevel) g_value_get_enum (value); - break; - case PROP_EXECUTION_PROVIDER: - self->execution_provider = - (GstOnnxExecutionProvider) g_value_get_enum (value); - break; - case PROP_DETECTION_NODE_INDEX: - onnxClient->setOutputNodeIndex - (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_DETECTION, - g_value_get_int (value)); - break; - case PROP_BOUNDING_BOX_NODE_INDEX: - onnxClient->setOutputNodeIndex - (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX, - g_value_get_int (value)); - break; - break; - case PROP_SCORE_NODE_INDEX: - onnxClient->setOutputNodeIndex - (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_SCORE, - g_value_get_int (value)); - break; - break; - case PROP_CLASS_NODE_INDEX: - onnxClient->setOutputNodeIndex - (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_CLASS, - g_value_get_int (value)); - break; - case PROP_INPUT_IMAGE_FORMAT: - onnxClient->setInputImageFormat ((GstMlModelInputImageFormat) - g_value_get_enum (value)); - break; - default: - G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec); - break; - } -} - -static void -gst_onnx_object_detector_get_property (GObject * object, guint prop_id, - GValue * value, GParamSpec * pspec) -{ - GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (object); - auto onnxClient = GST_ONNX_MEMBER (self); - - switch (prop_id) { - case PROP_MODEL_FILE: - g_value_set_string (value, self->model_file); - break; - case PROP_LABEL_FILE: - g_value_set_string (value, self->label_file); - break; - case PROP_SCORE_THRESHOLD: - GST_OBJECT_LOCK (self); - g_value_set_float (value, self->score_threshold); - GST_OBJECT_UNLOCK (self); - break; - case PROP_OPTIMIZATION_LEVEL: - g_value_set_enum (value, self->optimization_level); - break; - case PROP_EXECUTION_PROVIDER: - g_value_set_enum (value, self->execution_provider); - break; - case PROP_DETECTION_NODE_INDEX: - g_value_set_int (value, - onnxClient->getOutputNodeIndex - (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_DETECTION)); - break; - case PROP_BOUNDING_BOX_NODE_INDEX: - g_value_set_int (value, - onnxClient->getOutputNodeIndex - (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX)); - break; - break; - case PROP_SCORE_NODE_INDEX: - g_value_set_int (value, - onnxClient->getOutputNodeIndex - (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_SCORE)); - break; - break; - case PROP_CLASS_NODE_INDEX: - g_value_set_int (value, - onnxClient->getOutputNodeIndex - (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_CLASS)); - break; - case PROP_INPUT_IMAGE_FORMAT: - g_value_set_enum (value, onnxClient->getInputImageFormat ()); - break; - default: - G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec); - break; - } -} - -static gboolean -gst_onnx_object_detector_create_session (GstBaseTransform * trans) -{ - GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (trans); - auto onnxClient = GST_ONNX_MEMBER (self); - - GST_OBJECT_LOCK (self); - if (self->onnx_disabled || onnxClient->hasSession ()) { - GST_OBJECT_UNLOCK (self); - - return TRUE; - } - if (self->model_file) { - gboolean ret = GST_ONNX_MEMBER (self)->createSession (self->model_file, - self->optimization_level, - self->execution_provider); - if (!ret) { - GST_ERROR_OBJECT (self, - "Unable to create ONNX session. Detection disabled."); - } else { - auto outputNames = onnxClient->getOutputNodeNames (); - - for (size_t i = 0; i < outputNames.size (); ++i) - GST_INFO_OBJECT (self, "Output node index: %d for node: %s", (gint) i, - outputNames[i]); - if (outputNames.size () < 3) { - GST_ERROR_OBJECT (self, - "Number of output tensor nodes %d does not match the 3 or 4 nodes " - "required for an object detection model. Detection is disabled.", - (gint) outputNames.size ()); - self->onnx_disabled = TRUE; - } - // sanity check on output node indices - if (onnxClient->getOutputNodeIndex - (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_DETECTION) == - GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED) { - GST_ERROR_OBJECT (self, - "Output detection node index not set. Detection disabled."); - self->onnx_disabled = TRUE; - } - if (onnxClient->getOutputNodeIndex - (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX) == - GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED) { - GST_ERROR_OBJECT (self, - "Output bounding box node index not set. Detection disabled."); - self->onnx_disabled = TRUE; - } - if (onnxClient->getOutputNodeIndex - (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_SCORE) == - GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED) { - GST_ERROR_OBJECT (self, - "Output score node index not set. Detection disabled."); - self->onnx_disabled = TRUE; - } - if (outputNames.size () == 4 && onnxClient->getOutputNodeIndex - (GstOnnxNamespace::GST_ML_OUTPUT_NODE_FUNCTION_CLASS) == - GstOnnxNamespace::GST_ML_NODE_INDEX_DISABLED) { - GST_ERROR_OBJECT (self, - "Output class node index not set. Detection disabled."); - self->onnx_disabled = TRUE; - } - // model is not usable, so fail - if (self->onnx_disabled) { - GST_ELEMENT_WARNING (self, RESOURCE, FAILED, - ("ONNX model cannot be used for object detection"), (NULL)); - - return FALSE; - } - } - } else { - self->onnx_disabled = TRUE; - } - GST_OBJECT_UNLOCK (self); - if (self->onnx_disabled){ - gst_base_transform_set_passthrough (GST_BASE_TRANSFORM (self), TRUE); - } - - return TRUE; -} - - -static GstCaps * -gst_onnx_object_detector_transform_caps (GstBaseTransform * - trans, GstPadDirection direction, GstCaps * caps, GstCaps * filter_caps) -{ - GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (trans); - auto onnxClient = GST_ONNX_MEMBER (self); - GstCaps *other_caps; - guint i; - - if ( !gst_onnx_object_detector_create_session (trans) ) - return NULL; - GST_LOG_OBJECT (self, "transforming caps %" GST_PTR_FORMAT, caps); - - if (gst_base_transform_is_passthrough (trans) - || (!onnxClient->isFixedInputImageSize ())) - return gst_caps_ref (caps); - - other_caps = gst_caps_new_empty (); - for (i = 0; i < gst_caps_get_size (caps); ++i) { - GstStructure *structure, *new_structure; - - structure = gst_caps_get_structure (caps, i); - new_structure = gst_structure_copy (structure); - gst_structure_set (new_structure, "width", G_TYPE_INT, - onnxClient->getWidth (), "height", G_TYPE_INT, - onnxClient->getHeight (), NULL); - GST_LOG_OBJECT (self, - "transformed structure %2d: %" GST_PTR_FORMAT " => %" - GST_PTR_FORMAT, i, structure, new_structure); - gst_caps_append_structure (other_caps, new_structure); - } - - if (!gst_caps_is_empty (other_caps) && filter_caps) { - GstCaps *tmp = gst_caps_intersect_full (other_caps,filter_caps, - GST_CAPS_INTERSECT_FIRST); - gst_caps_replace (&other_caps, tmp); - gst_caps_unref (tmp); - } - - return other_caps; -} - - -static GstFlowReturn -gst_onnx_object_detector_transform_ip (GstBaseTransform * trans, - GstBuffer * buf) -{ - if (!gst_base_transform_is_passthrough (trans) - && !gst_onnx_object_detector_process (trans, buf)){ - GST_ELEMENT_WARNING (trans, STREAM, FAILED, - ("ONNX object detection failed"), (NULL)); - return GST_FLOW_ERROR; - } - - return GST_FLOW_OK; -} - -static gboolean -gst_onnx_object_detector_process (GstBaseTransform * trans, GstBuffer * buf) -{ - GstMapInfo info; - GstVideoMeta *vmeta = gst_buffer_get_video_meta (buf); - - if (!vmeta) { - GST_WARNING_OBJECT (trans, "missing video meta"); - return FALSE; - } - if (gst_buffer_map (buf, &info, GST_MAP_READ)) { - GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (trans); - std::vector < GstOnnxNamespace::GstMlBoundingBox > boxes; - try { - boxes = GST_ONNX_MEMBER (self)->run (info.data, vmeta, - self->label_file ? self->label_file : "", self->score_threshold); - } - catch (Ort::Exception & ortex) { - GST_ERROR_OBJECT (self, "%s", ortex.what ()); - return FALSE; - } - for (auto & b:boxes) { - auto vroi_meta = gst_buffer_add_video_region_of_interest_meta (buf, - GST_ONNX_OBJECT_DETECTOR_META_NAME, - b.x0, b.y0, - b.width, - b.height); - if (!vroi_meta) { - GST_WARNING_OBJECT (trans, - "Unable to attach GstVideoRegionOfInterestMeta to buffer"); - return FALSE; - } - auto s = gst_structure_new (GST_ONNX_OBJECT_DETECTOR_META_PARAM_NAME, - GST_ONNX_OBJECT_DETECTOR_META_FIELD_LABEL, - G_TYPE_STRING, - b.label.c_str (), - GST_ONNX_OBJECT_DETECTOR_META_FIELD_SCORE, - G_TYPE_DOUBLE, - b.score, - NULL); - gst_video_region_of_interest_meta_add_param (vroi_meta, s); - GST_DEBUG_OBJECT (self, - "Object detected with label : %s, score: %f, bound box: (%f,%f,%f,%f) \n", - b.label.c_str (), b.score, b.x0, b.y0, - b.x0 + b.width, b.y0 + b.height); - } - gst_buffer_unmap (buf, &info); - } - - return TRUE; -} diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstonnxobjectdetector.h b/subprojects/gst-plugins-bad/ext/onnx/gstonnxobjectdetector.h deleted file mode 100644 index e031d12e96..0000000000 --- a/subprojects/gst-plugins-bad/ext/onnx/gstonnxobjectdetector.h +++ /dev/null @@ -1,93 +0,0 @@ -/* - * GStreamer gstreamer-onnxobjectdetector - * Copyright (C) 2021 Collabora Ltd - * - * gstonnxobjectdetector.h - * - * This library is free software; you can redistribute it and/or - * modify it under the terms of the GNU Library General Public - * License as published by the Free Software Foundation; either - * version 2 of the License, or (at your option) any later version. - * - * This library is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - * Library General Public License for more details. - * - * You should have received a copy of the GNU Library General Public - * License along with this library; if not, write to the - * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, - * Boston, MA 02110-1301, USA. - */ - -#ifndef __GST_ONNX_OBJECT_DETECTOR_H__ -#define __GST_ONNX_OBJECT_DETECTOR_H__ - -#include -#include -#include -#include "gstonnxelement.h" - -G_BEGIN_DECLS - -#define GST_TYPE_ONNX_OBJECT_DETECTOR (gst_onnx_object_detector_get_type()) -G_DECLARE_FINAL_TYPE (GstOnnxObjectDetector, gst_onnx_object_detector, GST, ONNX_OBJECT_DETECTOR, GstBaseTransform) -#define GST_ONNX_OBJECT_DETECTOR(obj) (G_TYPE_CHECK_INSTANCE_CAST((obj),GST_TYPE_ONNX_OBJECT_DETECTOR,GstOnnxObjectDetector)) -#define GST_ONNX_OBJECT_DETECTOR_CLASS(klass) (G_TYPE_CHECK_CLASS_CAST((klass), GST_TYPE_ONNX_OBJECT_DETECTOR,GstOnnxObjectDetectorClass)) -#define GST_ONNX_OBJECT_DETECTOR_GET_CLASS(obj) (G_TYPE_INSTANCE_GET_CLASS((obj), GST_TYPE_ONNX_OBJECT_DETECTOR,GstOnnxObjectDetectorClass)) -#define GST_IS_ONNX_OBJECT_DETECTOR(obj) (G_TYPE_CHECK_INSTANCE_TYPE((obj),GST_TYPE_ONNX_OBJECT_DETECTOR)) -#define GST_IS_ONNX_OBJECT_DETECTOR_CLASS(klass) (G_TYPE_CHECK_CLASS_TYPE((klass), GST_TYPE_ONNX_OBJECT_DETECTOR)) - -#define GST_ONNX_OBJECT_DETECTOR_META_NAME "onnx-object_detector" -#define GST_ONNX_OBJECT_DETECTOR_META_PARAM_NAME "extra-data" -#define GST_ONNX_OBJECT_DETECTOR_META_FIELD_LABEL "label" -#define GST_ONNX_OBJECT_DETECTOR_META_FIELD_SCORE "score" - -/** - * GstOnnxObjectDetector: - * - * @model_file model file - * @label_file label file - * @score_threshold score threshold - * @confidence_threshold confidence threshold - * @iou_threhsold iou threshold - * @optimization_level ONNX optimization level - * @execution_provider: ONNX execution provider - * @onnx_ptr opaque pointer to ONNX implementation - * - * Since: 1.20 - */ -struct _GstOnnxObjectDetector -{ - GstBaseTransform basetransform; - gchar *model_file; - gchar *label_file; - gfloat score_threshold; - gfloat confidence_threshold; - gfloat iou_threshold; - GstOnnxOptimizationLevel optimization_level; - GstOnnxExecutionProvider execution_provider; - gpointer onnx_ptr; - gboolean onnx_disabled; - - void (*process) (GstOnnxObjectDetector * onnx_object_detector, - GstVideoFrame * inframe, GstVideoFrame * outframe); -}; - -/** - * GstOnnxObjectDetectorClass: - * - * @parent_class base transform base class - * - * Since: 1.20 - */ -struct _GstOnnxObjectDetectorClass -{ - GstBaseTransformClass parent_class; -}; - -GST_ELEMENT_REGISTER_DECLARE (onnx_object_detector) - -G_END_DECLS - -#endif /* __GST_ONNX_OBJECT_DETECTOR_H__ */ diff --git a/subprojects/gst-plugins-bad/ext/onnx/meson.build b/subprojects/gst-plugins-bad/ext/onnx/meson.build index e66d649e03..f591388866 100644 --- a/subprojects/gst-plugins-bad/ext/onnx/meson.build +++ b/subprojects/gst-plugins-bad/ext/onnx/meson.build @@ -3,27 +3,31 @@ if get_option('onnx').disabled() endif -onnxrt_dep = dependency('libonnxruntime', version : '>= 1.13.1', required : get_option('onnx')) +onnxrt_dep = dependency('libonnxruntime', version : '>= 1.15.1', required : get_option('onnx')) + +extra_args = [] +extra_deps = [] +if gstcuda_dep.found() + extra_args += ['-DHAVE_CUDA'] + extra_deps += [gstcuda_dep] +endif if onnxrt_dep.found() onnxrt_include_root = onnxrt_dep.get_variable('includedir') onnxrt_includes = [onnxrt_include_root / 'core/session', onnxrt_include_root / 'core'] - onnxrt_dep_args = [] - - compiler = meson.get_compiler('cpp') - if compiler.has_header(onnxrt_include_root / 'core/providers/cuda/cuda_provider_factory.h') - onnxrt_dep_args = ['-DGST_ML_ONNX_RUNTIME_HAVE_CUDA'] - endif gstonnx = library('gstonnx', 'gstonnx.c', - 'gstonnxelement.c', - 'gstonnxobjectdetector.cpp', + 'decoders/gstobjectdetectorutils.cpp', + 'decoders/gstssdobjectdetector.cpp', + 'gstonnxinference.cpp', 'gstonnxclient.cpp', - c_args : gst_plugins_bad_args, - cpp_args: onnxrt_dep_args, + 'tensor/gsttensorid.cpp', + 'tensor/gsttensormeta.c', + c_args : gst_plugins_bad_args + extra_args, + cpp_args : gst_plugins_bad_args + extra_args, link_args : noseh_link_args, - include_directories : [configinc, libsinc, onnxrt_includes], - dependencies : [gstbase_dep, gstvideo_dep, onnxrt_dep, libm], + include_directories : [configinc, libsinc, onnxrt_includes, cuda_stubinc], + dependencies : [gstbase_dep, gstvideo_dep, onnxrt_dep, libm] + extra_deps, install : true, install_dir : plugins_install_dir, ) diff --git a/subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensor.h b/subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensor.h new file mode 100644 index 0000000000..bfc63af010 --- /dev/null +++ b/subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensor.h @@ -0,0 +1,69 @@ +/* + * GStreamer gstreamer-tensor + * Copyright (C) 2023 Collabora Ltd + * + * gsttensor.h + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Library General Public + * License as published by the Free Software Foundation; either + * version 2 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Library General Public License for more details. + * + * You should have received a copy of the GNU Library General Public + * License along with this library; if not, write to the + * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, + * Boston, MA 02110-1301, USA. + */ +#ifndef __GST_TENSOR_H__ +#define __GST_TENSOR_H__ + + +/** + * GstTensorType: + * + * @GST_TENSOR_TYPE_INT8 8 bit integer tensor data + * @GST_TENSOR_TYPE_INT16 16 bit integer tensor data + * @GST_TENSOR_TYPE_INT32 32 bit integer tensor data + * @GST_TENSOR_TYPE_FLOAT16 16 bit floating point tensor data + * @GST_TENSOR_TYPE_FLOAT32 32 bit floating point tensor data + * + * Since: 1.24 + */ +typedef enum _GstTensorType +{ + GST_TENSOR_TYPE_INT8, + GST_TENSOR_TYPE_INT16, + GST_TENSOR_TYPE_INT32, + GST_TENSOR_TYPE_FLOAT16, + GST_TENSOR_TYPE_FLOAT32 +} GstTensorType; + + +/** + * GstTensor: + * + * @id unique tensor identifier + * @num_dims number of tensor dimensions + * @dims tensor dimensions + * @type @ref GstTensorType of tensor data + * @data @ref GstBuffer holding tensor data + * + * Since: 1.24 + */ +typedef struct _GstTensor +{ + GQuark id; + gint num_dims; + int64_t *dims; + GstTensorType type; + GstBuffer *data; +} GstTensor; + +#define GST_TENSOR_MISSING_ID -1 + +#endif diff --git a/subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensorid.cpp b/subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensorid.cpp new file mode 100644 index 0000000000..0fd8455022 --- /dev/null +++ b/subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensorid.cpp @@ -0,0 +1,82 @@ +/* + * GStreamer gstreamer-tensorid + * Copyright (C) 2023 Collabora Ltd + * + * gsttensorid.c + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Library General Public + * License as published by the Free Software Foundation; either + * version 2 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Library General Public License for more details. + * + * You should have received a copy of the GNU Library General Public + * License along with this library; if not, write to the + * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, + * Boston, MA 02110-1301, USA. + */ + +#include +#include "gsttensorid.h" + +/* Structure to encapsulate a string and its associated GQuark */ +struct TensorQuark +{ + const char *string; + GQuark quark_id; +}; + +class TensorId +{ +public: + TensorId (void):tensor_quarks_array (g_array_new (FALSE, FALSE, + sizeof (TensorQuark))) + { + } + ~TensorId (void) + { + if (tensor_quarks_array) { + for (guint i = 0; i < tensor_quarks_array->len; i++) { + TensorQuark *quark = + &g_array_index (tensor_quarks_array, TensorQuark, i); + g_free ((gpointer) quark->string); // free the duplicated string + } + g_array_free (tensor_quarks_array, TRUE); + } + } + GQuark get_quark (const char *str) + { + for (guint i = 0; i < tensor_quarks_array->len; i++) { + TensorQuark *quark = &g_array_index (tensor_quarks_array, TensorQuark, i); + if (g_strcmp0 (quark->string, str) == 0) { + return quark->quark_id; // already registered + } + } + + // Register the new quark and append to the GArray + TensorQuark new_quark; + new_quark.string = g_strdup (str); // create a copy of the string + new_quark.quark_id = g_quark_from_string (new_quark.string); + g_array_append_val (tensor_quarks_array, new_quark); + + return new_quark.quark_id; + } +private: + GArray * tensor_quarks_array; +}; + +static TensorId tensorId; + +G_BEGIN_DECLS + +GQuark +gst_tensorid_get_quark (const char *tensor_id) +{ + return tensorId.get_quark (tensor_id); +} + +G_END_DECLS diff --git a/subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensorid.h b/subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensorid.h new file mode 100644 index 0000000000..7df666b70d --- /dev/null +++ b/subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensorid.h @@ -0,0 +1,34 @@ +/* + * GStreamer gstreamer-tensorid + * Copyright (C) 2023 Collabora Ltd + * + * gsttensorid.h + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Library General Public + * License as published by the Free Software Foundation; either + * version 2 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Library General Public License for more details. + * + * You should have received a copy of the GNU Library General Public + * License along with this library; if not, write to the + * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, + * Boston, MA 02110-1301, USA. + */ +#ifndef __GST_TENSOR_ID_H__ +#define __GST_TENSOR_ID_H__ + +G_BEGIN_DECLS +/** + * gst_tensorid_get_quark get tensor id + * + * @param tensor_id unique string id for tensor node + */ + GQuark gst_tensorid_get_quark (const char *tensor_id); + +G_END_DECLS +#endif diff --git a/subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensormeta.c b/subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensormeta.c new file mode 100644 index 0000000000..a9da3cc083 --- /dev/null +++ b/subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensormeta.c @@ -0,0 +1,107 @@ +/* + * GStreamer gstreamer-tensormeta + * Copyright (C) 2023 Collabora Ltd + * + * gsttensormeta.c + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Library General Public + * License as published by the Free Software Foundation; either + * version 2 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Library General Public License for more details. + * + * You should have received a copy of the GNU Library General Public + * License along with this library; if not, write to the + * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, + * Boston, MA 02110-1301, USA. + */ + +#include "gsttensormeta.h" + +#include "gsttensor.h" + +static gboolean +gst_tensor_meta_init (GstMeta * meta, gpointer params, GstBuffer * buffer) +{ + GstTensorMeta *tmeta = (GstTensorMeta *) meta; + + tmeta->num_tensors = 0; + tmeta->tensor = NULL; + + return TRUE; +} + +static void +gst_tensor_meta_free (GstMeta * meta, GstBuffer * buffer) +{ + GstTensorMeta *tmeta = (GstTensorMeta *) meta; + + for (int i = 0; i < tmeta->num_tensors; i++) { + g_free (tmeta->tensor[i].dims); + gst_buffer_unref (tmeta->tensor[i].data); + } + g_free (tmeta->tensor); +} + +GType +gst_tensor_meta_api_get_type (void) +{ + static GType type = 0; + static const gchar *tags[] = { NULL }; + + if (g_once_init_enter (&type)) { + type = gst_meta_api_type_register ("GstTensorMetaAPI", tags); + } + return type; +} + + +const GstMetaInfo * +gst_tensor_meta_get_info (void) +{ + static const GstMetaInfo *tmeta_info = NULL; + + if (g_once_init_enter (&tmeta_info)) { + const GstMetaInfo *meta = + gst_meta_register (gst_tensor_meta_api_get_type (), + "GstTensorMeta", + sizeof (GstTensorMeta), + gst_tensor_meta_init, + gst_tensor_meta_free, + NULL); /* tensor_meta_transform not implemented */ + g_once_init_leave (&tmeta_info, meta); + } + return tmeta_info; +} + +GList * +gst_tensor_meta_get_all_from_buffer (GstBuffer * buffer) +{ + GType tensor_meta_api_type = gst_tensor_meta_api_get_type (); + GList *tensor_metas = NULL; + gpointer state = NULL; + GstMeta *meta; + + while ((meta = gst_buffer_iterate_meta (buffer, &state))) { + if (meta->info->api == tensor_meta_api_type) { + tensor_metas = g_list_append (tensor_metas, meta); + } + } + + return tensor_metas; +} + +gint +gst_tensor_meta_get_index_from_id (GstTensorMeta * meta, GQuark id) +{ + for (int i = 0; i < meta->num_tensors; ++i) { + if ((meta->tensor + i)->id == id) + return i; + } + + return GST_TENSOR_MISSING_ID; +} diff --git a/subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensormeta.h b/subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensormeta.h new file mode 100644 index 0000000000..a862e76f88 --- /dev/null +++ b/subprojects/gst-plugins-bad/ext/onnx/tensor/gsttensormeta.h @@ -0,0 +1,56 @@ +/* + * GStreamer gstreamer-tensormeta + * Copyright (C) 2023 Collabora Ltd + * + * gsttensormeta.h + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Library General Public + * License as published by the Free Software Foundation; either + * version 2 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Library General Public License for more details. + * + * You should have received a copy of the GNU Library General Public + * License along with this library; if not, write to the + * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, + * Boston, MA 02110-1301, USA. + */ +#ifndef __GST_TENSOR_META_H__ +#define __GST_TENSOR_META_H__ + +#include +#include "gsttensor.h" + +/** + * GstTensorMeta: + * + * @meta base GstMeta + * @num_tensors number of tensors + * @tensor @ref GstTensor for each tensor + * @batch_size model batch size + * + * Since: 1.24 + */ +typedef struct _GstTensorMeta +{ + GstMeta meta; + + gint num_tensors; + GstTensor *tensor; + int batch_size; +} GstTensorMeta; + +G_BEGIN_DECLS + +GType gst_tensor_meta_api_get_type (void); +const GstMetaInfo *gst_tensor_meta_get_info (void); +GList *gst_tensor_meta_get_all_from_buffer (GstBuffer * buffer); +gint gst_tensor_meta_get_index_from_id(GstTensorMeta *meta, GQuark id); + +G_END_DECLS + +#endif