onnx: Use the element pointer for debug message

Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/5885>
This commit is contained in:
Olivier Crête 2024-01-04 15:07:10 -05:00 committed by GStreamer Marge Bot
parent 54b361c554
commit 83c2d30438
3 changed files with 24 additions and 16 deletions

View file

@ -24,6 +24,8 @@
#include <cpu_provider_factory.h> #include <cpu_provider_factory.h>
#include <sstream> #include <sstream>
#define GST_CAT_DEFAULT onnx_inference_debug
namespace GstOnnxNamespace namespace GstOnnxNamespace
{ {
template < typename T > template < typename T >
@ -43,7 +45,8 @@ namespace GstOnnxNamespace
return os; return os;
} }
GstOnnxClient::GstOnnxClient ():session (nullptr), GstOnnxClient::GstOnnxClient (GstElement *debug_parent):debug_parent(debug_parent),
session (nullptr),
width (0), width (0),
height (0), height (0),
channels (0), channels (0),
@ -54,7 +57,8 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
inputDatatypeSize (sizeof (uint8_t)), inputDatatypeSize (sizeof (uint8_t)),
fixedInputImageSize (false), fixedInputImageSize (false),
inputTensorOffset (0.0), inputTensorOffset (0.0),
inputTensorScale (1.0) { inputTensorScale (1.0)
{
} }
GstOnnxClient::~GstOnnxClient () { GstOnnxClient::~GstOnnxClient () {
@ -224,7 +228,7 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
} }
fixedInputImageSize = width > 0 && height > 0; fixedInputImageSize = width > 0 && height > 0;
GST_DEBUG ("Number of Output Nodes: %d", GST_DEBUG_OBJECT (debug_parent, "Number of Output Nodes: %d",
(gint) session->GetOutputCount ()); (gint) session->GetOutputCount ());
ONNXTensorElementDataType elementType = ONNXTensorElementDataType elementType =
@ -238,17 +242,18 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
setInputImageDatatype(GST_TENSOR_TYPE_FLOAT32); setInputImageDatatype(GST_TENSOR_TYPE_FLOAT32);
break; break;
default: default:
GST_ERROR ("Only input tensors of type int8 and float are supported"); GST_ERROR_OBJECT (debug_parent,
"Only input tensors of type int8 and floatare supported");
return false; return false;
} }
Ort::AllocatorWithDefaultOptions allocator; Ort::AllocatorWithDefaultOptions allocator;
auto input_name = session->GetInputNameAllocated (0, allocator); auto input_name = session->GetInputNameAllocated (0, allocator);
GST_DEBUG ("Input name: %s", input_name.get ()); GST_DEBUG_OBJECT (debug_parent, "Input name: %s", input_name.get ());
for (size_t i = 0; i < session->GetOutputCount (); ++i) { for (size_t i = 0; i < session->GetOutputCount (); ++i) {
auto output_name = session->GetOutputNameAllocated (i, allocator); auto output_name = session->GetOutputNameAllocated (i, allocator);
GST_DEBUG ("Output name %lu:%s", i, output_name.get ()); GST_DEBUG_OBJECT (debug_parent, "Output name %lu:%s", i, output_name.get ());
outputNames.push_back (std::move (output_name)); outputNames.push_back (std::move (output_name));
} }
genOutputNamesRaw (); genOutputNamesRaw ();
@ -261,7 +266,7 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
if (status) { if (status) {
// Handle the error case // Handle the error case
const char *errorString = Ort::GetApi ().GetErrorMessage (status); const char *errorString = Ort::GetApi ().GetErrorMessage (status);
GST_WARNING ("Failed to get allocator: %s", errorString); GST_WARNING_OBJECT (debug_parent, "Failed to get allocator: %s", errorString);
// Clean up the error status // Clean up the error status
Ort::GetApi ().ReleaseStatus (status); Ort::GetApi ().ReleaseStatus (status);
@ -276,14 +281,14 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
GQuark quark = g_quark_from_string (res.get ()); GQuark quark = g_quark_from_string (res.get ());
outputIds.push_back (quark); outputIds.push_back (quark);
} else { } else {
GST_ERROR ("Failed to look up id for key %s", name); GST_ERROR_OBJECT (debug_parent, "Failed to look up id for key %s", name);
return false; return false;
} }
} }
} }
catch (Ort::Exception & ortex) { catch (Ort::Exception & ortex) {
GST_ERROR ("%s", ortex.what ()); GST_ERROR_OBJECT (debug_parent, "%s", ortex.what ());
return false; return false;
} }
@ -296,7 +301,7 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
int32_t newHeight = fixedInputImageSize ? height : vinfo.height; int32_t newHeight = fixedInputImageSize ? height : vinfo.height;
if (!fixedInputImageSize) { if (!fixedInputImageSize) {
GST_WARNING ("Allocating before knowing model input size"); GST_WARNING_OBJECT (debug_parent, "Allocating before knowing model input size");
} }
if (!dest || width * height < newWidth * newHeight) { if (!dest || width * height < newWidth * newHeight) {
@ -356,7 +361,7 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
buffer_size); buffer_size);
tensor->type = GST_TENSOR_TYPE_INT32; tensor->type = GST_TENSOR_TYPE_INT32;
} else { } else {
GST_ERROR ("Output tensor is not FLOAT32 or INT32, not supported"); GST_ERROR_OBJECT (debug_parent, "Output tensor is not FLOAT32 or INT32, not supported");
gst_buffer_remove_meta (buffer, (GstMeta*) tmeta); gst_buffer_remove_meta (buffer, (GstMeta*) tmeta);
return NULL; return NULL;
} }
@ -396,7 +401,7 @@ GstOnnxClient::GstOnnxClient ():session (nullptr),
std::ostringstream buffer; std::ostringstream buffer;
buffer << inputDims; buffer << inputDims;
GST_DEBUG ("Input dimensions: %s", buffer.str ().c_str ()); GST_DEBUG_OBJECT (debug_parent, "Input dimensions: %s", buffer.str ().c_str ());
// copy video frame // copy video frame
uint8_t *srcPtr[3] = { img_data, img_data + 1, img_data + 2 }; uint8_t *srcPtr[3] = { img_data, img_data + 1, img_data + 2 };

