From 54b361c55409defca86f20a129e917a7cd65c2ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Cr=C3=AAte?= Date: Thu, 4 Jan 2024 14:56:41 -0500 Subject: [PATCH] onnx: Extract data type from the model itself Part-of: --- .../ext/onnx/gstonnxclient.cpp | 31 ++++++++--- .../gst-plugins-bad/ext/onnx/gstonnxclient.h | 2 +- .../ext/onnx/gstonnxinference.cpp | 53 +------------------ 3 files changed, 26 insertions(+), 60 deletions(-) diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp index f3459dbe06..8526415e12 100644 --- a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp +++ b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp @@ -227,6 +227,21 @@ GstOnnxClient::GstOnnxClient ():session (nullptr), GST_DEBUG ("Number of Output Nodes: %d", (gint) session->GetOutputCount ()); + ONNXTensorElementDataType elementType = + inputTypeInfo.GetTensorTypeAndShapeInfo ().GetElementType (); + + switch (elementType) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + setInputImageDatatype(GST_TENSOR_TYPE_INT8); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + setInputImageDatatype(GST_TENSOR_TYPE_FLOAT32); + break; + default: + GST_ERROR ("Only input tensors of type int8 and float are supported"); + return false; + } + Ort::AllocatorWithDefaultOptions allocator; auto input_name = session->GetInputNameAllocated (0, allocator); GST_DEBUG ("Input name: %s", input_name.get ()); @@ -252,20 +267,20 @@ GstOnnxClient::GstOnnxClient ():session (nullptr), Ort::GetApi ().ReleaseStatus (status); return false; - } else { + } for (auto & name:outputNamesRaw) { - Ort::AllocatedStringPtr res = - metaData.LookupCustomMetadataMapAllocated (name, ortAllocator); - if (res) { - GQuark quark = g_quark_from_string (res.get ()); - outputIds.push_back (quark); - } else { + Ort::AllocatedStringPtr res = + metaData.LookupCustomMetadataMapAllocated (name, ortAllocator); + if (res) + { + GQuark quark = g_quark_from_string (res.get ()); + outputIds.push_back (quark); + } else { GST_ERROR ("Failed to look up id for key %s", name); return false; } } - } } catch (Ort::Exception & ortex) { GST_ERROR ("%s", ortex.what ()); diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.h b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.h index f0a9b8bb89..d6d2390210 100644 --- a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.h +++ b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.h @@ -55,7 +55,6 @@ namespace GstOnnxNamespace { bool hasSession(void); void setInputImageFormat(GstMlInputImageFormat format); GstMlInputImageFormat getInputImageFormat(void); - void setInputImageDatatype(GstTensorType datatype); GstTensorType getInputImageDatatype(void); void setInputImageOffset (float offset); float getInputImageOffset (); @@ -70,6 +69,7 @@ namespace GstOnnxNamespace { void parseDimensions(GstVideoInfo vinfo); private: + void setInputImageDatatype(GstTensorType datatype); template < typename T> void convert_image_remove_alpha (T *dest, GstMlInputImageFormat hwc, uint8_t **srcPtr, uint32_t srcSamplesPerPixel, uint32_t stride, T offset, T div); diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstonnxinference.cpp b/subprojects/gst-plugins-bad/ext/onnx/gstonnxinference.cpp index 349b3108f0..c182b1da22 100644 --- a/subprojects/gst-plugins-bad/ext/onnx/gstonnxinference.cpp +++ b/subprojects/gst-plugins-bad/ext/onnx/gstonnxinference.cpp @@ -36,9 +36,9 @@ * https://gitlab.collabora.com/gstreamer/onnx-models * * GST_DEBUG=ssdobjectdetector:5 \ - * gst-launch-1.0 multifilesrc location=onnx-models/images/bus.jpg ! \ + * gst-launch-1.0 filesrc location=onnx-models/images/bus.jpg ! \ * jpegdec ! videoconvert ! onnxinference execution-provider=cpu model-file=onnx-models/models/ssd_mobilenet_v1_coco.onnx ! \ - * ssdobjectdetector label-file=onnx-models/labels/COCO_classes.txt ! videoconvert ! autovideosink + * ssdobjectdetector label-file=onnx-models/labels/COCO_classes.txt ! videoconvert ! imagefreeze ! autovideosink * * * Note: in order for downstream tensor decoders to correctly parse the tensor @@ -100,7 +100,6 @@ enum PROP_INPUT_IMAGE_FORMAT, PROP_OPTIMIZATION_LEVEL, PROP_EXECUTION_PROVIDER, - PROP_INPUT_IMAGE_DATATYPE, PROP_INPUT_OFFSET, PROP_INPUT_SCALE }; @@ -149,9 +148,6 @@ GType gst_onnx_execution_provider_get_type (void); GType gst_ml_model_input_image_format_get_type (void); #define GST_TYPE_ML_MODEL_INPUT_IMAGE_FORMAT (gst_ml_model_input_image_format_get_type ()) -GType gst_onnx_model_input_image_datatype_get_type (void); -#define GST_TYPE_ONNX_MODEL_INPUT_IMAGE_DATATYPE (gst_onnx_model_input_image_datatype_get_type ()) - GType gst_onnx_optimization_level_get_type (void) { @@ -230,28 +226,6 @@ gst_ml_model_input_image_format_get_type (void) return ml_model_input_image_format; } -GType -gst_onnx_model_input_image_datatype_get_type (void) -{ - static GType model_input_image_datatype = 0; - - if (g_once_init_enter (&model_input_image_datatype)) { - static GEnumValue model_input_image_datatype_types[] = { - {GST_TENSOR_TYPE_INT8, "8 Bits integer", "int8"}, - {GST_TENSOR_TYPE_FLOAT32, "32 Bits floating points", "float"}, - {0, NULL, NULL}, - }; - - GType temp = g_enum_register_static ("GstTensorType", - model_input_image_datatype_types); - - g_once_init_leave (&model_input_image_datatype, temp); - } - - return model_input_image_datatype; - -} - static void gst_onnx_inference_class_init (GstOnnxInferenceClass * klass) { @@ -325,22 +299,6 @@ gst_onnx_inference_class_init (GstOnnxInferenceClass * klass) GST_ONNX_EXECUTION_PROVIDER_CPU, (GParamFlags) (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS))); - /** - * GstOnnxInference:input-image-datatype - * - * Temporary hack, this should be discovered from the model and exposed - * on sinkpad caps based on model contrains. - */ - - g_object_class_install_property (G_OBJECT_CLASS (klass), - PROP_INPUT_IMAGE_DATATYPE, - g_param_spec_enum ("input-image-datatype", - "Inference input image datatype", - "Datatype that will be used as an input for the inference", - GST_TYPE_ONNX_MODEL_INPUT_IMAGE_DATATYPE, - GST_TENSOR_TYPE_INT8, - (GParamFlags)(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS))); - g_object_class_install_property (G_OBJECT_CLASS (klass), PROP_INPUT_OFFSET, g_param_spec_float ("input-tensor-offset", @@ -425,10 +383,6 @@ gst_onnx_inference_set_property (GObject * object, guint prop_id, onnxClient->setInputImageFormat ((GstMlInputImageFormat) g_value_get_enum (value)); break; - case PROP_INPUT_IMAGE_DATATYPE: - onnxClient->setInputImageDatatype ((GstTensorType) - g_value_get_enum (value)); - break; case PROP_INPUT_OFFSET: onnxClient->setInputImageOffset (g_value_get_float (value)); break; @@ -461,9 +415,6 @@ gst_onnx_inference_get_property (GObject * object, guint prop_id, case PROP_INPUT_IMAGE_FORMAT: g_value_set_enum (value, onnxClient->getInputImageFormat ()); break; - case PROP_INPUT_IMAGE_DATATYPE: - g_value_set_enum (value, onnxClient->getInputImageDatatype ()); - break; case PROP_INPUT_OFFSET: g_value_set_float (value, onnxClient->getInputImageOffset ()); break;