diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp index b83a30de0f..19d49c69d5 100644 --- a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp +++ b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp @@ -333,7 +333,7 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren gst_tensor_meta_get_info (), NULL); tmeta->num_tensors = num_tensors; - tmeta->tensor = (GstTensor *) g_malloc (num_tensors * sizeof (GstTensor)); + tmeta->tensors = g_new (GstTensor *, num_tensors); bool hasIds = outputIds.size () == num_tensors; for (size_t i = 0; i < num_tensors; i++) { Ort::Value outputTensor = std::move (outputs[i]); @@ -341,14 +341,15 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren ONNXTensorElementDataType tensorType = outputTensor.GetTensorTypeAndShapeInfo ().GetElementType (); - GstTensor *tensor = &tmeta->tensor[i]; + auto tensorShape = outputTensor.GetTensorTypeAndShapeInfo ().GetShape (); + GstTensor *tensor = gst_tensor_alloc (tensorShape.size ()); + tmeta->tensors[i] = tensor; + if (hasIds) tensor->id = outputIds[i]; else tensor->id = 0; - auto tensorShape = outputTensor.GetTensorTypeAndShapeInfo ().GetShape (); tensor->num_dims = tensorShape.size (); - tensor->dims = g_new (gsize, tensor->num_dims); tensor->batch_size = 1; for (size_t j = 0; j < tensorShape.size (); ++j) diff --git a/subprojects/gst-plugins-bad/gst-libs/gst/analytics/gsttensor.c b/subprojects/gst-plugins-bad/gst-libs/gst/analytics/gsttensor.c new file mode 100644 index 0000000000..ac69016a82 --- /dev/null +++ b/subprojects/gst-plugins-bad/gst-libs/gst/analytics/gsttensor.c @@ -0,0 +1,92 @@ +/* GStreamer + * Copyright (C) 2023 Collabora Ltd + * + * gstanalyticsmeta.c + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Library General Public + * License as published by the Free Software Foundation; either + * version 2 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Library General Public License for more details. + * + * You should have received a copy of the GNU Library General Public + * License along with this library; if not, write to the + * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, + * Boston, MA 02110-1301, USA. + */ +#ifdef HAVE_CONFIG_H +#include "config.h" +#endif + +#include "gsttensor.h" + +#define GST_TENSOR_SIZE(num_dims) \ + (sizeof (GstTensor) + (sizeof (gsize) * num_dims)) + +G_DEFINE_BOXED_TYPE (GstTensor, gst_tensor, + (GBoxedCopyFunc) gst_tensor_copy, (GBoxedFreeFunc) gst_tensor_free); + +/** + * gst_tensor_alloc: (constructor) + * @num_dims: Number of dimension of the tensors + * + * Allocate a tensor with @num_dims dimensions. + * + * Returns: (transfer full) (not nullable): tensor allocated + * + * Since: 1.26 + */ +GstTensor * +gst_tensor_alloc (gsize num_dims) +{ + GstTensor *tensor = g_malloc0 (GST_TENSOR_SIZE (num_dims)); + + tensor->num_dims = num_dims; + + return tensor; +} +} + +/** + * gst_tensor_free: + * @tensor: (in) (transfer full): pointer to tensor to free + * + * Free tensor + * + * Since: 1.26 + */ +void +gst_tensor_free (GstTensor * tensor) +{ + if (tensor->data != NULL) { + gst_buffer_unref (tensor->data); + } + g_free (tensor); +} + +/** + * gst_tensor_copy: + * @tensor: (transfer none) (nullable): a #GstTensor to be copied + * + * Create a copy of @tensor. + * + * Returns: (transfer full) (nullable): a new #GstTensor + * + * Since: 1.26 + */ +GstTensor * +gst_tensor_copy (const GstTensor * tensor) +{ + GstTensor *copy = NULL; + if (tensor) { + copy = (GstTensor *) g_memdup2 (tensor, GST_TENSOR_SIZE (tensor->num_dims)); + if (copy->data) + gst_buffer_ref (copy->data); + } + + return copy; +} diff --git a/subprojects/gst-plugins-bad/gst-libs/gst/analytics/gsttensor.h b/subprojects/gst-plugins-bad/gst-libs/gst/analytics/gsttensor.h index f1d5466a0c..cfbee177a9 100644 --- a/subprojects/gst-plugins-bad/gst-libs/gst/analytics/gsttensor.h +++ b/subprojects/gst-plugins-bad/gst-libs/gst/analytics/gsttensor.h @@ -115,15 +115,33 @@ typedef enum _GstTensorLayout typedef struct _GstTensor { GQuark id; - gsize num_dims; - gsize *dims; GstTensorDimOrder dims_order; GstTensorLayout layout; GstTensorDataType data_type; gsize batch_size; GstBuffer *data; + gsize num_dims; + gsize dims[]; } GstTensor; +G_BEGIN_DECLS + +#define GST_TYPE_TENSOR (gst_tensor_get_type()) + +GST_ANALYTICS_META_API +GstTensor * gst_tensor_alloc (gsize num_dims); + +GST_ANALYTICS_META_API +void gst_tensor_free (GstTensor * tensor); + +GST_ANALYTICS_META_API +GstTensor * gst_tensor_copy (const GstTensor * tensor); + +GST_ANALYTICS_META_API +GType gst_tensor_get_type (void); + #define GST_TENSOR_MISSING_ID -1 +G_END_DECLS + #endif /* __GST_TENSOR_H__ */ diff --git a/subprojects/gst-plugins-bad/gst-libs/gst/analytics/gsttensormeta.c b/subprojects/gst-plugins-bad/gst-libs/gst/analytics/gsttensormeta.c index c350bc3ac9..3fa853c736 100644 --- a/subprojects/gst-plugins-bad/gst-libs/gst/analytics/gsttensormeta.c +++ b/subprojects/gst-plugins-bad/gst-libs/gst/analytics/gsttensormeta.c @@ -28,7 +28,7 @@ gst_tensor_meta_init (GstMeta * meta, gpointer params, GstBuffer * buffer) GstTensorMeta *tmeta = (GstTensorMeta *) meta; tmeta->num_tensors = 0; - tmeta->tensor = NULL; + tmeta->tensors = NULL; return TRUE; } @@ -39,10 +39,9 @@ gst_tensor_meta_free (GstMeta * meta, GstBuffer * buffer) GstTensorMeta *tmeta = (GstTensorMeta *) meta; for (int i = 0; i < tmeta->num_tensors; i++) { - g_free (tmeta->tensor[i].dims); - gst_buffer_unref (tmeta->tensor[i].data); + gst_tensor_free (tmeta->tensors[i]); } - g_free (tmeta->tensor); + g_free (tmeta->tensors); } GType @@ -81,7 +80,7 @@ gint gst_tensor_meta_get_index_from_id (GstTensorMeta * meta, GQuark id) { for (int i = 0; i < meta->num_tensors; ++i) { - if ((meta->tensor + i)->id == id) + if (meta->tensors[i]->id == id) return i; } diff --git a/subprojects/gst-plugins-bad/gst-libs/gst/analytics/gsttensormeta.h b/subprojects/gst-plugins-bad/gst-libs/gst/analytics/gsttensormeta.h index 24298b1547..667c12660a 100644 --- a/subprojects/gst-plugins-bad/gst-libs/gst/analytics/gsttensormeta.h +++ b/subprojects/gst-plugins-bad/gst-libs/gst/analytics/gsttensormeta.h @@ -45,7 +45,7 @@ typedef struct _GstTensorMeta GstMeta meta; gsize num_tensors; - GstTensor *tensor; + GstTensor **tensors; } GstTensorMeta; G_BEGIN_DECLS diff --git a/subprojects/gst-plugins-bad/gst-libs/gst/analytics/meson.build b/subprojects/gst-plugins-bad/gst-libs/gst/analytics/meson.build index af1a777c5a..6a00c24add 100644 --- a/subprojects/gst-plugins-bad/gst-libs/gst/analytics/meson.build +++ b/subprojects/gst-plugins-bad/gst-libs/gst/analytics/meson.build @@ -3,7 +3,8 @@ analytics_sources = files( 'gstanalyticsmeta.c', 'gstanalyticsobjectdetectionmtd.c', 'gstanalyticsobjecttrackingmtd.c', 'gstanalyticssegmentationmtd.c', - 'gsttensormeta.c') + 'gsttensormeta.c', + 'gsttensor.c') analytics_headers = files( 'analytics.h', 'gstanalyticsmeta.h', diff --git a/subprojects/gst-plugins-bad/gst/tensordecoders/gstssdobjectdetector.c b/subprojects/gst-plugins-bad/gst/tensordecoders/gstssdobjectdetector.c index e1ce336ed1..8901ea643c 100644 --- a/subprojects/gst-plugins-bad/gst/tensordecoders/gstssdobjectdetector.c +++ b/subprojects/gst-plugins-bad/gst/tensordecoders/gstssdobjectdetector.c @@ -442,21 +442,21 @@ DEFINE_GET_FUNC (guint32, UINT32_MAX) return; } - if (!gst_buffer_map (tmeta->tensor[numdetect_index].data, &numdetect_map, + if (!gst_buffer_map (tmeta->tensors[numdetect_index]->data, &numdetect_map, GST_MAP_READ)) { GST_ERROR_OBJECT (self, "Failed to map tensor memory for index %d", numdetect_index); goto cleanup; } - if (!gst_buffer_map (tmeta->tensor[boxes_index].data, &boxes_map, + if (!gst_buffer_map (tmeta->tensors[boxes_index]->data, &boxes_map, GST_MAP_READ)) { GST_ERROR_OBJECT (self, "Failed to map tensor memory for index %d", boxes_index); goto cleanup; } - if (!gst_buffer_map (tmeta->tensor[scores_index].data, &scores_map, + if (!gst_buffer_map (tmeta->tensors[scores_index]->data, &scores_map, GST_MAP_READ)) { GST_ERROR_OBJECT (self, "Failed to map tensor memory for index %d", scores_index); @@ -464,14 +464,14 @@ DEFINE_GET_FUNC (guint32, UINT32_MAX) } if (classes_index != GST_TENSOR_MISSING_ID && - !gst_buffer_map (tmeta->tensor[classes_index].data, &classes_map, + !gst_buffer_map (tmeta->tensors[classes_index]->data, &classes_map, GST_MAP_READ)) { GST_DEBUG_OBJECT (self, "Failed to map tensor memory for index %d", classes_index); } - if (!get_guint32_at_index (&tmeta->tensor[numdetect_index], &numdetect_map, + if (!get_guint32_at_index (tmeta->tensors[numdetect_index], &numdetect_map, 0, &num_detections)) { GST_ERROR_OBJECT (self, "Failed to get the number of detections"); goto cleanup; @@ -488,7 +488,7 @@ DEFINE_GET_FUNC (guint32, UINT32_MAX) GQuark label = 0; GstAnalyticsODMtd odmtd; - if (!get_float_at_index (&tmeta->tensor[numdetect_index], &scores_map, + if (!get_float_at_index (tmeta->tensors[numdetect_index], &scores_map, i, &score)) continue; @@ -496,16 +496,16 @@ DEFINE_GET_FUNC (guint32, UINT32_MAX) if (score < self->score_threshold) continue; - if (!get_float_at_index (&tmeta->tensor[boxes_index], &boxes_map, + if (!get_float_at_index (tmeta->tensors[boxes_index], &boxes_map, i * 4, &y)) continue; - if (!get_float_at_index (&tmeta->tensor[boxes_index], &boxes_map, + if (!get_float_at_index (tmeta->tensors[boxes_index], &boxes_map, i * 4 + 1, &x)) continue; - if (!get_float_at_index (&tmeta->tensor[boxes_index], &boxes_map, + if (!get_float_at_index (tmeta->tensors[boxes_index], &boxes_map, i * 4 + 2, &bheight)) continue; - if (!get_float_at_index (&tmeta->tensor[boxes_index], &boxes_map, + if (!get_float_at_index (tmeta->tensors[boxes_index], &boxes_map, i * 4 + 3, &bwidth)) continue; @@ -517,7 +517,7 @@ DEFINE_GET_FUNC (guint32, UINT32_MAX) } if (self->labels && classes_map.memory && - get_guint32_at_index (&tmeta->tensor[classes_index], &classes_map, + get_guint32_at_index (tmeta->tensors[classes_index], &classes_map, i, &bclass)) { if (bclass < self->labels->len) label = g_array_index (self->labels, GQuark, bclass); @@ -540,13 +540,13 @@ DEFINE_GET_FUNC (guint32, UINT32_MAX) cleanup: if (numdetect_map.memory) - gst_buffer_unmap (tmeta->tensor[numdetect_index].data, &numdetect_map); + gst_buffer_unmap (tmeta->tensors[numdetect_index]->data, &numdetect_map); if (classes_map.memory) - gst_buffer_unmap (tmeta->tensor[classes_index].data, &classes_map); + gst_buffer_unmap (tmeta->tensors[classes_index]->data, &classes_map); if (scores_map.memory) - gst_buffer_unmap (tmeta->tensor[scores_index].data, &scores_map); + gst_buffer_unmap (tmeta->tensors[scores_index]->data, &scores_map); if (boxes_map.memory) - gst_buffer_unmap (tmeta->tensor[boxes_index].data, &boxes_map); + gst_buffer_unmap (tmeta->tensors[boxes_index]->data, &boxes_map); }