mirror of
https://gitlab.freedesktop.org/gstreamer/gstreamer.git
synced 2024-11-23 02:01:12 +00:00
onnxinference: Return caps based on model preference when possible
This should enable zero-copy when the model has the right type Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/5885>
This commit is contained in:
parent
83c2d30438
commit
b1ac114ca5
3 changed files with 73 additions and 23 deletions
|
@ -76,6 +76,11 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
|
|||
return height;
|
||||
}
|
||||
|
||||
int32_t GstOnnxClient::getChannels (void)
|
||||
{
|
||||
return channels;
|
||||
}
|
||||
|
||||
bool GstOnnxClient::isFixedInputImageSize (void)
|
||||
{
|
||||
return fixedInputImageSize;
|
||||
|
@ -446,11 +451,18 @@ GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_paren
|
|||
|
||||
switch (inputDatatype) {
|
||||
case GST_TENSOR_TYPE_INT8:
|
||||
convert_image_remove_alpha (dest, inputImageFormat , srcPtr,
|
||||
srcSamplesPerPixel, stride, (uint8_t)inputTensorOffset,
|
||||
(uint8_t)inputTensorScale);
|
||||
uint8_t *src_data;
|
||||
if (inputTensorOffset == 00 && inputTensorScale == 1.0) {
|
||||
src_data = img_data;
|
||||
} else {
|
||||
convert_image_remove_alpha (
|
||||
dest, inputImageFormat, srcPtr, srcSamplesPerPixel, stride,
|
||||
(uint8_t)inputTensorOffset, (uint8_t)inputTensorScale);
|
||||
src_data = dest;
|
||||
}
|
||||
|
||||
inputTensors.push_back (Ort::Value::CreateTensor < uint8_t > (
|
||||
memoryInfo, dest, inputTensorSize, inputDims.data (),
|
||||
memoryInfo, src_data, inputTensorSize, inputDims.data (),
|
||||
inputDims.size ()));
|
||||
break;
|
||||
case GST_TENSOR_TYPE_FLOAT32: {
|
||||
|
|
|
@ -66,7 +66,9 @@ namespace GstOnnxNamespace {
|
|||
bool isFixedInputImageSize(void);
|
||||
int32_t getWidth(void);
|
||||
int32_t getHeight(void);
|
||||
GstTensorMeta* copy_tensors_to_meta(std::vector < Ort::Value > &outputs,GstBuffer* buffer);
|
||||
int32_t getChannels (void);
|
||||
GstTensorMeta *copy_tensors_to_meta (std::vector<Ort::Value> &outputs,
|
||||
GstBuffer *buffer);
|
||||
void parseDimensions(GstVideoInfo vinfo);
|
||||
private:
|
||||
|
||||
|
|
|
@ -474,34 +474,70 @@ gst_onnx_inference_transform_caps (GstBaseTransform *
|
|||
GstOnnxInference *self = GST_ONNX_INFERENCE (trans);
|
||||
auto onnxClient = GST_ONNX_CLIENT_MEMBER (self);
|
||||
GstCaps *other_caps;
|
||||
guint i;
|
||||
GstCaps *restrictions;
|
||||
|
||||
if (!gst_onnx_inference_create_session (trans))
|
||||
return NULL;
|
||||
GST_LOG_OBJECT (self, "transforming caps %" GST_PTR_FORMAT, caps);
|
||||
|
||||
if (gst_base_transform_is_passthrough (trans)
|
||||
|| (!onnxClient->isFixedInputImageSize ()))
|
||||
if (gst_base_transform_is_passthrough (trans))
|
||||
return gst_caps_ref (caps);
|
||||
|
||||
other_caps = gst_caps_new_empty ();
|
||||
for (i = 0; i < gst_caps_get_size (caps); ++i) {
|
||||
GstStructure *structure, *new_structure;
|
||||
restrictions = gst_caps_new_empty_simple ("video/x-raw");
|
||||
if (onnxClient->isFixedInputImageSize ())
|
||||
gst_caps_set_simple (restrictions, "width", G_TYPE_INT,
|
||||
onnxClient->getWidth (), "height", G_TYPE_INT,
|
||||
onnxClient->getHeight (), NULL);
|
||||
|
||||
structure = gst_caps_get_structure (caps, i);
|
||||
new_structure = gst_structure_copy (structure);
|
||||
gst_structure_set (new_structure, "width", G_TYPE_INT,
|
||||
onnxClient->getWidth (), "height", G_TYPE_INT,
|
||||
onnxClient->getHeight (), NULL);
|
||||
GST_LOG_OBJECT (self,
|
||||
"transformed structure %2d: %" GST_PTR_FORMAT " => %"
|
||||
GST_PTR_FORMAT, i, structure, new_structure);
|
||||
gst_caps_append_structure (other_caps, new_structure);
|
||||
if (onnxClient->getInputImageDatatype() == GST_TENSOR_TYPE_INT8 &&
|
||||
onnxClient->getInputImageScale() == 1.0 &&
|
||||
onnxClient->getInputImageOffset() == 0.0) {
|
||||
switch (onnxClient->getChannels()) {
|
||||
case 1:
|
||||
gst_caps_set_simple (restrictions, "format", G_TYPE_STRING, "GRAY8",
|
||||
NULL);
|
||||
break;
|
||||
case 3:
|
||||
switch (onnxClient->getInputImageFormat ()) {
|
||||
case GST_ML_INPUT_IMAGE_FORMAT_HWC:
|
||||
gst_caps_set_simple (restrictions, "format", G_TYPE_STRING, "RGB",
|
||||
NULL);
|
||||
break;
|
||||
case GST_ML_INPUT_IMAGE_FORMAT_CHW:
|
||||
gst_caps_set_simple (restrictions, "format", G_TYPE_STRING, "RGBP",
|
||||
NULL);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case 4:
|
||||
switch (onnxClient->getInputImageFormat ()) {
|
||||
case GST_ML_INPUT_IMAGE_FORMAT_HWC:
|
||||
gst_caps_set_simple (restrictions, "format", G_TYPE_STRING, "RGBA",
|
||||
NULL);
|
||||
break;
|
||||
case GST_ML_INPUT_IMAGE_FORMAT_CHW:
|
||||
gst_caps_set_simple (restrictions, "format", G_TYPE_STRING, "RGBAP",
|
||||
NULL);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
GST_ERROR_OBJECT (self, "Invalid number of channels %d",
|
||||
onnxClient->getChannels());
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
if (!gst_caps_is_empty (other_caps) && filter_caps) {
|
||||
GstCaps *tmp = gst_caps_intersect_full (other_caps, filter_caps,
|
||||
GST_CAPS_INTERSECT_FIRST);
|
||||
GST_DEBUG_OBJECT(self, "Applying caps restrictions: %" GST_PTR_FORMAT,
|
||||
restrictions);
|
||||
|
||||
other_caps = gst_caps_intersect_full (caps, restrictions,
|
||||
GST_CAPS_INTERSECT_FIRST);
|
||||
gst_caps_unref (restrictions);
|
||||
|
||||
if (filter_caps) {
|
||||
GstCaps *tmp = gst_caps_intersect_full (
|
||||
other_caps, filter_caps, GST_CAPS_INTERSECT_FIRST);
|
||||
gst_caps_replace (&other_caps, tmp);
|
||||
gst_caps_unref (tmp);
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue