mirror of
https://gitlab.freedesktop.org/gstreamer/gstreamer.git
synced 2025-01-10 17:35:59 +00:00
onnx: add offset and scale properties
- Offset each datapoint by the value set on offset property. - Scale each datapoint by the value set on scale property. Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/5761>
This commit is contained in:
parent
3793619a02
commit
15e5866e51
3 changed files with 74 additions and 8 deletions
|
@ -52,7 +52,9 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
|
||||||
inputImageFormat (GST_ML_INPUT_IMAGE_FORMAT_HWC),
|
inputImageFormat (GST_ML_INPUT_IMAGE_FORMAT_HWC),
|
||||||
inputDatatype (GST_TENSOR_TYPE_INT8),
|
inputDatatype (GST_TENSOR_TYPE_INT8),
|
||||||
inputDatatypeSize (sizeof (uint8_t)),
|
inputDatatypeSize (sizeof (uint8_t)),
|
||||||
fixedInputImageSize (false) {
|
fixedInputImageSize (false),
|
||||||
|
inputTensorOffset (0.0),
|
||||||
|
inputTensorScale (1.0) {
|
||||||
}
|
}
|
||||||
|
|
||||||
GstOnnxClient::~GstOnnxClient () {
|
GstOnnxClient::~GstOnnxClient () {
|
||||||
|
@ -107,6 +109,26 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void GstOnnxClient::setInputImageOffset (float offset)
|
||||||
|
{
|
||||||
|
inputTensorOffset = offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
float GstOnnxClient::getInputImageOffset ()
|
||||||
|
{
|
||||||
|
return inputTensorOffset;
|
||||||
|
}
|
||||||
|
|
||||||
|
void GstOnnxClient::setInputImageScale (float scale)
|
||||||
|
{
|
||||||
|
inputTensorScale = scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
float GstOnnxClient::getInputImageScale ()
|
||||||
|
{
|
||||||
|
return inputTensorScale;
|
||||||
|
}
|
||||||
|
|
||||||
GstTensorType GstOnnxClient::getInputImageDatatype(void)
|
GstTensorType GstOnnxClient::getInputImageDatatype(void)
|
||||||
{
|
{
|
||||||
return inputDatatype;
|
return inputDatatype;
|
||||||
|
@ -422,14 +444,16 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
|
||||||
switch (inputDatatype) {
|
switch (inputDatatype) {
|
||||||
case GST_TENSOR_TYPE_INT8:
|
case GST_TENSOR_TYPE_INT8:
|
||||||
convert_image_remove_alpha (dest, inputImageFormat , srcPtr,
|
convert_image_remove_alpha (dest, inputImageFormat , srcPtr,
|
||||||
srcSamplesPerPixel, stride);
|
srcSamplesPerPixel, stride, (uint8_t)inputTensorOffset,
|
||||||
|
(uint8_t)inputTensorScale);
|
||||||
inputTensors.push_back (Ort::Value::CreateTensor < uint8_t > (
|
inputTensors.push_back (Ort::Value::CreateTensor < uint8_t > (
|
||||||
memoryInfo, dest, inputTensorSize, inputDims.data (),
|
memoryInfo, dest, inputTensorSize, inputDims.data (),
|
||||||
inputDims.size ()));
|
inputDims.size ()));
|
||||||
break;
|
break;
|
||||||
case GST_TENSOR_TYPE_FLOAT32: {
|
case GST_TENSOR_TYPE_FLOAT32: {
|
||||||
convert_image_remove_alpha ((float*)dest, inputImageFormat , srcPtr,
|
convert_image_remove_alpha ((float*)dest, inputImageFormat , srcPtr,
|
||||||
srcSamplesPerPixel, stride);
|
srcSamplesPerPixel, stride, (float)inputTensorOffset, (float)
|
||||||
|
inputTensorScale);
|
||||||
inputTensors.push_back (Ort::Value::CreateTensor < float > (
|
inputTensors.push_back (Ort::Value::CreateTensor < float > (
|
||||||
memoryInfo, (float*)dest, inputTensorSize, inputDims.data (),
|
memoryInfo, (float*)dest, inputTensorSize, inputDims.data (),
|
||||||
inputDims.size ()));
|
inputDims.size ()));
|
||||||
|
@ -451,14 +475,17 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
|
||||||
template < typename T>
|
template < typename T>
|
||||||
void GstOnnxClient::convert_image_remove_alpha (T *dst,
|
void GstOnnxClient::convert_image_remove_alpha (T *dst,
|
||||||
GstMlInputImageFormat hwc, uint8_t **srcPtr, uint32_t srcSamplesPerPixel,
|
GstMlInputImageFormat hwc, uint8_t **srcPtr, uint32_t srcSamplesPerPixel,
|
||||||
uint32_t stride) {
|
uint32_t stride, T offset, T div) {
|
||||||
size_t destIndex = 0;
|
size_t destIndex = 0;
|
||||||
|
T tmp;
|
||||||
|
|
||||||
if (inputImageFormat == GST_ML_INPUT_IMAGE_FORMAT_HWC) {
|
if (inputImageFormat == GST_ML_INPUT_IMAGE_FORMAT_HWC) {
|
||||||
for (int32_t j = 0; j < height; ++j) {
|
for (int32_t j = 0; j < height; ++j) {
|
||||||
for (int32_t i = 0; i < width; ++i) {
|
for (int32_t i = 0; i < width; ++i) {
|
||||||
for (int32_t k = 0; k < channels; ++k) {
|
for (int32_t k = 0; k < channels; ++k) {
|
||||||
dst[destIndex++] = (T)*srcPtr[k];
|
tmp = *srcPtr[k];
|
||||||
|
tmp += offset;
|
||||||
|
dst[destIndex++] = (T)(tmp / div);
|
||||||
srcPtr[k] += srcSamplesPerPixel;
|
srcPtr[k] += srcSamplesPerPixel;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -472,7 +499,9 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
|
||||||
for (int32_t j = 0; j < height; ++j) {
|
for (int32_t j = 0; j < height; ++j) {
|
||||||
for (int32_t i = 0; i < width; ++i) {
|
for (int32_t i = 0; i < width; ++i) {
|
||||||
for (int32_t k = 0; k < channels; ++k) {
|
for (int32_t k = 0; k < channels; ++k) {
|
||||||
destPtr[k][destIndex] = (T)*srcPtr[k];
|
tmp = *srcPtr[k];
|
||||||
|
tmp += offset;
|
||||||
|
destPtr[k][destIndex] = (T)(tmp / div);
|
||||||
srcPtr[k] += srcSamplesPerPixel;
|
srcPtr[k] += srcSamplesPerPixel;
|
||||||
}
|
}
|
||||||
destIndex++;
|
destIndex++;
|
||||||
|
|
|
@ -57,6 +57,10 @@ namespace GstOnnxNamespace {
|
||||||
GstMlInputImageFormat getInputImageFormat(void);
|
GstMlInputImageFormat getInputImageFormat(void);
|
||||||
void setInputImageDatatype(GstTensorType datatype);
|
void setInputImageDatatype(GstTensorType datatype);
|
||||||
GstTensorType getInputImageDatatype(void);
|
GstTensorType getInputImageDatatype(void);
|
||||||
|
void setInputImageOffset (float offset);
|
||||||
|
float getInputImageOffset ();
|
||||||
|
void setInputImageScale (float offset);
|
||||||
|
float getInputImageScale ();
|
||||||
std::vector < Ort::Value > run (uint8_t * img_data, GstVideoInfo vinfo);
|
std::vector < Ort::Value > run (uint8_t * img_data, GstVideoInfo vinfo);
|
||||||
std::vector < const char *> genOutputNamesRaw(void);
|
std::vector < const char *> genOutputNamesRaw(void);
|
||||||
bool isFixedInputImageSize(void);
|
bool isFixedInputImageSize(void);
|
||||||
|
@ -68,7 +72,7 @@ namespace GstOnnxNamespace {
|
||||||
|
|
||||||
template < typename T>
|
template < typename T>
|
||||||
void convert_image_remove_alpha (T *dest, GstMlInputImageFormat hwc,
|
void convert_image_remove_alpha (T *dest, GstMlInputImageFormat hwc,
|
||||||
uint8_t **srcPtr, uint32_t srcSamplesPerPixel, uint32_t stride);
|
uint8_t **srcPtr, uint32_t srcSamplesPerPixel, uint32_t stride, T offset, T div);
|
||||||
bool doRun(uint8_t * img_data, GstVideoInfo vinfo, std::vector < Ort::Value > &modelOutput);
|
bool doRun(uint8_t * img_data, GstVideoInfo vinfo, std::vector < Ort::Value > &modelOutput);
|
||||||
Ort::Env env;
|
Ort::Env env;
|
||||||
Ort::Session * session;
|
Ort::Session * session;
|
||||||
|
@ -86,6 +90,8 @@ namespace GstOnnxNamespace {
|
||||||
GstTensorType inputDatatype;
|
GstTensorType inputDatatype;
|
||||||
size_t inputDatatypeSize;
|
size_t inputDatatypeSize;
|
||||||
bool fixedInputImageSize;
|
bool fixedInputImageSize;
|
||||||
|
float inputTensorOffset;
|
||||||
|
float inputTensorScale;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -123,7 +123,9 @@ enum
|
||||||
PROP_INPUT_IMAGE_FORMAT,
|
PROP_INPUT_IMAGE_FORMAT,
|
||||||
PROP_OPTIMIZATION_LEVEL,
|
PROP_OPTIMIZATION_LEVEL,
|
||||||
PROP_EXECUTION_PROVIDER,
|
PROP_EXECUTION_PROVIDER,
|
||||||
PROP_INPUT_IMAGE_DATATYPE
|
PROP_INPUT_IMAGE_DATATYPE,
|
||||||
|
PROP_INPUT_OFFSET,
|
||||||
|
PROP_INPUT_SCALE
|
||||||
};
|
};
|
||||||
|
|
||||||
#define GST_ONNX_INFERENCE_DEFAULT_EXECUTION_PROVIDER GST_ONNX_EXECUTION_PROVIDER_CPU
|
#define GST_ONNX_INFERENCE_DEFAULT_EXECUTION_PROVIDER GST_ONNX_EXECUTION_PROVIDER_CPU
|
||||||
|
@ -362,6 +364,23 @@ gst_onnx_inference_class_init (GstOnnxInferenceClass * klass)
|
||||||
GST_TENSOR_TYPE_INT8,
|
GST_TENSOR_TYPE_INT8,
|
||||||
(GParamFlags)(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
(GParamFlags)(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
||||||
|
|
||||||
|
g_object_class_install_property (G_OBJECT_CLASS (klass),
|
||||||
|
PROP_INPUT_OFFSET,
|
||||||
|
g_param_spec_float ("input-tensor-offset",
|
||||||
|
"Input tensor offset",
|
||||||
|
"offset each tensor value by this value",
|
||||||
|
-G_MAXFLOAT, G_MAXFLOAT, 0.0,
|
||||||
|
(GParamFlags)(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
||||||
|
|
||||||
|
g_object_class_install_property (G_OBJECT_CLASS (klass),
|
||||||
|
PROP_INPUT_SCALE,
|
||||||
|
g_param_spec_float ("input-tensor-scale",
|
||||||
|
"Input tensor scale",
|
||||||
|
"Divide each tensor value by this value",
|
||||||
|
G_MINFLOAT, G_MAXFLOAT, 1.0,
|
||||||
|
(GParamFlags)(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
||||||
|
|
||||||
|
|
||||||
gst_element_class_set_static_metadata (element_class, "onnxinference",
|
gst_element_class_set_static_metadata (element_class, "onnxinference",
|
||||||
"Filter/Effect/Video",
|
"Filter/Effect/Video",
|
||||||
"Apply neural network to video frames and create tensor output",
|
"Apply neural network to video frames and create tensor output",
|
||||||
|
@ -433,6 +452,12 @@ gst_onnx_inference_set_property (GObject * object, guint prop_id,
|
||||||
onnxClient->setInputImageDatatype ((GstTensorType)
|
onnxClient->setInputImageDatatype ((GstTensorType)
|
||||||
g_value_get_enum (value));
|
g_value_get_enum (value));
|
||||||
break;
|
break;
|
||||||
|
case PROP_INPUT_OFFSET:
|
||||||
|
onnxClient->setInputImageOffset (g_value_get_float (value));
|
||||||
|
break;
|
||||||
|
case PROP_INPUT_SCALE:
|
||||||
|
onnxClient->setInputImageScale (g_value_get_float (value));
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
|
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
|
||||||
break;
|
break;
|
||||||
|
@ -462,6 +487,12 @@ gst_onnx_inference_get_property (GObject * object, guint prop_id,
|
||||||
case PROP_INPUT_IMAGE_DATATYPE:
|
case PROP_INPUT_IMAGE_DATATYPE:
|
||||||
g_value_set_enum (value, onnxClient->getInputImageDatatype ());
|
g_value_set_enum (value, onnxClient->getInputImageDatatype ());
|
||||||
break;
|
break;
|
||||||
|
case PROP_INPUT_OFFSET:
|
||||||
|
g_value_set_float (value, onnxClient->getInputImageOffset ());
|
||||||
|
break;
|
||||||
|
case PROP_INPUT_SCALE:
|
||||||
|
g_value_set_float (value, onnxClient->getInputImageScale ());
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
|
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
|
||||||
break;
|
break;
|
||||||
|
|
Loading…
Reference in a new issue