tensors: Use full GstTensorDataType type name in type members

Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/6000>
This commit is contained in:
Olivier Crête 2024-11-04 17:04:28 -05:00 committed by GStreamer Marge Bot
parent e01a3b1d79
commit 4295386804
4 changed files with 44 additions and 44 deletions

View file

@ -53,7 +53,7 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
dest (nullptr), dest (nullptr),
m_provider (GST_ONNX_EXECUTION_PROVIDER_CPU), m_provider (GST_ONNX_EXECUTION_PROVIDER_CPU),
inputImageFormat (GST_ML_INPUT_IMAGE_FORMAT_HWC), inputImageFormat (GST_ML_INPUT_IMAGE_FORMAT_HWC),
inputDatatype (GST_TENSOR_TYPE_UINT8), inputDatatype (GST_TENSOR_DATA_TYPE_UINT8),
inputDatatypeSize (sizeof (uint8_t)), inputDatatypeSize (sizeof (uint8_t)),
fixedInputImageSize (false), fixedInputImageSize (false),
inputTensorOffset (0.0), inputTensorOffset (0.0),
@ -100,22 +100,22 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
{ {
inputDatatype = datatype; inputDatatype = datatype;
switch (inputDatatype) { switch (inputDatatype) {
case GST_TENSOR_TYPE_UINT8: case GST_TENSOR_DATA_TYPE_UINT8:
inputDatatypeSize = sizeof (uint8_t); inputDatatypeSize = sizeof (uint8_t);
break; break;
case GST_TENSOR_TYPE_UINT16: case GST_TENSOR_DATA_TYPE_UINT16:
inputDatatypeSize = sizeof (uint16_t); inputDatatypeSize = sizeof (uint16_t);
break; break;
case GST_TENSOR_TYPE_UINT32: case GST_TENSOR_DATA_TYPE_UINT32:
inputDatatypeSize = sizeof (uint32_t); inputDatatypeSize = sizeof (uint32_t);
break; break;
case GST_TENSOR_TYPE_INT32: case GST_TENSOR_DATA_TYPE_INT32:
inputDatatypeSize = sizeof (int32_t); inputDatatypeSize = sizeof (int32_t);
break; break;
case GST_TENSOR_TYPE_FLOAT16: case GST_TENSOR_DATA_TYPE_FLOAT16:
inputDatatypeSize = 2; inputDatatypeSize = 2;
break; break;
case GST_TENSOR_TYPE_FLOAT32: case GST_TENSOR_DATA_TYPE_FLOAT32:
inputDatatypeSize = sizeof (float); inputDatatypeSize = sizeof (float);
break; break;
default: default:
@ -247,10 +247,10 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
switch (elementType) { switch (elementType) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
setInputImageDatatype(GST_TENSOR_TYPE_UINT8); setInputImageDatatype(GST_TENSOR_DATA_TYPE_UINT8);
break; break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
setInputImageDatatype(GST_TENSOR_TYPE_FLOAT32); setInputImageDatatype(GST_TENSOR_DATA_TYPE_FLOAT32);
break; break;
default: default:
GST_ERROR_OBJECT (debug_parent, GST_ERROR_OBJECT (debug_parent,
@ -363,7 +363,7 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
tensor->data = gst_buffer_new_allocate (NULL, buffer_size, NULL); tensor->data = gst_buffer_new_allocate (NULL, buffer_size, NULL);
gst_buffer_fill (tensor->data, 0, outputTensor.GetTensorData < float >(), gst_buffer_fill (tensor->data, 0, outputTensor.GetTensorData < float >(),
buffer_size); buffer_size);
tensor->data_type = GST_TENSOR_TYPE_FLOAT32; tensor->data_type = GST_TENSOR_DATA_TYPE_FLOAT32;
} else if (tensorType == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) { } else if (tensorType == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) {
size_t buffer_size = 0; size_t buffer_size = 0;
@ -371,7 +371,7 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
tensor->data = gst_buffer_new_allocate (NULL, buffer_size, NULL); tensor->data = gst_buffer_new_allocate (NULL, buffer_size, NULL);
gst_buffer_fill (tensor->data, 0, outputTensor.GetTensorData < float >(), gst_buffer_fill (tensor->data, 0, outputTensor.GetTensorData < float >(),
buffer_size); buffer_size);
tensor->data_type = GST_TENSOR_TYPE_INT32; tensor->data_type = GST_TENSOR_DATA_TYPE_INT32;
} else { } else {
GST_ERROR_OBJECT (debug_parent, GST_ERROR_OBJECT (debug_parent,
"Output tensor is not FLOAT32 or INT32, not supported"); "Output tensor is not FLOAT32 or INT32, not supported");
@ -459,7 +459,7 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
std::vector < Ort::Value > inputTensors; std::vector < Ort::Value > inputTensors;
switch (inputDatatype) { switch (inputDatatype) {
case GST_TENSOR_TYPE_UINT8: case GST_TENSOR_DATA_TYPE_UINT8:
uint8_t *src_data; uint8_t *src_data;
if (inputTensorOffset == 00 && inputTensorScale == 1.0) { if (inputTensorOffset == 00 && inputTensorScale == 1.0) {
src_data = img_data; src_data = img_data;
@ -474,7 +474,7 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
memoryInfo, src_data, inputTensorSize, inputDims.data (), memoryInfo, src_data, inputTensorSize, inputDims.data (),
inputDims.size ())); inputDims.size ()));
break; break;
case GST_TENSOR_TYPE_FLOAT32: { case GST_TENSOR_DATA_TYPE_FLOAT32: {
convert_image_remove_alpha ((float*)dest, inputImageFormat , srcPtr, convert_image_remove_alpha ((float*)dest, inputImageFormat , srcPtr,
srcSamplesPerPixel, stride, (float)inputTensorOffset, (float) srcSamplesPerPixel, stride, (float)inputTensorOffset, (float)
inputTensorScale); inputTensorScale);

View file

@ -489,7 +489,7 @@ gst_onnx_inference_transform_caps (GstBaseTransform *
onnxClient->getWidth (), "height", G_TYPE_INT, onnxClient->getWidth (), "height", G_TYPE_INT,
onnxClient->getHeight (), NULL); onnxClient->getHeight (), NULL);
if (onnxClient->getInputImageDatatype() == GST_TENSOR_TYPE_UINT8 && if (onnxClient->getInputImageDatatype() == GST_TENSOR_DATA_TYPE_UINT8 &&
onnxClient->getInputImageScale() == 1.0 && onnxClient->getInputImageScale() == 1.0 &&
onnxClient->getInputImageOffset() == 0.0) { onnxClient->getInputImageOffset() == 0.0) {
switch (onnxClient->getChannels()) { switch (onnxClient->getChannels()) {

View file

@ -32,20 +32,20 @@
/** /**
* GstTensorDataType: * GstTensorDataType:
* @GST_TENSOR_TYPE_INT4 signed 4 bit integer tensor data * @GST_TENSOR_DATA_TYPE_INT4: signed 4 bit integer tensor data
* @GST_TENSOR_TYPE_INT8 signed 8 bit integer tensor data * @GST_TENSOR_DATA_TYPE_INT8: signed 8 bit integer tensor data
* @GST_TENSOR_TYPE_INT16 signed 16 bit integer tensor data * @GST_TENSOR_DATA_TYPE_INT16: signed 16 bit integer tensor data
* @GST_TENSOR_TYPE_INT32 signed 32 bit integer tensor data * @GST_TENSOR_DATA_TYPE_INT32: signed 32 bit integer tensor data
* @GST_TENSOR_TYPE_INT64 signed 64 bit integer tensor data * @GST_TENSOR_DATA_TYPE_INT64: signed 64 bit integer tensor data
* @GST_TENSOR_TYPE_UINT4 unsigned 4 bit integer tensor data * @GST_TENSOR_DATA_TYPE_UINT4: unsigned 4 bit integer tensor data
* @GST_TENSOR_TYPE_UINT8 unsigned 8 bit integer tensor data * @GST_TENSOR_DATA_TYPE_UINT8: unsigned 8 bit integer tensor data
* @GST_TENSOR_TYPE_UINT16 unsigned 16 bit integer tensor data * @GST_TENSOR_DATA_TYPE_UINT16: unsigned 16 bit integer tensor data
* @GST_TENSOR_TYPE_UINT32 unsigned 32 bit integer tensor data * @GST_TENSOR_DATA_TYPE_UINT32: unsigned 32 bit integer tensor data
* @GST_TENSOR_TYPE_UINT64 unsigned 64 bit integer tensor data * @GST_TENSOR_DATA_TYPE_UINT64: unsigned 64 bit integer tensor data
* @GST_TENSOR_TYPE_FLOAT16 16 bit floating point tensor data * @GST_TENSOR_DATA_TYPE_FLOAT16: 16 bit floating point tensor data
* @GST_TENSOR_TYPE_FLOAT32 32 bit floating point tensor data * @GST_TENSOR_DATA_TYPE_FLOAT32: 32 bit floating point tensor data
* @GST_TENSOR_TYPE_FLOAT64 64 bit floating point tensor data * @GST_TENSOR_DATA_TYPE_FLOAT64: 64 bit floating point tensor data
* @GST_TENSOR_TYPE_BFLOAT16 "brain" 16 bit floating point tensor data * @GST_TENSOR_DATA_TYPE_BFLOAT16: "brain" 16 bit floating point tensor data
* *
* Describe the type of data contain in the tensor. * Describe the type of data contain in the tensor.
* *
@ -53,20 +53,20 @@
*/ */
typedef enum _GstTensorDataType typedef enum _GstTensorDataType
{ {
GST_TENSOR_TYPE_INT4, GST_TENSOR_DATA_TYPE_INT4,
GST_TENSOR_TYPE_INT8, GST_TENSOR_DATA_TYPE_INT8,
GST_TENSOR_TYPE_INT16, GST_TENSOR_DATA_TYPE_INT16,
GST_TENSOR_TYPE_INT32, GST_TENSOR_DATA_TYPE_INT32,
GST_TENSOR_TYPE_INT64, GST_TENSOR_DATA_TYPE_INT64,
GST_TENSOR_TYPE_UINT4, GST_TENSOR_DATA_TYPE_UINT4,
GST_TENSOR_TYPE_UINT8, GST_TENSOR_DATA_TYPE_UINT8,
GST_TENSOR_TYPE_UINT16, GST_TENSOR_DATA_TYPE_UINT16,
GST_TENSOR_TYPE_UINT32, GST_TENSOR_DATA_TYPE_UINT32,
GST_TENSOR_TYPE_UINT64, GST_TENSOR_DATA_TYPE_UINT64,
GST_TENSOR_TYPE_FLOAT16, GST_TENSOR_DATA_TYPE_FLOAT16,
GST_TENSOR_TYPE_FLOAT32, GST_TENSOR_DATA_TYPE_FLOAT32,
GST_TENSOR_TYPE_FLOAT64, GST_TENSOR_DATA_TYPE_FLOAT64,
GST_TENSOR_TYPE_BFLOAT16, GST_TENSOR_DATA_TYPE_BFLOAT16
} GstTensorDataType; } GstTensorDataType;
/** /**

View file

@ -386,14 +386,14 @@ gst_ssd_object_detector_transform_ip (GstBaseTransform * trans, GstBuffer * buf)
guint index, TYPE * out) \ guint index, TYPE * out) \
{ \ { \
switch (tensor->data_type) { \ switch (tensor->data_type) { \
case GST_TENSOR_TYPE_FLOAT32: { \ case GST_TENSOR_DATA_TYPE_FLOAT32: { \
float *f = (float *) map->data; \ float *f = (float *) map->data; \
if (sizeof(*f) * (index + 1) > map->size) \ if (sizeof(*f) * (index + 1) > map->size) \
return FALSE; \ return FALSE; \
*out = f[index]; \ *out = f[index]; \
break; \ break; \
} \ } \
case GST_TENSOR_TYPE_UINT32: { \ case GST_TENSOR_DATA_TYPE_UINT32: { \
guint32 *u = (guint32 *) map->data; \ guint32 *u = (guint32 *) map->data; \
if (sizeof(*u) * (index + 1) > map->size) \ if (sizeof(*u) * (index + 1) > map->size) \
return FALSE; \ return FALSE; \