onnxobjectdetector: gracefully handle Ort exceptions rather than dumping core

Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/4765>
This commit is contained in:
Aaron Boxer 2023-06-02 13:21:06 -04:00 committed by GStreamer Marge Bot
parent a1c2df830b
commit e624e7c695
2 changed files with 63 additions and 48 deletions

View file

@ -203,6 +203,7 @@ bool GstOnnxClient::createSession (std::string modelFile,
break; break;
}; };
try {
Ort::SessionOptions sessionOptions; Ort::SessionOptions sessionOptions;
// for debugging // for debugging
//sessionOptions.SetIntraOpNumThreads (1); //sessionOptions.SetIntraOpNumThreads (1);
@ -214,6 +215,7 @@ bool GstOnnxClient::createSession (std::string modelFile,
Ort::ThrowOnError (OrtSessionOptionsAppendExecutionProvider_CUDA Ort::ThrowOnError (OrtSessionOptionsAppendExecutionProvider_CUDA
(sessionOptions, 0)); (sessionOptions, 0));
#else #else
GST_ERROR ("ONNX CUDA execution provider not supported");
return false; return false;
#endif #endif
break; break;
@ -221,7 +223,8 @@ bool GstOnnxClient::createSession (std::string modelFile,
break; break;
}; };
session = new Ort::Session (getEnv (), modelFile.c_str (), sessionOptions); session =
new Ort::Session (getEnv (), modelFile.c_str (), sessionOptions);
auto inputTypeInfo = session->GetInputTypeInfo (0); auto inputTypeInfo = session->GetInputTypeInfo (0);
std::vector < int64_t > inputDims = std::vector < int64_t > inputDims =
inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape (); inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape ();
@ -236,16 +239,17 @@ bool GstOnnxClient::createSession (std::string modelFile,
} }
fixedInputImageSize = width > 0 && height > 0; fixedInputImageSize = width > 0 && height > 0;
GST_DEBUG ("Number of Output Nodes: %d", (gint) session->GetOutputCount ()); GST_DEBUG ("Number of Output Nodes: %d",
(gint) session->GetOutputCount ());
Ort::AllocatorWithDefaultOptions allocator; Ort::AllocatorWithDefaultOptions allocator;
auto input_name = session->GetInputNameAllocated (0, allocator); auto input_name = session->GetInputNameAllocated (0, allocator);
GST_DEBUG ("Input name: %s", input_name.get()); GST_DEBUG ("Input name: %s", input_name.get ());
for (size_t i = 0; i < session->GetOutputCount (); ++i) { for (size_t i = 0; i < session->GetOutputCount (); ++i) {
auto output_name = session->GetOutputNameAllocated (i, allocator); auto output_name = session->GetOutputNameAllocated (i, allocator);
GST_DEBUG("Output name %lu:%s", i, output_name.get()); GST_DEBUG ("Output name %lu:%s", i, output_name.get ());
outputNames.push_back (std::move(output_name)); outputNames.push_back (std::move (output_name));
auto type_info = session->GetOutputTypeInfo (i); auto type_info = session->GetOutputTypeInfo (i);
auto tensor_info = type_info.GetTensorTypeAndShapeInfo (); auto tensor_info = type_info.GetTensorTypeAndShapeInfo ();
@ -254,6 +258,11 @@ bool GstOnnxClient::createSession (std::string modelFile,
outputNodeInfo[function].type = tensor_info.GetElementType (); outputNodeInfo[function].type = tensor_info.GetElementType ();
} }
} }
}
catch (Ort::Exception & ortex) {
GST_ERROR ("%s", ortex.what ());
return false;
}
return true; return true;
} }

View file

@ -643,9 +643,15 @@ gst_onnx_object_detector_process (GstBaseTransform * trans, GstBuffer * buf)
} }
if (gst_buffer_map (buf, &info, GST_MAP_READ)) { if (gst_buffer_map (buf, &info, GST_MAP_READ)) {
GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (trans); GstOnnxObjectDetector *self = GST_ONNX_OBJECT_DETECTOR (trans);
auto boxes = GST_ONNX_MEMBER (self)->run (info.data, vmeta, std::vector < GstOnnxNamespace::GstMlBoundingBox > boxes;
self->label_file ? self->label_file : "", try {
self->score_threshold); boxes = GST_ONNX_MEMBER (self)->run (info.data, vmeta,
self->label_file ? self->label_file : "", self->score_threshold);
}
catch (Ort::Exception & ortex) {
GST_ERROR_OBJECT (self, "%s", ortex.what ());
return FALSE;
}
for (auto & b:boxes) { for (auto & b:boxes) {
auto vroi_meta = gst_buffer_add_video_region_of_interest_meta (buf, auto vroi_meta = gst_buffer_add_video_region_of_interest_meta (buf,
GST_ONNX_OBJECT_DETECTOR_META_NAME, GST_ONNX_OBJECT_DETECTOR_META_NAME,