From 5e73e8e1b38803ba91db97e7b04854d170178527 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Cr=C3=AAte?= Date: Mon, 4 Nov 2024 19:29:43 -0500 Subject: [PATCH] tensor: Add APIs to create and access GstTensor contents Part-of: --- .../gst-libs/gst/analytics/gsttensor.c | 122 ++++++++++++++++++ .../gst-libs/gst/analytics/gsttensor.h | 12 ++ .../gst-libs/gst/analytics/gsttensormeta.c | 5 + 3 files changed, 139 insertions(+) diff --git a/subprojects/gst-plugins-bad/gst-libs/gst/analytics/gsttensor.c b/subprojects/gst-plugins-bad/gst-libs/gst/analytics/gsttensor.c index 9101390633..c54e9b2fab 100644 --- a/subprojects/gst-plugins-bad/gst-libs/gst/analytics/gsttensor.c +++ b/subprojects/gst-plugins-bad/gst-libs/gst/analytics/gsttensor.c @@ -49,6 +49,108 @@ gst_tensor_alloc (gsize num_dims) return tensor; } + +static gsize +size_for_elements (GstTensorDataType data_type, gsize elements) +{ + switch (data_type) { + case GST_TENSOR_DATA_TYPE_INT4: + case GST_TENSOR_DATA_TYPE_UINT4: + return (elements / 2) + (elements % 2); + + case GST_TENSOR_DATA_TYPE_INT8: + case GST_TENSOR_DATA_TYPE_UINT8: + return elements; + + case GST_TENSOR_DATA_TYPE_INT16: + case GST_TENSOR_DATA_TYPE_UINT16: + case GST_TENSOR_DATA_TYPE_FLOAT16: + case GST_TENSOR_DATA_TYPE_BFLOAT16: + return elements * 2; + + case GST_TENSOR_DATA_TYPE_INT32: + case GST_TENSOR_DATA_TYPE_UINT32: + case GST_TENSOR_DATA_TYPE_FLOAT32: + return elements * 4; + + case GST_TENSOR_DATA_TYPE_INT64: + case GST_TENSOR_DATA_TYPE_UINT64: + case GST_TENSOR_DATA_TYPE_FLOAT64: + return elements * 8; + + default: + g_assert_not_reached (); + return 0; + } +} + +/** + * gst_tensor_new_simple: + * @id: semantically identify the contents of the tensor + * @data_type: #GstTensorDataType of tensor data + * @batch_size: Model batch size + * @data: (transfer full): #GstBuffer holding tensor data + * @dims_order: Indicate tensor dimension indexing order + * @num_dims: number of tensor dimensions + * @dims: (array length=num_dims): tensor dimensions + * + * Allocates a new #GstTensor of @dims_order ROW_MAJOR or COLUMN_MAJOR and + * with an interleaved layout + * + * Returns: A newly allocated #GstTensor + * + * Since: 1.26 + */ +GstTensor * +gst_tensor_new_simple (GQuark id, GstTensorDataType data_type, + gsize batch_size, GstBuffer * data, + GstTensorDimOrder dims_order, gsize num_dims, gsize * dims) +{ + GstTensor *tensor; + gsize num_elements = 1; + gsize i; + + /* Update this if adding more to GstTensorDataType */ + g_return_val_if_fail (data_type <= GST_TENSOR_DATA_TYPE_BFLOAT16, NULL); + + g_return_val_if_fail (batch_size > 0, NULL); + g_return_val_if_fail (GST_IS_BUFFER (data), NULL); + g_return_val_if_fail (dims_order == GST_TENSOR_DIM_ORDER_ROW_MAJOR || + dims_order == GST_TENSOR_DIM_ORDER_COL_MAJOR, NULL); + g_return_val_if_fail (num_dims > 0, NULL); + + + for (i = 0; i < num_dims; i++) { + g_return_val_if_fail (dims[i] > 0, NULL); + num_elements *= dims[i]; + } + num_elements *= batch_size; + + if (gst_buffer_get_size (data) != size_for_elements (data_type, num_elements)) { + g_critical ("Expected buffer of size %zu (%zu elements)," + " but buffer has size %zu", + size_for_elements (data_type, num_elements), num_elements, + gst_buffer_get_size (data)); + return NULL; + } + + tensor = gst_tensor_alloc (num_dims); + tensor->id = id; + tensor->layout = GST_TENSOR_LAYOUT_STRIDED; + tensor->data_type = data_type; + tensor->batch_size = batch_size; + tensor->data = data; + tensor->dims_order = dims_order; + tensor->num_dims = num_dims; + for (i = 0; i < num_dims; i++) { + tensor->dims[i].size = dims[i]; + if (dims_order == GST_TENSOR_DIM_ORDER_COL_MAJOR) + tensor->dims[i].order_index = i; + else if (dims_order == GST_TENSOR_DIM_ORDER_ROW_MAJOR) + tensor->dims[i].order_index = num_dims - i - 1; + } + + return tensor; } /** @@ -90,3 +192,23 @@ gst_tensor_copy (const GstTensor * tensor) return copy; } + +/** + * gst_tensor_get_dims: + * @tensor: a #GstTensor + * @num_dims: (out): The number of dimensions + * + * Gets the dimensions of the tensor. + * + * Returns: (array length=num_dims) (transfer none): The dims array form the tensor + * + * Since: 1.26 + */ + +GstTensorDim * +gst_tensor_get_dims (GstTensor * tensor, gsize * num_dims) +{ + if (num_dims) + *num_dims = tensor->num_dims; + return tensor->dims; +} diff --git a/subprojects/gst-plugins-bad/gst-libs/gst/analytics/gsttensor.h b/subprojects/gst-plugins-bad/gst-libs/gst/analytics/gsttensor.h index 40819a2119..2585fbcd8c 100644 --- a/subprojects/gst-plugins-bad/gst-libs/gst/analytics/gsttensor.h +++ b/subprojects/gst-plugins-bad/gst-libs/gst/analytics/gsttensor.h @@ -153,12 +153,24 @@ G_BEGIN_DECLS GST_ANALYTICS_META_API GstTensor * gst_tensor_alloc (gsize num_dims); +GST_ANALYTICS_META_API +GstTensor * gst_tensor_new_simple (GQuark id, + GstTensorDataType data_type, + gsize batch_size, + GstBuffer * data, + GstTensorDimOrder dims_order, + gsize num_dims, + gsize * dims); + GST_ANALYTICS_META_API void gst_tensor_free (GstTensor * tensor); GST_ANALYTICS_META_API GstTensor * gst_tensor_copy (const GstTensor * tensor); +GST_ANALYTICS_META_API +GstTensorDim * gst_tensor_get_dims (GstTensor * tensor, gsize * num_dims); + GST_ANALYTICS_META_API GType gst_tensor_get_type (void); diff --git a/subprojects/gst-plugins-bad/gst-libs/gst/analytics/gsttensormeta.c b/subprojects/gst-plugins-bad/gst-libs/gst/analytics/gsttensormeta.c index a9871a218c..26243a31ab 100644 --- a/subprojects/gst-plugins-bad/gst-libs/gst/analytics/gsttensormeta.c +++ b/subprojects/gst-plugins-bad/gst-libs/gst/analytics/gsttensormeta.c @@ -58,6 +58,11 @@ gst_tensor_meta_api_get_type (void) } +/** + * gst_tensor_meta_get_info: (skip) + * + * Since: 1.26 + */ const GstMetaInfo * gst_tensor_meta_get_info (void) {