mirror of
https://gitlab.freedesktop.org/gstreamer/gstreamer.git
synced 2025-04-26 06:54:49 +00:00
onnx: Use the element pointer for debug message
Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/5885>
This commit is contained in:
parent
54b361c554
commit
83c2d30438
3 changed files with 24 additions and 16 deletions
|
@ -24,6 +24,8 @@
|
||||||
#include <cpu_provider_factory.h>
|
#include <cpu_provider_factory.h>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
|
#define GST_CAT_DEFAULT onnx_inference_debug
|
||||||
|
|
||||||
namespace GstOnnxNamespace
|
namespace GstOnnxNamespace
|
||||||
{
|
{
|
||||||
template < typename T >
|
template < typename T >
|
||||||
|
@ -43,7 +45,8 @@ namespace GstOnnxNamespace
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
GstOnnxClient::GstOnnxClient ():session (nullptr),
|
GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_parent),
|
||||||
|
session (nullptr),
|
||||||
width (0),
|
width (0),
|
||||||
height (0),
|
height (0),
|
||||||
channels (0),
|
channels (0),
|
||||||
|
@ -54,7 +57,8 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
|
||||||
inputDatatypeSize (sizeof (uint8_t)),
|
inputDatatypeSize (sizeof (uint8_t)),
|
||||||
fixedInputImageSize (false),
|
fixedInputImageSize (false),
|
||||||
inputTensorOffset (0.0),
|
inputTensorOffset (0.0),
|
||||||
inputTensorScale (1.0) {
|
inputTensorScale (1.0)
|
||||||
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
GstOnnxClient::~GstOnnxClient () {
|
GstOnnxClient::~GstOnnxClient () {
|
||||||
|
@ -224,7 +228,7 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
|
||||||
}
|
}
|
||||||
|
|
||||||
fixedInputImageSize = width > 0 && height > 0;
|
fixedInputImageSize = width > 0 && height > 0;
|
||||||
GST_DEBUG ("Number of Output Nodes: %d",
|
GST_DEBUG_OBJECT (debug_parent, "Number of Output Nodes: %d",
|
||||||
(gint) session->GetOutputCount ());
|
(gint) session->GetOutputCount ());
|
||||||
|
|
||||||
ONNXTensorElementDataType elementType =
|
ONNXTensorElementDataType elementType =
|
||||||
|
@ -238,17 +242,18 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
|
||||||
setInputImageDatatype(GST_TENSOR_TYPE_FLOAT32);
|
setInputImageDatatype(GST_TENSOR_TYPE_FLOAT32);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
GST_ERROR ("Only input tensors of type int8 and float are supported");
|
GST_ERROR_OBJECT (debug_parent,
|
||||||
|
"Only input tensors of type int8 and floatare supported");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ort::AllocatorWithDefaultOptions allocator;
|
Ort::AllocatorWithDefaultOptions allocator;
|
||||||
auto input_name = session->GetInputNameAllocated (0, allocator);
|
auto input_name = session->GetInputNameAllocated (0, allocator);
|
||||||
GST_DEBUG ("Input name: %s", input_name.get ());
|
GST_DEBUG_OBJECT (debug_parent, "Input name: %s", input_name.get ());
|
||||||
|
|
||||||
for (size_t i = 0; i < session->GetOutputCount (); ++i) {
|
for (size_t i = 0; i < session->GetOutputCount (); ++i) {
|
||||||
auto output_name = session->GetOutputNameAllocated (i, allocator);
|
auto output_name = session->GetOutputNameAllocated (i, allocator);
|
||||||
GST_DEBUG ("Output name %lu:%s", i, output_name.get ());
|
GST_DEBUG_OBJECT (debug_parent, "Output name %lu:%s", i, output_name.get ());
|
||||||
outputNames.push_back (std::move (output_name));
|
outputNames.push_back (std::move (output_name));
|
||||||
}
|
}
|
||||||
genOutputNamesRaw ();
|
genOutputNamesRaw ();
|
||||||
|
@ -261,7 +266,7 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
|
||||||
if (status) {
|
if (status) {
|
||||||
// Handle the error case
|
// Handle the error case
|
||||||
const char *errorString = Ort::GetApi ().GetErrorMessage (status);
|
const char *errorString = Ort::GetApi ().GetErrorMessage (status);
|
||||||
GST_WARNING ("Failed to get allocator: %s", errorString);
|
GST_WARNING_OBJECT (debug_parent, "Failed to get allocator: %s", errorString);
|
||||||
|
|
||||||
// Clean up the error status
|
// Clean up the error status
|
||||||
Ort::GetApi ().ReleaseStatus (status);
|
Ort::GetApi ().ReleaseStatus (status);
|
||||||
|
@ -276,14 +281,14 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
|
||||||
GQuark quark = g_quark_from_string (res.get ());
|
GQuark quark = g_quark_from_string (res.get ());
|
||||||
outputIds.push_back (quark);
|
outputIds.push_back (quark);
|
||||||
} else {
|
} else {
|
||||||
GST_ERROR ("Failed to look up id for key %s", name);
|
GST_ERROR_OBJECT (debug_parent, "Failed to look up id for key %s", name);
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
catch (Ort::Exception & ortex) {
|
catch (Ort::Exception & ortex) {
|
||||||
GST_ERROR ("%s", ortex.what ());
|
GST_ERROR_OBJECT (debug_parent, "%s", ortex.what ());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -296,7 +301,7 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
|
||||||
int32_t newHeight = fixedInputImageSize ? height : vinfo.height;
|
int32_t newHeight = fixedInputImageSize ? height : vinfo.height;
|
||||||
|
|
||||||
if (!fixedInputImageSize) {
|
if (!fixedInputImageSize) {
|
||||||
GST_WARNING ("Allocating before knowing model input size");
|
GST_WARNING_OBJECT (debug_parent, "Allocating before knowing model input size");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!dest || width * height < newWidth * newHeight) {
|
if (!dest || width * height < newWidth * newHeight) {
|
||||||
|
@ -356,7 +361,7 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
|
||||||
buffer_size);
|
buffer_size);
|
||||||
tensor->type = GST_TENSOR_TYPE_INT32;
|
tensor->type = GST_TENSOR_TYPE_INT32;
|
||||||
} else {
|
} else {
|
||||||
GST_ERROR ("Output tensor is not FLOAT32 or INT32, not supported");
|
GST_ERROR_OBJECT (debug_parent, "Output tensor is not FLOAT32 or INT32, not supported");
|
||||||
gst_buffer_remove_meta (buffer, (GstMeta*) tmeta);
|
gst_buffer_remove_meta (buffer, (GstMeta*) tmeta);
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
@ -396,7 +401,7 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
|
||||||
|
|
||||||
std::ostringstream buffer;
|
std::ostringstream buffer;
|
||||||
buffer << inputDims;
|
buffer << inputDims;
|
||||||
GST_DEBUG ("Input dimensions: %s", buffer.str ().c_str ());
|
GST_DEBUG_OBJECT (debug_parent, "Input dimensions: %s", buffer.str ().c_str ());
|
||||||
|
|
||||||
// copy video frame
|
// copy video frame
|
||||||
uint8_t *srcPtr[3] = { img_data, img_data + 1, img_data + 2 };
|
uint8_t *srcPtr[3] = { img_data, img_data + 1, img_data + 2 };
|
||||||
|
|
|
@ -28,6 +28,7 @@
|
||||||
#include "gstml.h"
|
#include "gstml.h"
|
||||||
#include "tensor/gsttensormeta.h"
|
#include "tensor/gsttensormeta.h"
|
||||||
|
|
||||||
|
GST_DEBUG_CATEGORY_EXTERN (onnx_inference_debug);
|
||||||
|
|
||||||
typedef enum
|
typedef enum
|
||||||
{
|
{
|
||||||
|
@ -48,7 +49,7 @@ namespace GstOnnxNamespace {
|
||||||
|
|
||||||
class GstOnnxClient {
|
class GstOnnxClient {
|
||||||
public:
|
public:
|
||||||
GstOnnxClient(void);
|
GstOnnxClient(GstElement *debug_parent);
|
||||||
~GstOnnxClient(void);
|
~GstOnnxClient(void);
|
||||||
bool createSession(std::string modelFile, GstOnnxOptimizationLevel optim,
|
bool createSession(std::string modelFile, GstOnnxOptimizationLevel optim,
|
||||||
GstOnnxExecutionProvider provider);
|
GstOnnxExecutionProvider provider);
|
||||||
|
@ -69,7 +70,8 @@ namespace GstOnnxNamespace {
|
||||||
void parseDimensions(GstVideoInfo vinfo);
|
void parseDimensions(GstVideoInfo vinfo);
|
||||||
private:
|
private:
|
||||||
|
|
||||||
void setInputImageDatatype(GstTensorType datatype);
|
GstElement *debug_parent;
|
||||||
|
void setInputImageDatatype (GstTensorType datatype);
|
||||||
template < typename T>
|
template < typename T>
|
||||||
void convert_image_remove_alpha (T *dest, GstMlInputImageFormat hwc,
|
void convert_image_remove_alpha (T *dest, GstMlInputImageFormat hwc,
|
||||||
uint8_t **srcPtr, uint32_t srcSamplesPerPixel, uint32_t stride, T offset, T div);
|
uint8_t **srcPtr, uint32_t srcSamplesPerPixel, uint32_t stride, T offset, T div);
|
||||||
|
|
|
@ -86,7 +86,8 @@ struct _GstOnnxInference
|
||||||
GstVideoInfo video_info;
|
GstVideoInfo video_info;
|
||||||
};
|
};
|
||||||
|
|
||||||
GST_DEBUG_CATEGORY_STATIC (onnx_inference_debug);
|
GST_DEBUG_CATEGORY (onnx_inference_debug);
|
||||||
|
|
||||||
#define GST_CAT_DEFAULT onnx_inference_debug
|
#define GST_CAT_DEFAULT onnx_inference_debug
|
||||||
#define GST_ONNX_CLIENT_MEMBER( self ) ((GstOnnxNamespace::GstOnnxClient *) (self->onnx_client))
|
#define GST_ONNX_CLIENT_MEMBER( self ) ((GstOnnxNamespace::GstOnnxClient *) (self->onnx_client))
|
||||||
GST_ELEMENT_REGISTER_DEFINE (onnx_inference, "onnxinference",
|
GST_ELEMENT_REGISTER_DEFINE (onnx_inference, "onnxinference",
|
||||||
|
@ -335,7 +336,7 @@ gst_onnx_inference_class_init (GstOnnxInferenceClass * klass)
|
||||||
static void
|
static void
|
||||||
gst_onnx_inference_init (GstOnnxInference * self)
|
gst_onnx_inference_init (GstOnnxInference * self)
|
||||||
{
|
{
|
||||||
self->onnx_client = new GstOnnxNamespace::GstOnnxClient ();
|
self->onnx_client = new GstOnnxNamespace::GstOnnxClient (GST_ELEMENT(self));
|
||||||
self->onnx_disabled = TRUE;
|
self->onnx_disabled = TRUE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue