diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp index 4542f1ecd8..98f21a66fe 100644 --- a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp +++ b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.cpp @@ -76,6 +76,11 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren return height; } + int32_t GstOnnxClient::getChannels (void) + { + return channels; + } + bool GstOnnxClient::isFixedInputImageSize (void) { return fixedInputImageSize; @@ -446,11 +451,18 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren switch (inputDatatype) { case GST_TENSOR_TYPE_INT8: - convert_image_remove_alpha (dest, inputImageFormat , srcPtr, - srcSamplesPerPixel, stride, (uint8_t)inputTensorOffset, - (uint8_t)inputTensorScale); + uint8_t *src_data; + if (inputTensorOffset == 00 && inputTensorScale == 1.0) { + src_data = img_data; + } else { + convert_image_remove_alpha ( + dest, inputImageFormat, srcPtr, srcSamplesPerPixel, stride, + (uint8_t)inputTensorOffset, (uint8_t)inputTensorScale); + src_data = dest; + } + inputTensors.push_back (Ort::Value::CreateTensor < uint8_t > ( - memoryInfo, dest, inputTensorSize, inputDims.data (), + memoryInfo, src_data, inputTensorSize, inputDims.data (), inputDims.size ())); break; case GST_TENSOR_TYPE_FLOAT32: { diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.h b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.h index f8aa38f122..0e4f50d68e 100644 --- a/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.h +++ b/subprojects/gst-plugins-bad/ext/onnx/gstonnxclient.h @@ -66,7 +66,9 @@ namespace GstOnnxNamespace { bool isFixedInputImageSize(void); int32_t getWidth(void); int32_t getHeight(void); - GstTensorMeta* copy_tensors_to_meta(std::vector < Ort::Value > &outputs,GstBuffer* buffer); + int32_t getChannels (void); + GstTensorMeta *copy_tensors_to_meta (std::vector &outputs, + GstBuffer *buffer); void parseDimensions(GstVideoInfo vinfo); private: diff --git a/subprojects/gst-plugins-bad/ext/onnx/gstonnxinference.cpp b/subprojects/gst-plugins-bad/ext/onnx/gstonnxinference.cpp index aa739055d8..21eaade85a 100644 --- a/subprojects/gst-plugins-bad/ext/onnx/gstonnxinference.cpp +++ b/subprojects/gst-plugins-bad/ext/onnx/gstonnxinference.cpp @@ -474,34 +474,70 @@ gst_onnx_inference_transform_caps (GstBaseTransform * GstOnnxInference *self = GST_ONNX_INFERENCE (trans); auto onnxClient = GST_ONNX_CLIENT_MEMBER (self); GstCaps *other_caps; - guint i; + GstCaps *restrictions; if (!gst_onnx_inference_create_session (trans)) return NULL; GST_LOG_OBJECT (self, "transforming caps %" GST_PTR_FORMAT, caps); - if (gst_base_transform_is_passthrough (trans) - || (!onnxClient->isFixedInputImageSize ())) + if (gst_base_transform_is_passthrough (trans)) return gst_caps_ref (caps); - other_caps = gst_caps_new_empty (); - for (i = 0; i < gst_caps_get_size (caps); ++i) { - GstStructure *structure, *new_structure; + restrictions = gst_caps_new_empty_simple ("video/x-raw"); + if (onnxClient->isFixedInputImageSize ()) + gst_caps_set_simple (restrictions, "width", G_TYPE_INT, + onnxClient->getWidth (), "height", G_TYPE_INT, + onnxClient->getHeight (), NULL); - structure = gst_caps_get_structure (caps, i); - new_structure = gst_structure_copy (structure); - gst_structure_set (new_structure, "width", G_TYPE_INT, - onnxClient->getWidth (), "height", G_TYPE_INT, - onnxClient->getHeight (), NULL); - GST_LOG_OBJECT (self, - "transformed structure %2d: %" GST_PTR_FORMAT " => %" - GST_PTR_FORMAT, i, structure, new_structure); - gst_caps_append_structure (other_caps, new_structure); + if (onnxClient->getInputImageDatatype() == GST_TENSOR_TYPE_INT8 && + onnxClient->getInputImageScale() == 1.0 && + onnxClient->getInputImageOffset() == 0.0) { + switch (onnxClient->getChannels()) { + case 1: + gst_caps_set_simple (restrictions, "format", G_TYPE_STRING, "GRAY8", + NULL); + break; + case 3: + switch (onnxClient->getInputImageFormat ()) { + case GST_ML_INPUT_IMAGE_FORMAT_HWC: + gst_caps_set_simple (restrictions, "format", G_TYPE_STRING, "RGB", + NULL); + break; + case GST_ML_INPUT_IMAGE_FORMAT_CHW: + gst_caps_set_simple (restrictions, "format", G_TYPE_STRING, "RGBP", + NULL); + break; + } + break; + case 4: + switch (onnxClient->getInputImageFormat ()) { + case GST_ML_INPUT_IMAGE_FORMAT_HWC: + gst_caps_set_simple (restrictions, "format", G_TYPE_STRING, "RGBA", + NULL); + break; + case GST_ML_INPUT_IMAGE_FORMAT_CHW: + gst_caps_set_simple (restrictions, "format", G_TYPE_STRING, "RGBAP", + NULL); + break; + } + break; + default: + GST_ERROR_OBJECT (self, "Invalid number of channels %d", + onnxClient->getChannels()); + return NULL; + } } - if (!gst_caps_is_empty (other_caps) && filter_caps) { - GstCaps *tmp = gst_caps_intersect_full (other_caps, filter_caps, - GST_CAPS_INTERSECT_FIRST); + GST_DEBUG_OBJECT(self, "Applying caps restrictions: %" GST_PTR_FORMAT, + restrictions); + + other_caps = gst_caps_intersect_full (caps, restrictions, + GST_CAPS_INTERSECT_FIRST); + gst_caps_unref (restrictions); + + if (filter_caps) { + GstCaps *tmp = gst_caps_intersect_full ( + other_caps, filter_caps, GST_CAPS_INTERSECT_FIRST); gst_caps_replace (&other_caps, tmp); gst_caps_unref (tmp); }