mirror of
https://gitlab.freedesktop.org/gstreamer/gstreamer.git
synced 2025-01-25 00:28:21 +00:00
onnx: Add support for float datatype
This is a bit of a hack solution has I think the correct solution is to expose model caps on sinkpad (eventually sinkpads). Till then I think this is reasonable. - Add a property to onnxinference to set datatype. - Fix internal buffer allocation size based on datatype. - Extract method to remove alphe channel and convert to planar image when requested. Also template the method to support writing to buffers of different datatype. Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/5761>
This commit is contained in:
parent
f59219228f
commit
48e3836482
3 changed files with 138 additions and 26 deletions
|
@ -50,6 +50,8 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
|
||||||
dest (nullptr),
|
dest (nullptr),
|
||||||
m_provider (GST_ONNX_EXECUTION_PROVIDER_CPU),
|
m_provider (GST_ONNX_EXECUTION_PROVIDER_CPU),
|
||||||
inputImageFormat (GST_ML_INPUT_IMAGE_FORMAT_HWC),
|
inputImageFormat (GST_ML_INPUT_IMAGE_FORMAT_HWC),
|
||||||
|
inputDatatype (GST_TENSOR_TYPE_INT8),
|
||||||
|
inputDatatypeSize (sizeof (uint8_t)),
|
||||||
fixedInputImageSize (false) {
|
fixedInputImageSize (false) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -83,6 +85,33 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
|
||||||
return inputImageFormat;
|
return inputImageFormat;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void GstOnnxClient::setInputImageDatatype(GstTensorType datatype)
|
||||||
|
{
|
||||||
|
inputDatatype = datatype;
|
||||||
|
switch (inputDatatype) {
|
||||||
|
case GST_TENSOR_TYPE_INT8:
|
||||||
|
inputDatatypeSize = sizeof (uint8_t);
|
||||||
|
break;
|
||||||
|
case GST_TENSOR_TYPE_INT16:
|
||||||
|
inputDatatypeSize = sizeof (uint16_t);
|
||||||
|
break;
|
||||||
|
case GST_TENSOR_TYPE_INT32:
|
||||||
|
inputDatatypeSize = sizeof (uint32_t);
|
||||||
|
break;
|
||||||
|
case GST_TENSOR_TYPE_FLOAT16:
|
||||||
|
inputDatatypeSize = 2;
|
||||||
|
break;
|
||||||
|
case GST_TENSOR_TYPE_FLOAT32:
|
||||||
|
inputDatatypeSize = sizeof (float);
|
||||||
|
break;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
GstTensorType GstOnnxClient::getInputImageDatatype(void)
|
||||||
|
{
|
||||||
|
return inputDatatype;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector < const char *>GstOnnxClient::genOutputNamesRaw (void)
|
std::vector < const char *>GstOnnxClient::genOutputNamesRaw (void)
|
||||||
{
|
{
|
||||||
if (!outputNames.empty () && outputNamesRaw.size () != outputNames.size ()) {
|
if (!outputNames.empty () && outputNamesRaw.size () != outputNames.size ()) {
|
||||||
|
@ -229,9 +258,13 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
|
||||||
int32_t newWidth = fixedInputImageSize ? width : vinfo.width;
|
int32_t newWidth = fixedInputImageSize ? width : vinfo.width;
|
||||||
int32_t newHeight = fixedInputImageSize ? height : vinfo.height;
|
int32_t newHeight = fixedInputImageSize ? height : vinfo.height;
|
||||||
|
|
||||||
|
if (!fixedInputImageSize) {
|
||||||
|
GST_WARNING ("Allocating before knowing model input size");
|
||||||
|
}
|
||||||
|
|
||||||
if (!dest || width * height < newWidth * newHeight) {
|
if (!dest || width * height < newWidth * newHeight) {
|
||||||
delete[]dest;
|
delete[]dest;
|
||||||
dest = new uint8_t[newWidth * newHeight * channels];
|
dest = new uint8_t[newWidth * newHeight * channels * inputDatatypeSize];
|
||||||
}
|
}
|
||||||
width = newWidth;
|
width = newWidth;
|
||||||
height = newHeight;
|
height = newHeight;
|
||||||
|
@ -378,13 +411,54 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
size_t destIndex = 0;
|
|
||||||
uint32_t stride = vinfo.stride[0];
|
uint32_t stride = vinfo.stride[0];
|
||||||
|
const size_t inputTensorSize = width * height * channels * inputDatatypeSize;
|
||||||
|
auto memoryInfo =
|
||||||
|
Ort::MemoryInfo::CreateCpu (OrtAllocatorType::OrtArenaAllocator,
|
||||||
|
OrtMemType::OrtMemTypeDefault);
|
||||||
|
|
||||||
|
std::vector < Ort::Value > inputTensors;
|
||||||
|
|
||||||
|
switch (inputDatatype) {
|
||||||
|
case GST_TENSOR_TYPE_INT8:
|
||||||
|
convert_image_remove_alpha (dest, inputImageFormat , srcPtr,
|
||||||
|
srcSamplesPerPixel, stride);
|
||||||
|
inputTensors.push_back (Ort::Value::CreateTensor < uint8_t > (
|
||||||
|
memoryInfo, dest, inputTensorSize, inputDims.data (),
|
||||||
|
inputDims.size ()));
|
||||||
|
break;
|
||||||
|
case GST_TENSOR_TYPE_FLOAT32: {
|
||||||
|
convert_image_remove_alpha ((float*)dest, inputImageFormat , srcPtr,
|
||||||
|
srcSamplesPerPixel, stride);
|
||||||
|
inputTensors.push_back (Ort::Value::CreateTensor < float > (
|
||||||
|
memoryInfo, (float*)dest, inputTensorSize, inputDims.data (),
|
||||||
|
inputDims.size ()));
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector < const char *>inputNames { inputName.get () };
|
||||||
|
modelOutput = session->Run (Ort::RunOptions {nullptr},
|
||||||
|
inputNames.data (),
|
||||||
|
inputTensors.data (), 1, outputNamesRaw.data (),
|
||||||
|
outputNamesRaw.size ());
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template < typename T>
|
||||||
|
void GstOnnxClient::convert_image_remove_alpha (T *dst,
|
||||||
|
GstMlInputImageFormat hwc, uint8_t **srcPtr, uint32_t srcSamplesPerPixel,
|
||||||
|
uint32_t stride) {
|
||||||
|
size_t destIndex = 0;
|
||||||
|
|
||||||
if (inputImageFormat == GST_ML_INPUT_IMAGE_FORMAT_HWC) {
|
if (inputImageFormat == GST_ML_INPUT_IMAGE_FORMAT_HWC) {
|
||||||
for (int32_t j = 0; j < height; ++j) {
|
for (int32_t j = 0; j < height; ++j) {
|
||||||
for (int32_t i = 0; i < width; ++i) {
|
for (int32_t i = 0; i < width; ++i) {
|
||||||
for (int32_t k = 0; k < channels; ++k) {
|
for (int32_t k = 0; k < channels; ++k) {
|
||||||
dest[destIndex++] = *srcPtr[k];
|
dst[destIndex++] = (T)*srcPtr[k];
|
||||||
srcPtr[k] += srcSamplesPerPixel;
|
srcPtr[k] += srcSamplesPerPixel;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -394,11 +468,11 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
size_t frameSize = width * height;
|
size_t frameSize = width * height;
|
||||||
uint8_t *destPtr[3] = { dest, dest + frameSize, dest + 2 * frameSize };
|
T *destPtr[3] = { dst, dst + frameSize, dst + 2 * frameSize };
|
||||||
for (int32_t j = 0; j < height; ++j) {
|
for (int32_t j = 0; j < height; ++j) {
|
||||||
for (int32_t i = 0; i < width; ++i) {
|
for (int32_t i = 0; i < width; ++i) {
|
||||||
for (int32_t k = 0; k < channels; ++k) {
|
for (int32_t k = 0; k < channels; ++k) {
|
||||||
destPtr[k][destIndex] = *srcPtr[k];
|
destPtr[k][destIndex] = (T)*srcPtr[k];
|
||||||
srcPtr[k] += srcSamplesPerPixel;
|
srcPtr[k] += srcSamplesPerPixel;
|
||||||
}
|
}
|
||||||
destIndex++;
|
destIndex++;
|
||||||
|
@ -408,25 +482,5 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
|
||||||
srcPtr[k] += stride - srcSamplesPerPixel * width;
|
srcPtr[k] += stride - srcSamplesPerPixel * width;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t inputTensorSize = width * height * channels;
|
|
||||||
auto memoryInfo =
|
|
||||||
Ort::MemoryInfo::CreateCpu (OrtAllocatorType::OrtArenaAllocator,
|
|
||||||
OrtMemType::OrtMemTypeDefault);
|
|
||||||
std::vector < Ort::Value > inputTensors;
|
|
||||||
inputTensors.push_back (Ort::Value::CreateTensor < uint8_t > (memoryInfo,
|
|
||||||
dest, inputTensorSize, inputDims.data (), inputDims.size ()));
|
|
||||||
std::vector < const char *>inputNames
|
|
||||||
{
|
|
||||||
inputName.get ()};
|
|
||||||
|
|
||||||
modelOutput = session->Run (Ort::RunOptions {
|
|
||||||
nullptr},
|
|
||||||
inputNames.data (),
|
|
||||||
inputTensors.data (), 1, outputNamesRaw.data (),
|
|
||||||
outputNamesRaw.size ());
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -55,6 +55,8 @@ namespace GstOnnxNamespace {
|
||||||
bool hasSession(void);
|
bool hasSession(void);
|
||||||
void setInputImageFormat(GstMlInputImageFormat format);
|
void setInputImageFormat(GstMlInputImageFormat format);
|
||||||
GstMlInputImageFormat getInputImageFormat(void);
|
GstMlInputImageFormat getInputImageFormat(void);
|
||||||
|
void setInputImageDatatype(GstTensorType datatype);
|
||||||
|
GstTensorType getInputImageDatatype(void);
|
||||||
std::vector < Ort::Value > run (uint8_t * img_data, GstVideoInfo vinfo);
|
std::vector < Ort::Value > run (uint8_t * img_data, GstVideoInfo vinfo);
|
||||||
std::vector < const char *> genOutputNamesRaw(void);
|
std::vector < const char *> genOutputNamesRaw(void);
|
||||||
bool isFixedInputImageSize(void);
|
bool isFixedInputImageSize(void);
|
||||||
|
@ -63,6 +65,10 @@ namespace GstOnnxNamespace {
|
||||||
GstTensorMeta* copy_tensors_to_meta(std::vector < Ort::Value > &outputs,GstBuffer* buffer);
|
GstTensorMeta* copy_tensors_to_meta(std::vector < Ort::Value > &outputs,GstBuffer* buffer);
|
||||||
void parseDimensions(GstVideoInfo vinfo);
|
void parseDimensions(GstVideoInfo vinfo);
|
||||||
private:
|
private:
|
||||||
|
|
||||||
|
template < typename T>
|
||||||
|
void convert_image_remove_alpha (T *dest, GstMlInputImageFormat hwc,
|
||||||
|
uint8_t **srcPtr, uint32_t srcSamplesPerPixel, uint32_t stride);
|
||||||
bool doRun(uint8_t * img_data, GstVideoInfo vinfo, std::vector < Ort::Value > &modelOutput);
|
bool doRun(uint8_t * img_data, GstVideoInfo vinfo, std::vector < Ort::Value > &modelOutput);
|
||||||
Ort::Env env;
|
Ort::Env env;
|
||||||
Ort::Session * session;
|
Ort::Session * session;
|
||||||
|
@ -77,6 +83,8 @@ namespace GstOnnxNamespace {
|
||||||
std::vector < Ort::AllocatedStringPtr > outputNames;
|
std::vector < Ort::AllocatedStringPtr > outputNames;
|
||||||
std::vector < GQuark > outputIds;
|
std::vector < GQuark > outputIds;
|
||||||
GstMlInputImageFormat inputImageFormat;
|
GstMlInputImageFormat inputImageFormat;
|
||||||
|
GstTensorType inputDatatype;
|
||||||
|
size_t inputDatatypeSize;
|
||||||
bool fixedInputImageSize;
|
bool fixedInputImageSize;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -122,7 +122,8 @@ enum
|
||||||
PROP_MODEL_FILE,
|
PROP_MODEL_FILE,
|
||||||
PROP_INPUT_IMAGE_FORMAT,
|
PROP_INPUT_IMAGE_FORMAT,
|
||||||
PROP_OPTIMIZATION_LEVEL,
|
PROP_OPTIMIZATION_LEVEL,
|
||||||
PROP_EXECUTION_PROVIDER
|
PROP_EXECUTION_PROVIDER,
|
||||||
|
PROP_INPUT_IMAGE_DATATYPE
|
||||||
};
|
};
|
||||||
|
|
||||||
#define GST_ONNX_INFERENCE_DEFAULT_EXECUTION_PROVIDER GST_ONNX_EXECUTION_PROVIDER_CPU
|
#define GST_ONNX_INFERENCE_DEFAULT_EXECUTION_PROVIDER GST_ONNX_EXECUTION_PROVIDER_CPU
|
||||||
|
@ -169,6 +170,9 @@ GType gst_onnx_execution_provider_get_type (void);
|
||||||
GType gst_ml_model_input_image_format_get_type (void);
|
GType gst_ml_model_input_image_format_get_type (void);
|
||||||
#define GST_TYPE_ML_MODEL_INPUT_IMAGE_FORMAT (gst_ml_model_input_image_format_get_type ())
|
#define GST_TYPE_ML_MODEL_INPUT_IMAGE_FORMAT (gst_ml_model_input_image_format_get_type ())
|
||||||
|
|
||||||
|
GType gst_onnx_model_input_image_datatype_get_type (void);
|
||||||
|
#define GST_TYPE_ONNX_MODEL_INPUT_IMAGE_DATATYPE (gst_onnx_model_input_image_datatype_get_type ())
|
||||||
|
|
||||||
GType
|
GType
|
||||||
gst_onnx_optimization_level_get_type (void)
|
gst_onnx_optimization_level_get_type (void)
|
||||||
{
|
{
|
||||||
|
@ -247,6 +251,28 @@ gst_ml_model_input_image_format_get_type (void)
|
||||||
return ml_model_input_image_format;
|
return ml_model_input_image_format;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GType
|
||||||
|
gst_onnx_model_input_image_datatype_get_type (void)
|
||||||
|
{
|
||||||
|
static GType model_input_image_datatype = 0;
|
||||||
|
|
||||||
|
if (g_once_init_enter (&model_input_image_datatype)) {
|
||||||
|
static GEnumValue model_input_image_datatype_types[] = {
|
||||||
|
{GST_TENSOR_TYPE_INT8, "8 Bits integer", "int8"},
|
||||||
|
{GST_TENSOR_TYPE_FLOAT32, "32 Bits floating points", "float"},
|
||||||
|
{0, NULL, NULL},
|
||||||
|
};
|
||||||
|
|
||||||
|
GType temp = g_enum_register_static ("GstTensorType",
|
||||||
|
model_input_image_datatype_types);
|
||||||
|
|
||||||
|
g_once_init_leave (&model_input_image_datatype, temp);
|
||||||
|
}
|
||||||
|
|
||||||
|
return model_input_image_datatype;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
static void
|
static void
|
||||||
gst_onnx_inference_class_init (GstOnnxInferenceClass * klass)
|
gst_onnx_inference_class_init (GstOnnxInferenceClass * klass)
|
||||||
{
|
{
|
||||||
|
@ -320,6 +346,22 @@ gst_onnx_inference_class_init (GstOnnxInferenceClass * klass)
|
||||||
GST_ONNX_EXECUTION_PROVIDER_CPU, (GParamFlags)
|
GST_ONNX_EXECUTION_PROVIDER_CPU, (GParamFlags)
|
||||||
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
||||||
|
|
||||||
|
/**
|
||||||
|
* GstOnnxInference:input-image-datatype
|
||||||
|
*
|
||||||
|
* Temporary hack, this should be discovered from the model and exposed
|
||||||
|
* on sinkpad caps based on model contrains.
|
||||||
|
*/
|
||||||
|
|
||||||
|
g_object_class_install_property (G_OBJECT_CLASS (klass),
|
||||||
|
PROP_INPUT_IMAGE_DATATYPE,
|
||||||
|
g_param_spec_enum ("input-image-datatype",
|
||||||
|
"Inference input image datatype",
|
||||||
|
"Datatype that will be used as an input for the inference",
|
||||||
|
GST_TYPE_ONNX_MODEL_INPUT_IMAGE_DATATYPE,
|
||||||
|
GST_TENSOR_TYPE_INT8,
|
||||||
|
(GParamFlags)(G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));
|
||||||
|
|
||||||
gst_element_class_set_static_metadata (element_class, "onnxinference",
|
gst_element_class_set_static_metadata (element_class, "onnxinference",
|
||||||
"Filter/Effect/Video",
|
"Filter/Effect/Video",
|
||||||
"Apply neural network to video frames and create tensor output",
|
"Apply neural network to video frames and create tensor output",
|
||||||
|
@ -387,6 +429,10 @@ gst_onnx_inference_set_property (GObject * object, guint prop_id,
|
||||||
onnxClient->setInputImageFormat ((GstMlInputImageFormat)
|
onnxClient->setInputImageFormat ((GstMlInputImageFormat)
|
||||||
g_value_get_enum (value));
|
g_value_get_enum (value));
|
||||||
break;
|
break;
|
||||||
|
case PROP_INPUT_IMAGE_DATATYPE:
|
||||||
|
onnxClient->setInputImageDatatype ((GstTensorType)
|
||||||
|
g_value_get_enum (value));
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
|
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
|
||||||
break;
|
break;
|
||||||
|
@ -413,6 +459,9 @@ gst_onnx_inference_get_property (GObject * object, guint prop_id,
|
||||||
case PROP_INPUT_IMAGE_FORMAT:
|
case PROP_INPUT_IMAGE_FORMAT:
|
||||||
g_value_set_enum (value, onnxClient->getInputImageFormat ());
|
g_value_set_enum (value, onnxClient->getInputImageFormat ());
|
||||||
break;
|
break;
|
||||||
|
case PROP_INPUT_IMAGE_DATATYPE:
|
||||||
|
g_value_set_enum (value, onnxClient->getInputImageDatatype ());
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
|
G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
|
||||||
break;
|
break;
|
||||||
|
@ -547,6 +596,7 @@ gst_onnx_inference_process (GstBaseTransform * trans, GstBuffer * buf)
|
||||||
auto meta = client->copy_tensors_to_meta (outputs, buf);
|
auto meta = client->copy_tensors_to_meta (outputs, buf);
|
||||||
if (!meta)
|
if (!meta)
|
||||||
return FALSE;
|
return FALSE;
|
||||||
|
GST_TRACE_OBJECT (trans, "Num tensors:%d", meta->num_tensors);
|
||||||
meta->batch_size = 1;
|
meta->batch_size = 1;
|
||||||
}
|
}
|
||||||
catch (Ort::Exception & ortex) {
|
catch (Ort::Exception & ortex) {
|
||||||
|
|
Loading…
Reference in a new issue