onnxinference: Return caps based on model preference when possible

This should enable zero-copy when the model has the right type

Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/5885>
This commit is contained in:
Olivier Crête 2024-01-04 15:23:54 -05:00 committed by GStreamer Marge Bot
parent 83c2d30438
commit b1ac114ca5
3 changed files with 73 additions and 23 deletions

View file

@ -76,6 +76,11 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
return height;
}
int32_t GstOnnxClient::getChannels (void)
{
return channels;
}
bool GstOnnxClient::isFixedInputImageSize (void)
{
return fixedInputImageSize;
@ -446,11 +451,18 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
switch (inputDatatype) {
case GST_TENSOR_TYPE_INT8:
convert_image_remove_alpha (dest, inputImageFormat , srcPtr,
srcSamplesPerPixel, stride, (uint8_t)inputTensorOffset,
(uint8_t)inputTensorScale);
uint8_t *src_data;
if (inputTensorOffset == 00 && inputTensorScale == 1.0) {
src_data = img_data;
} else {
convert_image_remove_alpha (
dest, inputImageFormat, srcPtr, srcSamplesPerPixel, stride,
(uint8_t)inputTensorOffset, (uint8_t)inputTensorScale);
src_data = dest;
}
inputTensors.push_back (Ort::Value::CreateTensor < uint8_t > (
memoryInfo, dest, inputTensorSize, inputDims.data (),
memoryInfo, src_data, inputTensorSize, inputDims.data (),
inputDims.size ()));
break;
case GST_TENSOR_TYPE_FLOAT32: {

View file

@ -66,7 +66,9 @@ namespace GstOnnxNamespace {
bool isFixedInputImageSize(void);
int32_t getWidth(void);
int32_t getHeight(void);
GstTensorMeta* copy_tensors_to_meta(std::vector < Ort::Value > &outputs,GstBuffer* buffer);
int32_t getChannels (void);
GstTensorMeta *copy_tensors_to_meta (std::vector<Ort::Value> &outputs,
GstBuffer *buffer);
void parseDimensions(GstVideoInfo vinfo);
private:

View file

@ -474,34 +474,70 @@ gst_onnx_inference_transform_caps (GstBaseTransform *
GstOnnxInference *self = GST_ONNX_INFERENCE (trans);
auto onnxClient = GST_ONNX_CLIENT_MEMBER (self);
GstCaps *other_caps;
guint i;
GstCaps *restrictions;
if (!gst_onnx_inference_create_session (trans))
return NULL;
GST_LOG_OBJECT (self, "transforming caps %" GST_PTR_FORMAT, caps);
if (gst_base_transform_is_passthrough (trans)
|| (!onnxClient->isFixedInputImageSize ()))
if (gst_base_transform_is_passthrough (trans))
return gst_caps_ref (caps);
other_caps = gst_caps_new_empty ();
for (i = 0; i < gst_caps_get_size (caps); ++i) {
GstStructure *structure, *new_structure;
structure = gst_caps_get_structure (caps, i);
new_structure = gst_structure_copy (structure);
gst_structure_set (new_structure, "width", G_TYPE_INT,
restrictions = gst_caps_new_empty_simple ("video/x-raw");
if (onnxClient->isFixedInputImageSize ())
gst_caps_set_simple (restrictions, "width", G_TYPE_INT,
onnxClient->getWidth (), "height", G_TYPE_INT,
onnxClient->getHeight (), NULL);
GST_LOG_OBJECT (self,
"transformed structure %2d: %" GST_PTR_FORMAT " => %"
GST_PTR_FORMAT, i, structure, new_structure);
gst_caps_append_structure (other_caps, new_structure);
if (onnxClient->getInputImageDatatype() == GST_TENSOR_TYPE_INT8 &&
onnxClient->getInputImageScale() == 1.0 &&
onnxClient->getInputImageOffset() == 0.0) {
switch (onnxClient->getChannels()) {
case 1:
gst_caps_set_simple (restrictions, "format", G_TYPE_STRING, "GRAY8",
NULL);
break;
case 3:
switch (onnxClient->getInputImageFormat ()) {
case GST_ML_INPUT_IMAGE_FORMAT_HWC:
gst_caps_set_simple (restrictions, "format", G_TYPE_STRING, "RGB",
NULL);
break;
case GST_ML_INPUT_IMAGE_FORMAT_CHW:
gst_caps_set_simple (restrictions, "format", G_TYPE_STRING, "RGBP",
NULL);
break;
}
break;
case 4:
switch (onnxClient->getInputImageFormat ()) {
case GST_ML_INPUT_IMAGE_FORMAT_HWC:
gst_caps_set_simple (restrictions, "format", G_TYPE_STRING, "RGBA",
NULL);
break;
case GST_ML_INPUT_IMAGE_FORMAT_CHW:
gst_caps_set_simple (restrictions, "format", G_TYPE_STRING, "RGBAP",
NULL);
break;
}
break;
default:
GST_ERROR_OBJECT (self, "Invalid number of channels %d",
onnxClient->getChannels());
return NULL;
}
}
if (!gst_caps_is_empty (other_caps) && filter_caps) {
GstCaps *tmp = gst_caps_intersect_full (other_caps, filter_caps,
GST_DEBUG_OBJECT(self, "Applying caps restrictions: %" GST_PTR_FORMAT,
restrictions);
other_caps = gst_caps_intersect_full (caps, restrictions,
GST_CAPS_INTERSECT_FIRST);
gst_caps_unref (restrictions);
if (filter_caps) {
GstCaps *tmp = gst_caps_intersect_full (
other_caps, filter_caps, GST_CAPS_INTERSECT_FIRST);
gst_caps_replace (&other_caps, tmp);
gst_caps_unref (tmp);
}