mirror of
https://gitlab.freedesktop.org/gstreamer/gstreamer.git
synced 2024-11-23 10:11:08 +00:00
onnx: Extract data type from the model itself
Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/5885>
This commit is contained in:
parent
2539bb0b1d
commit
54b361c554
3 changed files with 26 additions and 60 deletions
|
@ -227,6 +227,21 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
|
||||||
GST_DEBUG ("Number of Output Nodes: %d",
|
GST_DEBUG ("Number of Output Nodes: %d",
|
||||||
(gint) session->GetOutputCount ());
|
(gint) session->GetOutputCount ());
|
||||||
|
|
||||||
|
ONNXTensorElementDataType elementType =
|
||||||
|
inputTypeInfo.GetTensorTypeAndShapeInfo ().GetElementType ();
|
||||||
|
|
||||||
|
switch (elementType) {
|
||||||
|
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
|
||||||
|
setInputImageDatatype(GST_TENSOR_TYPE_INT8);
|
||||||
|
break;
|
||||||
|
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
|
||||||
|
setInputImageDatatype(GST_TENSOR_TYPE_FLOAT32);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GST_ERROR ("Only input tensors of type int8 and float are supported");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
Ort::AllocatorWithDefaultOptions allocator;
|
Ort::AllocatorWithDefaultOptions allocator;
|
||||||
auto input_name = session->GetInputNameAllocated (0, allocator);
|
auto input_name = session->GetInputNameAllocated (0, allocator);
|
||||||
GST_DEBUG ("Input name: %s", input_name.get ());
|
GST_DEBUG ("Input name: %s", input_name.get ());
|
||||||
|
@ -252,20 +267,20 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
|
||||||
Ort::GetApi ().ReleaseStatus (status);
|
Ort::GetApi ().ReleaseStatus (status);
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
} else {
|
}
|
||||||
for (auto & name:outputNamesRaw) {
|
for (auto & name:outputNamesRaw) {
|
||||||
Ort::AllocatedStringPtr res =
|
Ort::AllocatedStringPtr res =
|
||||||
metaData.LookupCustomMetadataMapAllocated (name, ortAllocator);
|
metaData.LookupCustomMetadataMapAllocated (name, ortAllocator);
|
||||||
if (res) {
|
if (res)
|
||||||
GQuark quark = g_quark_from_string (res.get ());
|
{
|
||||||
outputIds.push_back (quark);
|
GQuark quark = g_quark_from_string (res.get ());
|
||||||
} else {
|
outputIds.push_back (quark);
|
||||||
|
} else {
|
||||||
GST_ERROR ("Failed to look up id for key %s", name);
|
GST_ERROR ("Failed to look up id for key %s", name);
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
catch (Ort::Exception & ortex) {
|
catch (Ort::Exception & ortex) {
|
||||||
GST_ERROR ("%s", ortex.what ());
|
GST_ERROR ("%s", ortex.what ());
|
||||||
|
|
|
@ -55,7 +55,6 @@ namespace GstOnnxNamespace {
|
||||||
bool hasSession(void);
|
bool hasSession(void);
|
||||||
void setInputImageFormat(GstMlInputImageFormat format);
|
void setInputImageFormat(GstMlInputImageFormat format);
|
||||||
GstMlInputImageFormat getInputImageFormat(void);
|
GstMlInputImageFormat getInputImageFormat(void);
|
||||||
void setInputImageDatatype(GstTensorType datatype);
|
|
||||||
GstTensorType getInputImageDatatype(void);
|
GstTensorType getInputImageDatatype(void);
|
||||||
void setInputImageOffset (float offset);
|
void setInputImageOffset (float offset);
|
||||||
float getInputImageOffset ();
|
float getInputImageOffset ();
|
||||||
|
@ -70,6 +69,7 @@ namespace GstOnnxNamespace {
|
||||||
void parseDimensions(GstVideoInfo vinfo);
|
void parseDimensions(GstVideoInfo vinfo);
|
||||||
private:
|
private:
|
||||||
|
|
||||||
|
void setInputImageDatatype(GstTensorType datatype);
|
||||||
template < typename T>
|
template < typename T>
|
||||||
void convert_image_remove_alpha (T *dest, GstMlInputImageFormat hwc,
|
void convert_image_remove_alpha (T *dest, GstMlInputImageFormat hwc,
|
||||||
uint8_t **srcPtr, uint32_t srcSamplesPerPixel, uint32_t stride, T offset, T div);
|
uint8_t **srcPtr, uint32_t srcSamplesPerPixel, uint32_t stride, T offset, T div);
|
||||||
|
|
|
@ -36,9 +36,9 @@
|
||||||
* https://gitlab.collabora.com/gstreamer/onnx-models
|
* https://gitlab.collabora.com/gstreamer/onnx-models
|
||||||
*
|
*
|
||||||
* GST_DEBUG=ssdobjectdetector:5 \
|
* GST_DEBUG=ssdobjectdetector:5 \
|
||||||
* gst-launch-1.0 multifilesrc location=onnx-models/images/bus.jpg ! \
|
* gst-launch-1.0 filesrc location=onnx-models/images/bus.jpg ! \
|
||||||
* jpegdec ! videoconvert ! onnxinference execution-provider=cpu model-file=onnx-models/models/ssd_mobilenet_v1_coco.onnx ! \
|
* jpegdec ! videoconvert ! onnxinference execution-provider=cpu model-file=onnx-models/models/ssd_mobilenet_v1_coco.onnx ! \
|
||||||
* ssdobjectdetector label-file=onnx-models/labels/COCO_classes.txt ! videoconvert ! autovideosink
|
* ssdobjectdetector label-file=onnx-models/labels/COCO_classes.txt ! videoconvert ! imagefreeze ! autovideosink
|
||||||
*
|
*
|
||||||
*
|
*
|
||||||
* Note: in order for downstream tensor decoders to correctly parse the tensor
|
* Note: in order for downstream tensor decoders to correctly parse the tensor
|
||||||
|
@ -100,7 +100,6 @@ enum
|
||||||
PROP_INPUT_IMAGE_FORMAT,
|
PROP_INPUT_IMAGE_FORMAT,
|
||||||
PROP_OPTIMIZATION_LEVEL,
|
PROP_OPTIMIZATION_LEVEL,
|
||||||
PROP_EXECUTION_PROVIDER,
|
PROP_EXECUTION_PROVIDER,
|
||||||
PROP_INPUT_IMAGE_DATATYPE,
|
|
||||||
PROP_INPUT_OFFSET,
|
PROP_INPUT_OFFSET,
|
||||||
PROP_INPUT_SCALE
|
PROP_INPUT_SCALE
|
||||||
};
|
};
|
||||||
|
@ -149,9 +148,6 @@ GType gst_onnx_execution_provider_get_type (void);
|
||||||
GType gst_ml_model_input_image_format_get_type (void);
|
GType gst_ml_model_input_image_format_get_type (void);
|
||||||
#define GST_TYPE_ML_MODEL_INPUT_IMAGE_FORMAT (gst_ml_model_input_image_format_get_type ())
|
#define GST_TYPE_ML_MODEL_INPUT_IMAGE_FORMAT (gst_ml_model_input_image_format_get_type ())
|
||||||
|
|
||||||
GType gst_onnx_model_input_image_datatype_get_type (void);
|
|
||||||
#define GST_TYPE_ONNX_MODEL_INPUT_IMAGE_DATATYPE (gst_onnx_model_input_image_datatype_get_type ())
|
|
||||||
|
|
||||||
GType
|
GType
|
||||||
gst_onnx_optimization_level_get_type (void)
|
gst_onnx_optimization_level_get_type (void)
|
||||||
{
|
{
|
||||||
|
@ -230,28 +226,6 @@ gst_ml_model_input_image_format_get_type (void)
|
||||||
return ml_model_input_image_format;
|
return ml_model_input_image_format;
|
||||||
}
|
}
|
||||||
|
|
||||||
GType
|
|
||||||
gst_onnx_model_input_image_datatype_get_type (void)
|
|
||||||
{
|
|
||||||
static GType model_input_image_datatype = 0;
|
|
||||||
|
|
||||||
if (g_once_init_enter (&model_input_image_datatype)) {
|
|
||||||
static GEnumValue model_input_image_datatype_types[] = {
|
|
||||||
{GST_TENSOR_TYPE_INT8, "8 Bits integer", "int8"},
|
|
||||||
{GST_TENSOR_TYPE_FLOAT32, "32 Bits floating points", "float"},
|
|
||||||
{0, NULL, NULL},
|
|
||||||
};
|
|
||||||
|
|
||||||
GType temp = g_enum_register_static ("GstTensorType",
|
|
||||||
model_input_image_datatype_types);
|
|
||||||
|
|
||||||
g_once_init_leave (&model_input_image_datatype, temp);
|
|
||||||
}
|
|
||||||
|
|
||||||
return model_input_image_datatype;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
static void
|
static void
|
||||||
gst_onnx_inference_class_init (GstOnnxInferenceClass * klass)
|
gst_onnx_inference_class_init (GstOnnxInferenceClass * klass)
|
||||||
{
|
{
|
||||||
|
@ -325,22 +299,6 @@ gst_onnx_inference_class_init (GstOnnxInferenceClass * klass)
|
||||||
GST_ONNX_EXECUTION_PROVIDER_CPU, (GParamFlags)
|
GST_ONNX_EXECUTION_PROVIDER_CPU, (GParamFlags)
|
||||||
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
||||||
|
|
||||||
/**
|
|
||||||
* GstOnnxInference:input-image-datatype
|
|
||||||
*
|
|
||||||
* Temporary hack, this should be discovered from the model and exposed
|
|
||||||
* on sinkpad caps based on model contrains.
|
|
||||||
*/
|
|
||||||
|
|
||||||
g_object_class_install_property (G_OBJECT_CLASS (klass),
|
|
||||||
PROP_INPUT_IMAGE_DATATYPE,
|
|
||||||
g_param_spec_enum ("input-image-datatype",
|
|
||||||
"Inference input image datatype",
|
|
||||||
"Datatype that will be used as an input for the inference",
|
|
||||||
GST_TYPE_ONNX_MODEL_INPUT_IMAGE_DATATYPE,
|
|
||||||
GST_TENSOR_TYPE_INT8,
|
|
||||||
(GParamFlags)(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
|
||||||
|
|
||||||
g_object_class_install_property (G_OBJECT_CLASS (klass),
|
g_object_class_install_property (G_OBJECT_CLASS (klass),
|
||||||
PROP_INPUT_OFFSET,
|
PROP_INPUT_OFFSET,
|
||||||
g_param_spec_float ("input-tensor-offset",
|
g_param_spec_float ("input-tensor-offset",
|
||||||
|
@ -425,10 +383,6 @@ gst_onnx_inference_set_property (GObject * object, guint prop_id,
|
||||||
onnxClient->setInputImageFormat ((GstMlInputImageFormat)
|
onnxClient->setInputImageFormat ((GstMlInputImageFormat)
|
||||||
g_value_get_enum (value));
|
g_value_get_enum (value));
|
||||||
break;
|
break;
|
||||||
case PROP_INPUT_IMAGE_DATATYPE:
|
|
||||||
onnxClient->setInputImageDatatype ((GstTensorType)
|
|
||||||
g_value_get_enum (value));
|
|
||||||
break;
|
|
||||||
case PROP_INPUT_OFFSET:
|
case PROP_INPUT_OFFSET:
|
||||||
onnxClient->setInputImageOffset (g_value_get_float (value));
|
onnxClient->setInputImageOffset (g_value_get_float (value));
|
||||||
break;
|
break;
|
||||||
|
@ -461,9 +415,6 @@ gst_onnx_inference_get_property (GObject * object, guint prop_id,
|
||||||
case PROP_INPUT_IMAGE_FORMAT:
|
case PROP_INPUT_IMAGE_FORMAT:
|
||||||
g_value_set_enum (value, onnxClient->getInputImageFormat ());
|
g_value_set_enum (value, onnxClient->getInputImageFormat ());
|
||||||
break;
|
break;
|
||||||
case PROP_INPUT_IMAGE_DATATYPE:
|
|
||||||
g_value_set_enum (value, onnxClient->getInputImageDatatype ());
|
|
||||||
break;
|
|
||||||
case PROP_INPUT_OFFSET:
|
case PROP_INPUT_OFFSET:
|
||||||
g_value_set_float (value, onnxClient->getInputImageOffset ());
|
g_value_set_float (value, onnxClient->getInputImageOffset ());
|
||||||
break;
|
break;
|
||||||
|
|
Loading…
Reference in a new issue