onnx: add offset and scale properties

- Offset each datapoint by the value set on offset property.
- Scale each datapoint by the value set on scale property.

Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/5761>
This commit is contained in:
Daniel Morin 2023-12-04 19:57:39 -05:00 committed by GStreamer Marge Bot
parent 3793619a02
commit 15e5866e51
3 changed files with 74 additions and 8 deletions

View file

@ -52,7 +52,9 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
inputImageFormat (GST_ML_INPUT_IMAGE_FORMAT_HWC),
inputDatatype (GST_TENSOR_TYPE_INT8),
inputDatatypeSize (sizeof (uint8_t)),
fixedInputImageSize (false) {
fixedInputImageSize (false),
inputTensorOffset (0.0),
inputTensorScale (1.0) {
}
GstOnnxClient::~GstOnnxClient () {
@ -107,6 +109,26 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
};
}
void GstOnnxClient::setInputImageOffset (float offset)
{
inputTensorOffset = offset;
}
float GstOnnxClient::getInputImageOffset ()
{
return inputTensorOffset;
}
void GstOnnxClient::setInputImageScale (float scale)
{
inputTensorScale = scale;
}
float GstOnnxClient::getInputImageScale ()
{
return inputTensorScale;
}
GstTensorType GstOnnxClient::getInputImageDatatype(void)
{
return inputDatatype;
@ -422,14 +444,16 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
switch (inputDatatype) {
case GST_TENSOR_TYPE_INT8:
convert_image_remove_alpha (dest, inputImageFormat , srcPtr,
srcSamplesPerPixel, stride);
srcSamplesPerPixel, stride, (uint8_t)inputTensorOffset,
(uint8_t)inputTensorScale);
inputTensors.push_back (Ort::Value::CreateTensor < uint8_t > (
memoryInfo, dest, inputTensorSize, inputDims.data (),
inputDims.size ()));
break;
case GST_TENSOR_TYPE_FLOAT32: {
convert_image_remove_alpha ((float*)dest, inputImageFormat , srcPtr,
srcSamplesPerPixel, stride);
srcSamplesPerPixel, stride, (float)inputTensorOffset, (float)
inputTensorScale);
inputTensors.push_back (Ort::Value::CreateTensor < float > (
memoryInfo, (float*)dest, inputTensorSize, inputDims.data (),
inputDims.size ()));
@ -451,14 +475,17 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
template < typename T>
void GstOnnxClient::convert_image_remove_alpha (T *dst,
GstMlInputImageFormat hwc, uint8_t **srcPtr, uint32_t srcSamplesPerPixel,
uint32_t stride) {
uint32_t stride, T offset, T div) {
size_t destIndex = 0;
T tmp;
if (inputImageFormat == GST_ML_INPUT_IMAGE_FORMAT_HWC) {
for (int32_t j = 0; j < height; ++j) {
for (int32_t i = 0; i < width; ++i) {
for (int32_t k = 0; k < channels; ++k) {
dst[destIndex++] = (T)*srcPtr[k];
tmp = *srcPtr[k];
tmp += offset;
dst[destIndex++] = (T)(tmp / div);
srcPtr[k] += srcSamplesPerPixel;
}
}
@ -472,7 +499,9 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
for (int32_t j = 0; j < height; ++j) {
for (int32_t i = 0; i < width; ++i) {
for (int32_t k = 0; k < channels; ++k) {
destPtr[k][destIndex] = (T)*srcPtr[k];
tmp = *srcPtr[k];
tmp += offset;
destPtr[k][destIndex] = (T)(tmp / div);
srcPtr[k] += srcSamplesPerPixel;
}
destIndex++;

View file

@ -57,6 +57,10 @@ namespace GstOnnxNamespace {
GstMlInputImageFormat getInputImageFormat(void);
void setInputImageDatatype(GstTensorType datatype);
GstTensorType getInputImageDatatype(void);
void setInputImageOffset (float offset);
float getInputImageOffset ();
void setInputImageScale (float offset);
float getInputImageScale ();
std::vector < Ort::Value > run (uint8_t * img_data, GstVideoInfo vinfo);
std::vector < const char *> genOutputNamesRaw(void);
bool isFixedInputImageSize(void);
@ -68,7 +72,7 @@ namespace GstOnnxNamespace {
template < typename T>
void convert_image_remove_alpha (T *dest, GstMlInputImageFormat hwc,
uint8_t **srcPtr, uint32_t srcSamplesPerPixel, uint32_t stride);
uint8_t **srcPtr, uint32_t srcSamplesPerPixel, uint32_t stride, T offset, T div);
bool doRun(uint8_t * img_data, GstVideoInfo vinfo, std::vector < Ort::Value > &modelOutput);
Ort::Env env;
Ort::Session * session;
@ -86,6 +90,8 @@ namespace GstOnnxNamespace {
GstTensorType inputDatatype;
size_t inputDatatypeSize;
bool fixedInputImageSize;
float inputTensorOffset;
float inputTensorScale;
};
}

View file

@ -123,7 +123,9 @@ enum
PROP_INPUT_IMAGE_FORMAT,
PROP_OPTIMIZATION_LEVEL,
PROP_EXECUTION_PROVIDER,
PROP_INPUT_IMAGE_DATATYPE
PROP_INPUT_IMAGE_DATATYPE,
PROP_INPUT_OFFSET,
PROP_INPUT_SCALE
};
#define GST_ONNX_INFERENCE_DEFAULT_EXECUTION_PROVIDER GST_ONNX_EXECUTION_PROVIDER_CPU
@ -362,6 +364,23 @@ gst_onnx_inference_class_init (GstOnnxInferenceClass * klass)
GST_TENSOR_TYPE_INT8,
(GParamFlags)(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
g_object_class_install_property (G_OBJECT_CLASS (klass),
PROP_INPUT_OFFSET,
g_param_spec_float ("input-tensor-offset",
"Input tensor offset",
"offset each tensor value by this value",
-G_MAXFLOAT, G_MAXFLOAT, 0.0,
(GParamFlags)(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
g_object_class_install_property (G_OBJECT_CLASS (klass),
PROP_INPUT_SCALE,
g_param_spec_float ("input-tensor-scale",
"Input tensor scale",
"Divide each tensor value by this value",
G_MINFLOAT, G_MAXFLOAT, 1.0,
(GParamFlags)(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
gst_element_class_set_static_metadata (element_class, "onnxinference",
"Filter/Effect/Video",
"Apply neural network to video frames and create tensor output",
@ -433,6 +452,12 @@ gst_onnx_inference_set_property (GObject * object, guint prop_id,
onnxClient->setInputImageDatatype ((GstTensorType)
g_value_get_enum (value));
break;
case PROP_INPUT_OFFSET:
onnxClient->setInputImageOffset (g_value_get_float (value));
break;
case PROP_INPUT_SCALE:
onnxClient->setInputImageScale (g_value_get_float (value));
break;
default:
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
break;
@ -462,6 +487,12 @@ gst_onnx_inference_get_property (GObject * object, guint prop_id,
case PROP_INPUT_IMAGE_DATATYPE:
g_value_set_enum (value, onnxClient->getInputImageDatatype ());
break;
case PROP_INPUT_OFFSET:
g_value_set_float (value, onnxClient->getInputImageOffset ());
break;
case PROP_INPUT_SCALE:
g_value_set_float (value, onnxClient->getInputImageScale ());
break;
default:
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
break;