From 4295386804f3440a05e21f02c7990008ae20c341 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Cr=C3=AAte?= Date: Mon, 4 Nov 2024 17:04:28 -0500 Subject: [PATCH] tensors: Use full GstTensorDataType type name in type members Part-of: --- .../ext/onnx/gstonnxclient.cpp | 26 ++++----- .../ext/onnx/gstonnxinference.cpp | 2 +- .../gst-libs/gst/analytics/gsttensor.h | 56 +++++++++---------- .../gst/tensordecoders/gstssdobjectdetector.c | 4 +- 4 files changed, 44 insertions(+), 44 deletions(-) diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp index b35765d23b..8536281953 100644 --- a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp +++ b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp @@ -53,7 +53,7 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren dest (nullptr), m_provider (GST_ONNX_EXECUTION_PROVIDER_CPU), inputImageFormat (GST_ML_INPUT_IMAGE_FORMAT_HWC), - inputDatatype (GST_TENSOR_TYPE_UINT8), + inputDatatype (GST_TENSOR_DATA_TYPE_UINT8), inputDatatypeSize (sizeof (uint8_t)), fixedInputImageSize (false), inputTensorOffset (0.0), @@ -100,22 +100,22 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren { inputDatatype = datatype; switch (inputDatatype) { - case GST_TENSOR_TYPE_UINT8: + case GST_TENSOR_DATA_TYPE_UINT8: inputDatatypeSize = sizeof (uint8_t); break; - case GST_TENSOR_TYPE_UINT16: + case GST_TENSOR_DATA_TYPE_UINT16: inputDatatypeSize = sizeof (uint16_t); break; - case GST_TENSOR_TYPE_UINT32: + case GST_TENSOR_DATA_TYPE_UINT32: inputDatatypeSize = sizeof (uint32_t); break; - case GST_TENSOR_TYPE_INT32: + case GST_TENSOR_DATA_TYPE_INT32: inputDatatypeSize = sizeof (int32_t); break; - case GST_TENSOR_TYPE_FLOAT16: + case GST_TENSOR_DATA_TYPE_FLOAT16: inputDatatypeSize = 2; break; - case GST_TENSOR_TYPE_FLOAT32: + case GST_TENSOR_DATA_TYPE_FLOAT32: inputDatatypeSize = sizeof (float); break; default: @@ -247,10 +247,10 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren switch (elementType) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: - setInputImageDatatype(GST_TENSOR_TYPE_UINT8); + setInputImageDatatype(GST_TENSOR_DATA_TYPE_UINT8); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: - setInputImageDatatype(GST_TENSOR_TYPE_FLOAT32); + setInputImageDatatype(GST_TENSOR_DATA_TYPE_FLOAT32); break; default: GST_ERROR_OBJECT (debug_parent, @@ -363,7 +363,7 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren tensor->data = gst_buffer_new_allocate (NULL, buffer_size, NULL); gst_buffer_fill (tensor->data, 0, outputTensor.GetTensorData < float >(), buffer_size); - tensor->data_type = GST_TENSOR_TYPE_FLOAT32; + tensor->data_type = GST_TENSOR_DATA_TYPE_FLOAT32; } else if (tensorType == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) { size_t buffer_size = 0; @@ -371,7 +371,7 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren tensor->data = gst_buffer_new_allocate (NULL, buffer_size, NULL); gst_buffer_fill (tensor->data, 0, outputTensor.GetTensorData < float >(), buffer_size); - tensor->data_type = GST_TENSOR_TYPE_INT32; + tensor->data_type = GST_TENSOR_DATA_TYPE_INT32; } else { GST_ERROR_OBJECT (debug_parent, "Output tensor is not FLOAT32 or INT32, not supported"); @@ -459,7 +459,7 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren std::vector < Ort::Value > inputTensors; switch (inputDatatype) { - case GST_TENSOR_TYPE_UINT8: + case GST_TENSOR_DATA_TYPE_UINT8: uint8_t *src_data; if (inputTensorOffset == 00 && inputTensorScale == 1.0) { src_data = img_data; @@ -474,7 +474,7 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren memoryInfo, src_data, inputTensorSize, inputDims.data (), inputDims.size ())); break; - case GST_TENSOR_TYPE_FLOAT32: { + case GST_TENSOR_DATA_TYPE_FLOAT32: { convert_image_remove_alpha ((float*)dest, inputImageFormat , srcPtr, srcSamplesPerPixel, stride, (float)inputTensorOffset, (float) inputTensorScale); diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstonnxinference.cpp b/subprojects/gst-plugins-bad/ext/onnx/gstonnxinference.cpp index d53a1d8872..3abdb3645e 100644 --- a/subprojects/gst-plugins-bad/ext/onnx/gstonnxinference.cpp +++ b/subprojects/gst-plugins-bad/ext/onnx/gstonnxinference.cpp @@ -489,7 +489,7 @@ gst_onnx_inference_transform_caps (GstBaseTransform * onnxClient->getWidth (), "height", G_TYPE_INT, onnxClient->getHeight (), NULL); - if (onnxClient->getInputImageDatatype() == GST_TENSOR_TYPE_UINT8 && + if (onnxClient->getInputImageDatatype() == GST_TENSOR_DATA_TYPE_UINT8 && onnxClient->getInputImageScale() == 1.0 && onnxClient->getInputImageOffset() == 0.0) { switch (onnxClient->getChannels()) { diff --git a/subprojects/gst-plugins-bad/gst-libs/gst/analytics/gsttensor.h b/subprojects/gst-plugins-bad/gst-libs/gst/analytics/gsttensor.h index 297c70f931..40819a2119 100644 --- a/subprojects/gst-plugins-bad/gst-libs/gst/analytics/gsttensor.h +++ b/subprojects/gst-plugins-bad/gst-libs/gst/analytics/gsttensor.h @@ -32,20 +32,20 @@ /** * GstTensorDataType: - * @GST_TENSOR_TYPE_INT4 signed 4 bit integer tensor data - * @GST_TENSOR_TYPE_INT8 signed 8 bit integer tensor data - * @GST_TENSOR_TYPE_INT16 signed 16 bit integer tensor data - * @GST_TENSOR_TYPE_INT32 signed 32 bit integer tensor data - * @GST_TENSOR_TYPE_INT64 signed 64 bit integer tensor data - * @GST_TENSOR_TYPE_UINT4 unsigned 4 bit integer tensor data - * @GST_TENSOR_TYPE_UINT8 unsigned 8 bit integer tensor data - * @GST_TENSOR_TYPE_UINT16 unsigned 16 bit integer tensor data - * @GST_TENSOR_TYPE_UINT32 unsigned 32 bit integer tensor data - * @GST_TENSOR_TYPE_UINT64 unsigned 64 bit integer tensor data - * @GST_TENSOR_TYPE_FLOAT16 16 bit floating point tensor data - * @GST_TENSOR_TYPE_FLOAT32 32 bit floating point tensor data - * @GST_TENSOR_TYPE_FLOAT64 64 bit floating point tensor data - * @GST_TENSOR_TYPE_BFLOAT16 "brain" 16 bit floating point tensor data + * @GST_TENSOR_DATA_TYPE_INT4: signed 4 bit integer tensor data + * @GST_TENSOR_DATA_TYPE_INT8: signed 8 bit integer tensor data + * @GST_TENSOR_DATA_TYPE_INT16: signed 16 bit integer tensor data + * @GST_TENSOR_DATA_TYPE_INT32: signed 32 bit integer tensor data + * @GST_TENSOR_DATA_TYPE_INT64: signed 64 bit integer tensor data + * @GST_TENSOR_DATA_TYPE_UINT4: unsigned 4 bit integer tensor data + * @GST_TENSOR_DATA_TYPE_UINT8: unsigned 8 bit integer tensor data + * @GST_TENSOR_DATA_TYPE_UINT16: unsigned 16 bit integer tensor data + * @GST_TENSOR_DATA_TYPE_UINT32: unsigned 32 bit integer tensor data + * @GST_TENSOR_DATA_TYPE_UINT64: unsigned 64 bit integer tensor data + * @GST_TENSOR_DATA_TYPE_FLOAT16: 16 bit floating point tensor data + * @GST_TENSOR_DATA_TYPE_FLOAT32: 32 bit floating point tensor data + * @GST_TENSOR_DATA_TYPE_FLOAT64: 64 bit floating point tensor data + * @GST_TENSOR_DATA_TYPE_BFLOAT16: "brain" 16 bit floating point tensor data * * Describe the type of data contain in the tensor. * @@ -53,20 +53,20 @@ */ typedef enum _GstTensorDataType { - GST_TENSOR_TYPE_INT4, - GST_TENSOR_TYPE_INT8, - GST_TENSOR_TYPE_INT16, - GST_TENSOR_TYPE_INT32, - GST_TENSOR_TYPE_INT64, - GST_TENSOR_TYPE_UINT4, - GST_TENSOR_TYPE_UINT8, - GST_TENSOR_TYPE_UINT16, - GST_TENSOR_TYPE_UINT32, - GST_TENSOR_TYPE_UINT64, - GST_TENSOR_TYPE_FLOAT16, - GST_TENSOR_TYPE_FLOAT32, - GST_TENSOR_TYPE_FLOAT64, - GST_TENSOR_TYPE_BFLOAT16, + GST_TENSOR_DATA_TYPE_INT4, + GST_TENSOR_DATA_TYPE_INT8, + GST_TENSOR_DATA_TYPE_INT16, + GST_TENSOR_DATA_TYPE_INT32, + GST_TENSOR_DATA_TYPE_INT64, + GST_TENSOR_DATA_TYPE_UINT4, + GST_TENSOR_DATA_TYPE_UINT8, + GST_TENSOR_DATA_TYPE_UINT16, + GST_TENSOR_DATA_TYPE_UINT32, + GST_TENSOR_DATA_TYPE_UINT64, + GST_TENSOR_DATA_TYPE_FLOAT16, + GST_TENSOR_DATA_TYPE_FLOAT32, + GST_TENSOR_DATA_TYPE_FLOAT64, + GST_TENSOR_DATA_TYPE_BFLOAT16 } GstTensorDataType; /** diff --git a/subprojects/gst-plugins-bad/gst/tensordecoders/gstssdobjectdetector.c b/subprojects/gst-plugins-bad/gst/tensordecoders/gstssdobjectdetector.c index 8901ea643c..1f69fea77a 100644 --- a/subprojects/gst-plugins-bad/gst/tensordecoders/gstssdobjectdetector.c +++ b/subprojects/gst-plugins-bad/gst/tensordecoders/gstssdobjectdetector.c @@ -386,14 +386,14 @@ gst_ssd_object_detector_transform_ip (GstBaseTransform * trans, GstBuffer * buf) guint index, TYPE * out) \ { \ switch (tensor->data_type) { \ - case GST_TENSOR_TYPE_FLOAT32: { \ + case GST_TENSOR_DATA_TYPE_FLOAT32: { \ float *f = (float *) map->data; \ if (sizeof(*f) * (index + 1) > map->size) \ return FALSE; \ *out = f[index]; \ break; \ } \ - case GST_TENSOR_TYPE_UINT32: { \ + case GST_TENSOR_DATA_TYPE_UINT32: { \ guint32 *u = (guint32 *) map->data; \ if (sizeof(*u) * (index + 1) > map->size) \ return FALSE; \