mirror of
https://gitlab.freedesktop.org/gstreamer/gstreamer.git
synced 2024-11-22 01:31:03 +00:00
tensor: Add APIs to create and access GstTensor contents
Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/6000>
This commit is contained in:
parent
4295386804
commit
5e73e8e1b3
3 changed files with 139 additions and 0 deletions
|
@ -49,6 +49,108 @@ gst_tensor_alloc (gsize num_dims)
|
|||
|
||||
return tensor;
|
||||
}
|
||||
|
||||
static gsize
|
||||
size_for_elements (GstTensorDataType data_type, gsize elements)
|
||||
{
|
||||
switch (data_type) {
|
||||
case GST_TENSOR_DATA_TYPE_INT4:
|
||||
case GST_TENSOR_DATA_TYPE_UINT4:
|
||||
return (elements / 2) + (elements % 2);
|
||||
|
||||
case GST_TENSOR_DATA_TYPE_INT8:
|
||||
case GST_TENSOR_DATA_TYPE_UINT8:
|
||||
return elements;
|
||||
|
||||
case GST_TENSOR_DATA_TYPE_INT16:
|
||||
case GST_TENSOR_DATA_TYPE_UINT16:
|
||||
case GST_TENSOR_DATA_TYPE_FLOAT16:
|
||||
case GST_TENSOR_DATA_TYPE_BFLOAT16:
|
||||
return elements * 2;
|
||||
|
||||
case GST_TENSOR_DATA_TYPE_INT32:
|
||||
case GST_TENSOR_DATA_TYPE_UINT32:
|
||||
case GST_TENSOR_DATA_TYPE_FLOAT32:
|
||||
return elements * 4;
|
||||
|
||||
case GST_TENSOR_DATA_TYPE_INT64:
|
||||
case GST_TENSOR_DATA_TYPE_UINT64:
|
||||
case GST_TENSOR_DATA_TYPE_FLOAT64:
|
||||
return elements * 8;
|
||||
|
||||
default:
|
||||
g_assert_not_reached ();
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* gst_tensor_new_simple:
|
||||
* @id: semantically identify the contents of the tensor
|
||||
* @data_type: #GstTensorDataType of tensor data
|
||||
* @batch_size: Model batch size
|
||||
* @data: (transfer full): #GstBuffer holding tensor data
|
||||
* @dims_order: Indicate tensor dimension indexing order
|
||||
* @num_dims: number of tensor dimensions
|
||||
* @dims: (array length=num_dims): tensor dimensions
|
||||
*
|
||||
* Allocates a new #GstTensor of @dims_order ROW_MAJOR or COLUMN_MAJOR and
|
||||
* with an interleaved layout
|
||||
*
|
||||
* Returns: A newly allocated #GstTensor
|
||||
*
|
||||
* Since: 1.26
|
||||
*/
|
||||
GstTensor *
|
||||
gst_tensor_new_simple (GQuark id, GstTensorDataType data_type,
|
||||
gsize batch_size, GstBuffer * data,
|
||||
GstTensorDimOrder dims_order, gsize num_dims, gsize * dims)
|
||||
{
|
||||
GstTensor *tensor;
|
||||
gsize num_elements = 1;
|
||||
gsize i;
|
||||
|
||||
/* Update this if adding more to GstTensorDataType */
|
||||
g_return_val_if_fail (data_type <= GST_TENSOR_DATA_TYPE_BFLOAT16, NULL);
|
||||
|
||||
g_return_val_if_fail (batch_size > 0, NULL);
|
||||
g_return_val_if_fail (GST_IS_BUFFER (data), NULL);
|
||||
g_return_val_if_fail (dims_order == GST_TENSOR_DIM_ORDER_ROW_MAJOR ||
|
||||
dims_order == GST_TENSOR_DIM_ORDER_COL_MAJOR, NULL);
|
||||
g_return_val_if_fail (num_dims > 0, NULL);
|
||||
|
||||
|
||||
for (i = 0; i < num_dims; i++) {
|
||||
g_return_val_if_fail (dims[i] > 0, NULL);
|
||||
num_elements *= dims[i];
|
||||
}
|
||||
num_elements *= batch_size;
|
||||
|
||||
if (gst_buffer_get_size (data) != size_for_elements (data_type, num_elements)) {
|
||||
g_critical ("Expected buffer of size %zu (%zu elements),"
|
||||
" but buffer has size %zu",
|
||||
size_for_elements (data_type, num_elements), num_elements,
|
||||
gst_buffer_get_size (data));
|
||||
return NULL;
|
||||
}
|
||||
|
||||
tensor = gst_tensor_alloc (num_dims);
|
||||
tensor->id = id;
|
||||
tensor->layout = GST_TENSOR_LAYOUT_STRIDED;
|
||||
tensor->data_type = data_type;
|
||||
tensor->batch_size = batch_size;
|
||||
tensor->data = data;
|
||||
tensor->dims_order = dims_order;
|
||||
tensor->num_dims = num_dims;
|
||||
for (i = 0; i < num_dims; i++) {
|
||||
tensor->dims[i].size = dims[i];
|
||||
if (dims_order == GST_TENSOR_DIM_ORDER_COL_MAJOR)
|
||||
tensor->dims[i].order_index = i;
|
||||
else if (dims_order == GST_TENSOR_DIM_ORDER_ROW_MAJOR)
|
||||
tensor->dims[i].order_index = num_dims - i - 1;
|
||||
}
|
||||
|
||||
return tensor;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -90,3 +192,23 @@ gst_tensor_copy (const GstTensor * tensor)
|
|||
|
||||
return copy;
|
||||
}
|
||||
|
||||
/**
|
||||
* gst_tensor_get_dims:
|
||||
* @tensor: a #GstTensor
|
||||
* @num_dims: (out): The number of dimensions
|
||||
*
|
||||
* Gets the dimensions of the tensor.
|
||||
*
|
||||
* Returns: (array length=num_dims) (transfer none): The dims array form the tensor
|
||||
*
|
||||
* Since: 1.26
|
||||
*/
|
||||
|
||||
GstTensorDim *
|
||||
gst_tensor_get_dims (GstTensor * tensor, gsize * num_dims)
|
||||
{
|
||||
if (num_dims)
|
||||
*num_dims = tensor->num_dims;
|
||||
return tensor->dims;
|
||||
}
|
||||
|
|
|
@ -153,12 +153,24 @@ G_BEGIN_DECLS
|
|||
GST_ANALYTICS_META_API
|
||||
GstTensor * gst_tensor_alloc (gsize num_dims);
|
||||
|
||||
GST_ANALYTICS_META_API
|
||||
GstTensor * gst_tensor_new_simple (GQuark id,
|
||||
GstTensorDataType data_type,
|
||||
gsize batch_size,
|
||||
GstBuffer * data,
|
||||
GstTensorDimOrder dims_order,
|
||||
gsize num_dims,
|
||||
gsize * dims);
|
||||
|
||||
GST_ANALYTICS_META_API
|
||||
void gst_tensor_free (GstTensor * tensor);
|
||||
|
||||
GST_ANALYTICS_META_API
|
||||
GstTensor * gst_tensor_copy (const GstTensor * tensor);
|
||||
|
||||
GST_ANALYTICS_META_API
|
||||
GstTensorDim * gst_tensor_get_dims (GstTensor * tensor, gsize * num_dims);
|
||||
|
||||
GST_ANALYTICS_META_API
|
||||
GType gst_tensor_get_type (void);
|
||||
|
||||
|
|
|
@ -58,6 +58,11 @@ gst_tensor_meta_api_get_type (void)
|
|||
}
|
||||
|
||||
|
||||
/**
|
||||
* gst_tensor_meta_get_info: (skip)
|
||||
*
|
||||
* Since: 1.26
|
||||
*/
|
||||
const GstMetaInfo *
|
||||
gst_tensor_meta_get_info (void)
|
||||
{
|
||||
|
|
Loading…
Reference in a new issue