onnx: Use the element pointer for debug message

Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/5885>
This commit is contained in:
Olivier Crête 2024-01-04 15:07:10 -05:00 committed by GStreamer Marge Bot
parent 54b361c554
commit 83c2d30438
3 changed files with 24 additions and 16 deletions

View file

@ -24,6 +24,8 @@
#include <cpu_provider_factory.h>
#include <sstream>
#define GST_CAT_DEFAULT onnx_inference_debug
namespace GstOnnxNamespace
{
template < typename T >
@ -43,7 +45,8 @@ namespace GstOnnxNamespace
return os;
}
GstOnnxClient::GstOnnxClient ():session (nullptr),
GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_parent),
session (nullptr),
width (0),
height (0),
channels (0),
@ -54,7 +57,8 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
inputDatatypeSize (sizeof (uint8_t)),
fixedInputImageSize (false),
inputTensorOffset (0.0),
inputTensorScale (1.0) {
inputTensorScale (1.0)
{
}
GstOnnxClient::~GstOnnxClient () {
@ -224,7 +228,7 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
}
fixedInputImageSize = width > 0 && height > 0;
GST_DEBUG ("Number of Output Nodes: %d",
GST_DEBUG_OBJECT (debug_parent, "Number of Output Nodes: %d",
(gint) session->GetOutputCount ());
ONNXTensorElementDataType elementType =
@ -238,17 +242,18 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
setInputImageDatatype(GST_TENSOR_TYPE_FLOAT32);
break;
default:
GST_ERROR ("Only input tensors of type int8 and float are supported");
GST_ERROR_OBJECT (debug_parent,
"Only input tensors of type int8 and floatare supported");
return false;
}
Ort::AllocatorWithDefaultOptions allocator;
auto input_name = session->GetInputNameAllocated (0, allocator);
GST_DEBUG ("Input name: %s", input_name.get ());
GST_DEBUG_OBJECT (debug_parent, "Input name: %s", input_name.get ());
for (size_t i = 0; i < session->GetOutputCount (); ++i) {
auto output_name = session->GetOutputNameAllocated (i, allocator);
GST_DEBUG ("Output name %lu:%s", i, output_name.get ());
GST_DEBUG_OBJECT (debug_parent, "Output name %lu:%s", i, output_name.get ());
outputNames.push_back (std::move (output_name));
}
genOutputNamesRaw ();
@ -261,7 +266,7 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
if (status) {
// Handle the error case
const char *errorString = Ort::GetApi ().GetErrorMessage (status);
GST_WARNING ("Failed to get allocator: %s", errorString);
GST_WARNING_OBJECT (debug_parent, "Failed to get allocator: %s", errorString);
// Clean up the error status
Ort::GetApi ().ReleaseStatus (status);
@ -276,14 +281,14 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
GQuark quark = g_quark_from_string (res.get ());
outputIds.push_back (quark);
} else {
GST_ERROR ("Failed to look up id for key %s", name);
GST_ERROR_OBJECT (debug_parent, "Failed to look up id for key %s", name);
return false;
}
}
}
catch (Ort::Exception & ortex) {
GST_ERROR ("%s", ortex.what ());
GST_ERROR_OBJECT (debug_parent, "%s", ortex.what ());
return false;
}
@ -296,7 +301,7 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
int32_t newHeight = fixedInputImageSize ? height : vinfo.height;
if (!fixedInputImageSize) {
GST_WARNING ("Allocating before knowing model input size");
GST_WARNING_OBJECT (debug_parent, "Allocating before knowing model input size");
}
if (!dest || width * height < newWidth * newHeight) {
@ -356,7 +361,7 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
buffer_size);
tensor->type = GST_TENSOR_TYPE_INT32;
} else {
GST_ERROR ("Output tensor is not FLOAT32 or INT32, not supported");
GST_ERROR_OBJECT (debug_parent, "Output tensor is not FLOAT32 or INT32, not supported");
gst_buffer_remove_meta (buffer, (GstMeta*) tmeta);
return NULL;
}
@ -396,7 +401,7 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
std::ostringstream buffer;
buffer << inputDims;
GST_DEBUG ("Input dimensions: %s", buffer.str ().c_str ());
GST_DEBUG_OBJECT (debug_parent, "Input dimensions: %s", buffer.str ().c_str ());
// copy video frame
uint8_t *srcPtr[3] = { img_data, img_data + 1, img_data + 2 };

View file

@ -28,6 +28,7 @@
#include "gstml.h"
#include "tensor/gsttensormeta.h"
GST_DEBUG_CATEGORY_EXTERN (onnx_inference_debug);
typedef enum
{
@ -48,7 +49,7 @@ namespace GstOnnxNamespace {
class GstOnnxClient {
public:
GstOnnxClient(void);
GstOnnxClient(GstElement *debug_parent);
~GstOnnxClient(void);
bool createSession(std::string modelFile, GstOnnxOptimizationLevel optim,
GstOnnxExecutionProvider provider);
@ -69,7 +70,8 @@ namespace GstOnnxNamespace {
void parseDimensions(GstVideoInfo vinfo);
private:
void setInputImageDatatype(GstTensorType datatype);
GstElement *debug_parent;
void setInputImageDatatype (GstTensorType datatype);
template < typename T>
void convert_image_remove_alpha (T *dest, GstMlInputImageFormat hwc,
uint8_t **srcPtr, uint32_t srcSamplesPerPixel, uint32_t stride, T offset, T div);

View file

@ -86,7 +86,8 @@ struct _GstOnnxInference
GstVideoInfo video_info;
};
GST_DEBUG_CATEGORY_STATIC (onnx_inference_debug);
GST_DEBUG_CATEGORY (onnx_inference_debug);
#define GST_CAT_DEFAULT onnx_inference_debug
#define GST_ONNX_CLIENT_MEMBER( self ) ((GstOnnxNamespace::GstOnnxClient *) (self->onnx_client))
GST_ELEMENT_REGISTER_DEFINE (onnx_inference, "onnxinference",
@ -335,7 +336,7 @@ gst_onnx_inference_class_init (GstOnnxInferenceClass * klass)
static void
gst_onnx_inference_init (GstOnnxInference * self)
{
self->onnx_client = new GstOnnxNamespace::GstOnnxClient ();
self->onnx_client = new GstOnnxNamespace::GstOnnxClient (GST_ELEMENT(self));
self->onnx_disabled = TRUE;
}