onnx: add offset and scale properties

- Offset each datapoint by the value set on offset property.
- Scale each datapoint by the value set on scale property.

Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/5761>
This commit is contained in:
Daniel Morin 2023-12-04 19:57:39 -05:00 committed by GStreamer Marge Bot
parent 3793619a02
commit 15e5866e51
3 changed files with 74 additions and 8 deletions

View file

@ -52,7 +52,9 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
inputImageFormat (GST_ML_INPUT_IMAGE_FORMAT_HWC), inputImageFormat (GST_ML_INPUT_IMAGE_FORMAT_HWC),
inputDatatype (GST_TENSOR_TYPE_INT8), inputDatatype (GST_TENSOR_TYPE_INT8),
inputDatatypeSize (sizeof (uint8_t)), inputDatatypeSize (sizeof (uint8_t)),
fixedInputImageSize (false) { fixedInputImageSize (false),
inputTensorOffset (0.0),
inputTensorScale (1.0) {
} }
GstOnnxClient::~GstOnnxClient () { GstOnnxClient::~GstOnnxClient () {
@ -107,6 +109,26 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
}; };
} }
void GstOnnxClient::setInputImageOffset (float offset)
{
inputTensorOffset = offset;
}
float GstOnnxClient::getInputImageOffset ()
{
return inputTensorOffset;
}
void GstOnnxClient::setInputImageScale (float scale)
{
inputTensorScale = scale;
}
float GstOnnxClient::getInputImageScale ()
{
return inputTensorScale;
}
GstTensorType GstOnnxClient::getInputImageDatatype(void) GstTensorType GstOnnxClient::getInputImageDatatype(void)
{ {
return inputDatatype; return inputDatatype;
@ -422,14 +444,16 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
switch (inputDatatype) { switch (inputDatatype) {
case GST_TENSOR_TYPE_INT8: case GST_TENSOR_TYPE_INT8:
convert_image_remove_alpha (dest, inputImageFormat , srcPtr, convert_image_remove_alpha (dest, inputImageFormat , srcPtr,
srcSamplesPerPixel, stride); srcSamplesPerPixel, stride, (uint8_t)inputTensorOffset,
(uint8_t)inputTensorScale);
inputTensors.push_back (Ort::Value::CreateTensor < uint8_t > ( inputTensors.push_back (Ort::Value::CreateTensor < uint8_t > (
memoryInfo, dest, inputTensorSize, inputDims.data (), memoryInfo, dest, inputTensorSize, inputDims.data (),
inputDims.size ())); inputDims.size ()));
break; break;
case GST_TENSOR_TYPE_FLOAT32: { case GST_TENSOR_TYPE_FLOAT32: {
convert_image_remove_alpha ((float*)dest, inputImageFormat , srcPtr, convert_image_remove_alpha ((float*)dest, inputImageFormat , srcPtr,
srcSamplesPerPixel, stride); srcSamplesPerPixel, stride, (float)inputTensorOffset, (float)
inputTensorScale);
inputTensors.push_back (Ort::Value::CreateTensor < float > ( inputTensors.push_back (Ort::Value::CreateTensor < float > (
memoryInfo, (float*)dest, inputTensorSize, inputDims.data (), memoryInfo, (float*)dest, inputTensorSize, inputDims.data (),
inputDims.size ())); inputDims.size ()));
@ -451,14 +475,17 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
template < typename T> template < typename T>
void GstOnnxClient::convert_image_remove_alpha (T *dst, void GstOnnxClient::convert_image_remove_alpha (T *dst,
GstMlInputImageFormat hwc, uint8_t **srcPtr, uint32_t srcSamplesPerPixel, GstMlInputImageFormat hwc, uint8_t **srcPtr, uint32_t srcSamplesPerPixel,
uint32_t stride) { uint32_t stride, T offset, T div) {
size_t destIndex = 0; size_t destIndex = 0;
T tmp;
if (inputImageFormat == GST_ML_INPUT_IMAGE_FORMAT_HWC) { if (inputImageFormat == GST_ML_INPUT_IMAGE_FORMAT_HWC) {
for (int32_t j = 0; j < height; ++j) { for (int32_t j = 0; j < height; ++j) {
for (int32_t i = 0; i < width; ++i) { for (int32_t i = 0; i < width; ++i) {
for (int32_t k = 0; k < channels; ++k) { for (int32_t k = 0; k < channels; ++k) {
dst[destIndex++] = (T)*srcPtr[k]; tmp = *srcPtr[k];
tmp += offset;
dst[destIndex++] = (T)(tmp / div);
srcPtr[k] += srcSamplesPerPixel; srcPtr[k] += srcSamplesPerPixel;
} }
} }
@ -472,7 +499,9 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
for (int32_t j = 0; j < height; ++j) { for (int32_t j = 0; j < height; ++j) {
for (int32_t i = 0; i < width; ++i) { for (int32_t i = 0; i < width; ++i) {
for (int32_t k = 0; k < channels; ++k) { for (int32_t k = 0; k < channels; ++k) {
destPtr[k][destIndex] = (T)*srcPtr[k]; tmp = *srcPtr[k];
tmp += offset;
destPtr[k][destIndex] = (T)(tmp / div);
srcPtr[k] += srcSamplesPerPixel; srcPtr[k] += srcSamplesPerPixel;
} }
destIndex++; destIndex++;

View file

@ -57,6 +57,10 @@ namespace GstOnnxNamespace {
GstMlInputImageFormat getInputImageFormat(void); GstMlInputImageFormat getInputImageFormat(void);
void setInputImageDatatype(GstTensorType datatype); void setInputImageDatatype(GstTensorType datatype);
GstTensorType getInputImageDatatype(void); GstTensorType getInputImageDatatype(void);
void setInputImageOffset (float offset);
float getInputImageOffset ();
void setInputImageScale (float offset);
float getInputImageScale ();
std::vector < Ort::Value > run (uint8_t * img_data, GstVideoInfo vinfo); std::vector < Ort::Value > run (uint8_t * img_data, GstVideoInfo vinfo);
std::vector < const char *> genOutputNamesRaw(void); std::vector < const char *> genOutputNamesRaw(void);
bool isFixedInputImageSize(void); bool isFixedInputImageSize(void);
@ -68,7 +72,7 @@ namespace GstOnnxNamespace {
template < typename T> template < typename T>
void convert_image_remove_alpha (T *dest, GstMlInputImageFormat hwc, void convert_image_remove_alpha (T *dest, GstMlInputImageFormat hwc,
uint8_t **srcPtr, uint32_t srcSamplesPerPixel, uint32_t stride); uint8_t **srcPtr, uint32_t srcSamplesPerPixel, uint32_t stride, T offset, T div);
bool doRun(uint8_t * img_data, GstVideoInfo vinfo, std::vector < Ort::Value > &modelOutput); bool doRun(uint8_t * img_data, GstVideoInfo vinfo, std::vector < Ort::Value > &modelOutput);
Ort::Env env; Ort::Env env;
Ort::Session * session; Ort::Session * session;
@ -86,6 +90,8 @@ namespace GstOnnxNamespace {
GstTensorType inputDatatype; GstTensorType inputDatatype;
size_t inputDatatypeSize; size_t inputDatatypeSize;
bool fixedInputImageSize; bool fixedInputImageSize;
float inputTensorOffset;
float inputTensorScale;
}; };
} }

View file

@ -123,7 +123,9 @@ enum
PROP_INPUT_IMAGE_FORMAT, PROP_INPUT_IMAGE_FORMAT,
PROP_OPTIMIZATION_LEVEL, PROP_OPTIMIZATION_LEVEL,
PROP_EXECUTION_PROVIDER, PROP_EXECUTION_PROVIDER,
PROP_INPUT_IMAGE_DATATYPE PROP_INPUT_IMAGE_DATATYPE,
PROP_INPUT_OFFSET,
PROP_INPUT_SCALE
}; };
#define GST_ONNX_INFERENCE_DEFAULT_EXECUTION_PROVIDER GST_ONNX_EXECUTION_PROVIDER_CPU #define GST_ONNX_INFERENCE_DEFAULT_EXECUTION_PROVIDER GST_ONNX_EXECUTION_PROVIDER_CPU
@ -362,6 +364,23 @@ gst_onnx_inference_class_init (GstOnnxInferenceClass * klass)
GST_TENSOR_TYPE_INT8, GST_TENSOR_TYPE_INT8,
(GParamFlags)(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS))); (GParamFlags)(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
g_object_class_install_property (G_OBJECT_CLASS (klass),
PROP_INPUT_OFFSET,
g_param_spec_float ("input-tensor-offset",
"Input tensor offset",
"offset each tensor value by this value",
-G_MAXFLOAT, G_MAXFLOAT, 0.0,
(GParamFlags)(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
g_object_class_install_property (G_OBJECT_CLASS (klass),
PROP_INPUT_SCALE,
g_param_spec_float ("input-tensor-scale",
"Input tensor scale",
"Divide each tensor value by this value",
G_MINFLOAT, G_MAXFLOAT, 1.0,
(GParamFlags)(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
gst_element_class_set_static_metadata (element_class, "onnxinference", gst_element_class_set_static_metadata (element_class, "onnxinference",
"Filter/Effect/Video", "Filter/Effect/Video",
"Apply neural network to video frames and create tensor output", "Apply neural network to video frames and create tensor output",
@ -433,6 +452,12 @@ gst_onnx_inference_set_property (GObject * object, guint prop_id,
onnxClient->setInputImageDatatype ((GstTensorType) onnxClient->setInputImageDatatype ((GstTensorType)
g_value_get_enum (value)); g_value_get_enum (value));
break; break;
case PROP_INPUT_OFFSET:
onnxClient->setInputImageOffset (g_value_get_float (value));
break;
case PROP_INPUT_SCALE:
onnxClient->setInputImageScale (g_value_get_float (value));
break;
default: default:
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec); G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
break; break;
@ -462,6 +487,12 @@ gst_onnx_inference_get_property (GObject * object, guint prop_id,
case PROP_INPUT_IMAGE_DATATYPE: case PROP_INPUT_IMAGE_DATATYPE:
g_value_set_enum (value, onnxClient->getInputImageDatatype ()); g_value_set_enum (value, onnxClient->getInputImageDatatype ());
break; break;
case PROP_INPUT_OFFSET:
g_value_set_float (value, onnxClient->getInputImageOffset ());
break;
case PROP_INPUT_SCALE:
g_value_set_float (value, onnxClient->getInputImageScale ());
break;
default: default:
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec); G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
break; break;