mirror of
https://gitlab.freedesktop.org/gstreamer/gstreamer.git
synced 2025-02-02 12:32:29 +00:00
onnx: Add more tensor data types
Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/6001>
This commit is contained in:
parent
e3d8168a5a
commit
13de5160be
3 changed files with 35 additions and 11 deletions
|
@ -53,7 +53,7 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
|
|||
dest (nullptr),
|
||||
m_provider (GST_ONNX_EXECUTION_PROVIDER_CPU),
|
||||
inputImageFormat (GST_ML_INPUT_IMAGE_FORMAT_HWC),
|
||||
inputDatatype (GST_TENSOR_TYPE_INT8),
|
||||
inputDatatype (GST_TENSOR_TYPE_UINT8),
|
||||
inputDatatypeSize (sizeof (uint8_t)),
|
||||
fixedInputImageSize (false),
|
||||
inputTensorOffset (0.0),
|
||||
|
@ -100,21 +100,27 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
|
|||
{
|
||||
inputDatatype = datatype;
|
||||
switch (inputDatatype) {
|
||||
case GST_TENSOR_TYPE_INT8:
|
||||
case GST_TENSOR_TYPE_UINT8:
|
||||
inputDatatypeSize = sizeof (uint8_t);
|
||||
break;
|
||||
case GST_TENSOR_TYPE_INT16:
|
||||
case GST_TENSOR_TYPE_UINT16:
|
||||
inputDatatypeSize = sizeof (uint16_t);
|
||||
break;
|
||||
case GST_TENSOR_TYPE_INT32:
|
||||
case GST_TENSOR_TYPE_UINT32:
|
||||
inputDatatypeSize = sizeof (uint32_t);
|
||||
break;
|
||||
case GST_TENSOR_TYPE_INT32:
|
||||
inputDatatypeSize = sizeof (int32_t);
|
||||
break;
|
||||
case GST_TENSOR_TYPE_FLOAT16:
|
||||
inputDatatypeSize = 2;
|
||||
break;
|
||||
case GST_TENSOR_TYPE_FLOAT32:
|
||||
inputDatatypeSize = sizeof (float);
|
||||
break;
|
||||
default:
|
||||
g_error ("Data type %d not handled", inputDatatype);
|
||||
break;
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -241,7 +247,7 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
|
|||
|
||||
switch (elementType) {
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
|
||||
setInputImageDatatype(GST_TENSOR_TYPE_INT8);
|
||||
setInputImageDatatype(GST_TENSOR_TYPE_UINT8);
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
|
||||
setInputImageDatatype(GST_TENSOR_TYPE_FLOAT32);
|
||||
|
@ -450,7 +456,7 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
|
|||
std::vector < Ort::Value > inputTensors;
|
||||
|
||||
switch (inputDatatype) {
|
||||
case GST_TENSOR_TYPE_INT8:
|
||||
case GST_TENSOR_TYPE_UINT8:
|
||||
uint8_t *src_data;
|
||||
if (inputTensorOffset == 00 && inputTensorScale == 1.0) {
|
||||
src_data = img_data;
|
||||
|
|
|
@ -489,7 +489,7 @@ gst_onnx_inference_transform_caps (GstBaseTransform *
|
|||
onnxClient->getWidth (), "height", G_TYPE_INT,
|
||||
onnxClient->getHeight (), NULL);
|
||||
|
||||
if (onnxClient->getInputImageDatatype() == GST_TENSOR_TYPE_INT8 &&
|
||||
if (onnxClient->getInputImageDatatype() == GST_TENSOR_TYPE_UINT8 &&
|
||||
onnxClient->getInputImageScale() == 1.0 &&
|
||||
onnxClient->getInputImageOffset() == 0.0) {
|
||||
switch (onnxClient->getChannels()) {
|
||||
|
|
|
@ -27,21 +27,39 @@
|
|||
/**
|
||||
* GstTensorType:
|
||||
*
|
||||
* @GST_TENSOR_TYPE_INT8 8 bit integer tensor data
|
||||
* @GST_TENSOR_TYPE_INT16 16 bit integer tensor data
|
||||
* @GST_TENSOR_TYPE_INT32 32 bit integer tensor data
|
||||
* @GST_TENSOR_TYPE_INT4 signed 4 bit integer tensor data
|
||||
* @GST_TENSOR_TYPE_INT8 signed 8 bit integer tensor data
|
||||
* @GST_TENSOR_TYPE_INT16 signed 16 bit integer tensor data
|
||||
* @GST_TENSOR_TYPE_INT32 signed 32 bit integer tensor data
|
||||
* @GST_TENSOR_TYPE_INT64 signed 64 bit integer tensor data
|
||||
* @GST_TENSOR_TYPE_UINT4 unsigned 4 bit integer tensor data
|
||||
* @GST_TENSOR_TYPE_UINT8 unsigned 8 bit integer tensor data
|
||||
* @GST_TENSOR_TYPE_UINT16 unsigned 16 bit integer tensor data
|
||||
* @GST_TENSOR_TYPE_UINT32 unsigned 32 bit integer tensor data
|
||||
* @GST_TENSOR_TYPE_UINT64 unsigned 64 bit integer tensor data
|
||||
* @GST_TENSOR_TYPE_FLOAT16 16 bit floating point tensor data
|
||||
* @GST_TENSOR_TYPE_FLOAT32 32 bit floating point tensor data
|
||||
* @GST_TENSOR_TYPE_FLOAT64 64 bit floating point tensor data
|
||||
* @GST_TENSOR_TYPE_BFLOAT16 "brain" 16 bit floating point tensor data
|
||||
*
|
||||
* Since: 1.24
|
||||
*/
|
||||
typedef enum _GstTensorType
|
||||
{
|
||||
GST_TENSOR_TYPE_INT4,
|
||||
GST_TENSOR_TYPE_INT8,
|
||||
GST_TENSOR_TYPE_INT16,
|
||||
GST_TENSOR_TYPE_INT32,
|
||||
GST_TENSOR_TYPE_INT64,
|
||||
GST_TENSOR_TYPE_UINT4,
|
||||
GST_TENSOR_TYPE_UINT8,
|
||||
GST_TENSOR_TYPE_UINT16,
|
||||
GST_TENSOR_TYPE_UINT32,
|
||||
GST_TENSOR_TYPE_UINT64,
|
||||
GST_TENSOR_TYPE_FLOAT16,
|
||||
GST_TENSOR_TYPE_FLOAT32
|
||||
GST_TENSOR_TYPE_FLOAT32,
|
||||
GST_TENSOR_TYPE_FLOAT64,
|
||||
GST_TENSOR_TYPE_BFLOAT16,
|
||||
} GstTensorType;
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue