analytics: Move batch to GstTensor

- batch_size is required to interpret the tensor depending on the tensor format
the batch are not necessarily memory plane therefore it's preferable to keep it
inside GstTensor.

Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/6000>
This commit is contained in:
Daniel Morin 2024-07-17 14:39:42 -04:00 committed by GStreamer Marge Bot
parent 43c7e524ce
commit 7c925eae61
4 changed files with 38 additions and 5 deletions

View file

@ -348,7 +348,8 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
tensor->id = 0; tensor->id = 0;
auto tensorShape = outputTensor.GetTensorTypeAndShapeInfo ().GetShape (); auto tensorShape = outputTensor.GetTensorTypeAndShapeInfo ().GetShape ();
tensor->num_dims = tensorShape.size (); tensor->num_dims = tensorShape.size ();
tensor->dims = g_new (int64_t, tensor->num_dims); tensor->dims = g_new (gsize, tensor->num_dims);
tensor->batch_size = 1;
for (size_t j = 0; j < tensorShape.size (); ++j) for (size_t j = 0; j < tensorShape.size (); ++j)
tensor->dims[j] = tensorShape[j]; tensor->dims[j] = tensorShape[j];

View file

@ -587,7 +587,6 @@ gst_onnx_inference_process (GstBaseTransform * trans, GstBuffer * buf)
if (!meta) if (!meta)
return FALSE; return FALSE;
GST_TRACE_OBJECT (trans, "Num tensors:%zu", meta->num_tensors); GST_TRACE_OBJECT (trans, "Num tensors:%zu", meta->num_tensors);
meta->batch_size = 1;
} }
catch (Ort::Exception & ortex) { catch (Ort::Exception & ortex) {
GST_ERROR_OBJECT (self, "%s", ortex.what ()); GST_ERROR_OBJECT (self, "%s", ortex.what ());

View file

@ -69,12 +69,43 @@ typedef enum _GstTensorDataType
GST_TENSOR_TYPE_BFLOAT16, GST_TENSOR_TYPE_BFLOAT16,
} GstTensorDataType; } GstTensorDataType;
/**
* GstTensorDimOrder:
* @GST_TENSOR_DIM_ORDER_ROW_MAJOR: elements along a row are consecutive in memory
* @GST_TENSOR_DIM_ORDER_COL_MAJOR: elements along a column are consecutive in memory
*
* Indicate to read tensor from memory in row-major or column-major.
*
* Since: 1.26
*/
typedef enum _GstTensorDimOrder
{
GST_TENSOR_DIM_ORDER_ROW_MAJOR,
GST_TENSOR_DIM_ORDER_COL_MAJOR
} GstTensorDimOrder;
/**
* GstTensorLayout:
* @GST_TENSOR_LAYOUT_STRIDED: indicate the tensor is stored in a dense format in memory
*
* Indicate tensor storage in memory.
*
* Since: 1.26
*/
typedef enum _GstTensorLayout
{
GST_TENSOR_LAYOUT_STRIDED
} GstTensorLayout;
/** /**
* GstTensor: * GstTensor:
* @id: semantically identify the contents of the tensor * @id: semantically identify the contents of the tensor
* @num_dims: number of tensor dimensions * @num_dims: number of tensor dimensions
* @dims: tensor dimensions * @dims: tensor dimensions
* @dims_order: Indicate tensor elements layout in memory.
* @layout: Indicate tensor layout
* @type: #GstTensorDataType of tensor data * @type: #GstTensorDataType of tensor data
* @batch_size: Model batch size
* @data: #GstBuffer holding tensor data * @data: #GstBuffer holding tensor data
* *
* Hold tensor data * Hold tensor data
@ -84,9 +115,12 @@ typedef enum _GstTensorDataType
typedef struct _GstTensor typedef struct _GstTensor
{ {
GQuark id; GQuark id;
gint num_dims; gsize num_dims;
int64_t *dims; gsize *dims;
GstTensorDimOrder dims_order;
GstTensorLayout layout;
GstTensorDataType data_type; GstTensorDataType data_type;
gsize batch_size;
GstBuffer *data; GstBuffer *data;
} GstTensor; } GstTensor;

View file

@ -46,7 +46,6 @@ typedef struct _GstTensorMeta
gsize num_tensors; gsize num_tensors;
GstTensor *tensor; GstTensor *tensor;
gsize batch_size;
} GstTensorMeta; } GstTensorMeta;
G_BEGIN_DECLS G_BEGIN_DECLS