diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp index f9111a976a..308c65eff8 100644 --- a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp +++ b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp @@ -50,6 +50,8 @@ GstOnnxClient::GstOnnxClient ():session (nullptr), dest (nullptr), m_provider (GST_ONNX_EXECUTION_PROVIDER_CPU), inputImageFormat (GST_ML_INPUT_IMAGE_FORMAT_HWC), + inputDatatype (GST_TENSOR_TYPE_INT8), + inputDatatypeSize (sizeof (uint8_t)), fixedInputImageSize (false) { } @@ -83,6 +85,33 @@ GstOnnxClient::GstOnnxClient ():session (nullptr), return inputImageFormat; } + void GstOnnxClient::setInputImageDatatype(GstTensorType datatype) + { + inputDatatype = datatype; + switch (inputDatatype) { + case GST_TENSOR_TYPE_INT8: + inputDatatypeSize = sizeof (uint8_t); + break; + case GST_TENSOR_TYPE_INT16: + inputDatatypeSize = sizeof (uint16_t); + break; + case GST_TENSOR_TYPE_INT32: + inputDatatypeSize = sizeof (uint32_t); + break; + case GST_TENSOR_TYPE_FLOAT16: + inputDatatypeSize = 2; + break; + case GST_TENSOR_TYPE_FLOAT32: + inputDatatypeSize = sizeof (float); + break; + }; + } + + GstTensorType GstOnnxClient::getInputImageDatatype(void) + { + return inputDatatype; + } + std::vector < const char *>GstOnnxClient::genOutputNamesRaw (void) { if (!outputNames.empty () && outputNamesRaw.size () != outputNames.size ()) { @@ -229,9 +258,13 @@ GstOnnxClient::GstOnnxClient ():session (nullptr), int32_t newWidth = fixedInputImageSize ? width : vinfo.width; int32_t newHeight = fixedInputImageSize ? height : vinfo.height; + if (!fixedInputImageSize) { + GST_WARNING ("Allocating before knowing model input size"); + } + if (!dest || width * height < newWidth * newHeight) { delete[]dest; - dest = new uint8_t[newWidth * newHeight * channels]; + dest = new uint8_t[newWidth * newHeight * channels * inputDatatypeSize]; } width = newWidth; height = newHeight; @@ -378,13 +411,54 @@ GstOnnxClient::GstOnnxClient ():session (nullptr), default: break; } - size_t destIndex = 0; uint32_t stride = vinfo.stride[0]; + const size_t inputTensorSize = width * height * channels * inputDatatypeSize; + auto memoryInfo = + Ort::MemoryInfo::CreateCpu (OrtAllocatorType::OrtArenaAllocator, + OrtMemType::OrtMemTypeDefault); + + std::vector < Ort::Value > inputTensors; + + switch (inputDatatype) { + case GST_TENSOR_TYPE_INT8: + convert_image_remove_alpha (dest, inputImageFormat , srcPtr, + srcSamplesPerPixel, stride); + inputTensors.push_back (Ort::Value::CreateTensor < uint8_t > ( + memoryInfo, dest, inputTensorSize, inputDims.data (), + inputDims.size ())); + break; + case GST_TENSOR_TYPE_FLOAT32: { + convert_image_remove_alpha ((float*)dest, inputImageFormat , srcPtr, + srcSamplesPerPixel, stride); + inputTensors.push_back (Ort::Value::CreateTensor < float > ( + memoryInfo, (float*)dest, inputTensorSize, inputDims.data (), + inputDims.size ())); + } + break; + default: + break; + } + + std::vector < const char *>inputNames { inputName.get () }; + modelOutput = session->Run (Ort::RunOptions {nullptr}, + inputNames.data (), + inputTensors.data (), 1, outputNamesRaw.data (), + outputNamesRaw.size ()); + + return true; + } + + template < typename T> + void GstOnnxClient::convert_image_remove_alpha (T *dst, + GstMlInputImageFormat hwc, uint8_t **srcPtr, uint32_t srcSamplesPerPixel, + uint32_t stride) { + size_t destIndex = 0; + if (inputImageFormat == GST_ML_INPUT_IMAGE_FORMAT_HWC) { for (int32_t j = 0; j < height; ++j) { for (int32_t i = 0; i < width; ++i) { for (int32_t k = 0; k < channels; ++k) { - dest[destIndex++] = *srcPtr[k]; + dst[destIndex++] = (T)*srcPtr[k]; srcPtr[k] += srcSamplesPerPixel; } } @@ -394,11 +468,11 @@ GstOnnxClient::GstOnnxClient ():session (nullptr), } } else { size_t frameSize = width * height; - uint8_t *destPtr[3] = { dest, dest + frameSize, dest + 2 * frameSize }; + T *destPtr[3] = { dst, dst + frameSize, dst + 2 * frameSize }; for (int32_t j = 0; j < height; ++j) { for (int32_t i = 0; i < width; ++i) { for (int32_t k = 0; k < channels; ++k) { - destPtr[k][destIndex] = *srcPtr[k]; + destPtr[k][destIndex] = (T)*srcPtr[k]; srcPtr[k] += srcSamplesPerPixel; } destIndex++; @@ -408,25 +482,5 @@ GstOnnxClient::GstOnnxClient ():session (nullptr), srcPtr[k] += stride - srcSamplesPerPixel * width; } } - - const size_t inputTensorSize = width * height * channels; - auto memoryInfo = - Ort::MemoryInfo::CreateCpu (OrtAllocatorType::OrtArenaAllocator, - OrtMemType::OrtMemTypeDefault); - std::vector < Ort::Value > inputTensors; - inputTensors.push_back (Ort::Value::CreateTensor < uint8_t > (memoryInfo, - dest, inputTensorSize, inputDims.data (), inputDims.size ())); - std::vector < const char *>inputNames - { - inputName.get ()}; - - modelOutput = session->Run (Ort::RunOptions { - nullptr}, - inputNames.data (), - inputTensors.data (), 1, outputNamesRaw.data (), - outputNamesRaw.size ()); - - return true; } - } diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.h b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.h index feed54a826..3721b8968a 100644 --- a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.h +++ b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.h @@ -55,6 +55,8 @@ namespace GstOnnxNamespace { bool hasSession(void); void setInputImageFormat(GstMlInputImageFormat format); GstMlInputImageFormat getInputImageFormat(void); + void setInputImageDatatype(GstTensorType datatype); + GstTensorType getInputImageDatatype(void); std::vector < Ort::Value > run (uint8_t * img_data, GstVideoInfo vinfo); std::vector < const char *> genOutputNamesRaw(void); bool isFixedInputImageSize(void); @@ -63,6 +65,10 @@ namespace GstOnnxNamespace { GstTensorMeta* copy_tensors_to_meta(std::vector < Ort::Value > &outputs,GstBuffer* buffer); void parseDimensions(GstVideoInfo vinfo); private: + + template < typename T> + void convert_image_remove_alpha (T *dest, GstMlInputImageFormat hwc, + uint8_t **srcPtr, uint32_t srcSamplesPerPixel, uint32_t stride); bool doRun(uint8_t * img_data, GstVideoInfo vinfo, std::vector < Ort::Value > &modelOutput); Ort::Env env; Ort::Session * session; @@ -77,6 +83,8 @@ namespace GstOnnxNamespace { std::vector < Ort::AllocatedStringPtr > outputNames; std::vector < GQuark > outputIds; GstMlInputImageFormat inputImageFormat; + GstTensorType inputDatatype; + size_t inputDatatypeSize; bool fixedInputImageSize; }; } diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstonnxinference.cpp b/subprojects/gst-plugins-bad/ext/onnx/gstonnxinference.cpp index 54363a72ca..3256ea31bd 100644 --- a/subprojects/gst-plugins-bad/ext/onnx/gstonnxinference.cpp +++ b/subprojects/gst-plugins-bad/ext/onnx/gstonnxinference.cpp @@ -122,7 +122,8 @@ enum PROP_MODEL_FILE, PROP_INPUT_IMAGE_FORMAT, PROP_OPTIMIZATION_LEVEL, - PROP_EXECUTION_PROVIDER + PROP_EXECUTION_PROVIDER, + PROP_INPUT_IMAGE_DATATYPE }; #define GST_ONNX_INFERENCE_DEFAULT_EXECUTION_PROVIDER GST_ONNX_EXECUTION_PROVIDER_CPU @@ -169,6 +170,9 @@ GType gst_onnx_execution_provider_get_type (void); GType gst_ml_model_input_image_format_get_type (void); #define GST_TYPE_ML_MODEL_INPUT_IMAGE_FORMAT (gst_ml_model_input_image_format_get_type ()) +GType gst_onnx_model_input_image_datatype_get_type (void); +#define GST_TYPE_ONNX_MODEL_INPUT_IMAGE_DATATYPE (gst_onnx_model_input_image_datatype_get_type ()) + GType gst_onnx_optimization_level_get_type (void) { @@ -247,6 +251,28 @@ gst_ml_model_input_image_format_get_type (void) return ml_model_input_image_format; } +GType +gst_onnx_model_input_image_datatype_get_type (void) +{ + static GType model_input_image_datatype = 0; + + if (g_once_init_enter (&model_input_image_datatype)) { + static GEnumValue model_input_image_datatype_types[] = { + {GST_TENSOR_TYPE_INT8, "8 Bits integer", "int8"}, + {GST_TENSOR_TYPE_FLOAT32, "32 Bits floating points", "float"}, + {0, NULL, NULL}, + }; + + GType temp = g_enum_register_static ("GstTensorType", + model_input_image_datatype_types); + + g_once_init_leave (&model_input_image_datatype, temp); + } + + return model_input_image_datatype; + +} + static void gst_onnx_inference_class_init (GstOnnxInferenceClass * klass) { @@ -320,6 +346,22 @@ gst_onnx_inference_class_init (GstOnnxInferenceClass * klass) GST_ONNX_EXECUTION_PROVIDER_CPU, (GParamFlags) (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS))); + /** + * GstOnnxInference:input-image-datatype + * + * Temporary hack, this should be discovered from the model and exposed + * on sinkpad caps based on model contrains. + */ + + g_object_class_install_property (G_OBJECT_CLASS (klass), + PROP_INPUT_IMAGE_DATATYPE, + g_param_spec_enum ("input-image-datatype", + "Inference input image datatype", + "Datatype that will be used as an input for the inference", + GST_TYPE_ONNX_MODEL_INPUT_IMAGE_DATATYPE, + GST_TENSOR_TYPE_INT8, + (GParamFlags)(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS))); + gst_element_class_set_static_metadata (element_class, "onnxinference", "Filter/Effect/Video", "Apply neural network to video frames and create tensor output", @@ -387,6 +429,10 @@ gst_onnx_inference_set_property (GObject * object, guint prop_id, onnxClient->setInputImageFormat ((GstMlInputImageFormat) g_value_get_enum (value)); break; + case PROP_INPUT_IMAGE_DATATYPE: + onnxClient->setInputImageDatatype ((GstTensorType) + g_value_get_enum (value)); + break; default: G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec); break; @@ -413,6 +459,9 @@ gst_onnx_inference_get_property (GObject * object, guint prop_id, case PROP_INPUT_IMAGE_FORMAT: g_value_set_enum (value, onnxClient->getInputImageFormat ()); break; + case PROP_INPUT_IMAGE_DATATYPE: + g_value_set_enum (value, onnxClient->getInputImageDatatype ()); + break; default: G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec); break; @@ -547,6 +596,7 @@ gst_onnx_inference_process (GstBaseTransform * trans, GstBuffer * buf) auto meta = client->copy_tensors_to_meta (outputs, buf); if (!meta) return FALSE; + GST_TRACE_OBJECT (trans, "Num tensors:%d", meta->num_tensors); meta->batch_size = 1; } catch (Ort::Exception & ortex) {