View file

@ -28,6 +28,7 @@
#include "gstml.h" #include "gstml.h"
#include "tensor/gsttensormeta.h" #include "tensor/gsttensormeta.h"
GST_DEBUG_CATEGORY_EXTERN (onnx_inference_debug);
typedef enum typedef enum
{ {
@ -48,7 +49,7 @@ namespace GstOnnxNamespace {
class GstOnnxClient { class GstOnnxClient {
public: public:
GstOnnxClient(void); GstOnnxClient(GstElement *debug_parent);
~GstOnnxClient(void); ~GstOnnxClient(void);
bool createSession(std::string modelFile, GstOnnxOptimizationLevel optim, bool createSession(std::string modelFile, GstOnnxOptimizationLevel optim,
GstOnnxExecutionProvider provider); GstOnnxExecutionProvider provider);
@ -69,7 +70,8 @@ namespace GstOnnxNamespace {
void parseDimensions(GstVideoInfo vinfo); void parseDimensions(GstVideoInfo vinfo);
private: private:
void setInputImageDatatype(GstTensorType datatype); GstElement *debug_parent;
void setInputImageDatatype (GstTensorType datatype);
template < typename T> template < typename T>
void convert_image_remove_alpha (T *dest, GstMlInputImageFormat hwc, void convert_image_remove_alpha (T *dest, GstMlInputImageFormat hwc,
uint8_t **srcPtr, uint32_t srcSamplesPerPixel, uint32_t stride, T offset, T div); uint8_t **srcPtr, uint32_t srcSamplesPerPixel, uint32_t stride, T offset, T div);

View file

@ -86,7 +86,8 @@ struct _GstOnnxInference
GstVideoInfo video_info; GstVideoInfo video_info;
}; };
GST_DEBUG_CATEGORY_STATIC (onnx_inference_debug); GST_DEBUG_CATEGORY (onnx_inference_debug);
#define GST_CAT_DEFAULT onnx_inference_debug #define GST_CAT_DEFAULT onnx_inference_debug
#define GST_ONNX_CLIENT_MEMBER( self ) ((GstOnnxNamespace::GstOnnxClient *) (self->onnx_client)) #define GST_ONNX_CLIENT_MEMBER( self ) ((GstOnnxNamespace::GstOnnxClient *) (self->onnx_client))
GST_ELEMENT_REGISTER_DEFINE (onnx_inference, "onnxinference", GST_ELEMENT_REGISTER_DEFINE (onnx_inference, "onnxinference",
@ -335,7 +336,7 @@ gst_onnx_inference_class_init (GstOnnxInferenceClass * klass)
static void static void
gst_onnx_inference_init (GstOnnxInference * self) gst_onnx_inference_init (GstOnnxInference * self)
{ {
self->onnx_client = new GstOnnxNamespace::GstOnnxClient (); self->onnx_client = new GstOnnxNamespace::GstOnnxClient (GST_ELEMENT(self));
self->onnx_disabled = TRUE; self->onnx_disabled = TRUE;
} }