/* * GStreamer gstreamer-onnxclient * Copyright (C) 2021 Collabora Ltd * * gstonnxclient.cpp * * This library is free software; you can redistribute it and/or * modify it under the terms of the GNU Library General Public * License as published by the Free Software Foundation; either * version 2 of the License, or (at your option) any later version. * * This library is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Library General Public License for more details. * * You should have received a copy of the GNU Library General Public * License along with this library; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301, USA. */ #include "gstonnxclient.h" #include #ifdef GST_ML_ONNX_RUNTIME_HAVE_CUDA #include #endif #include #include #include #include #include #include #include namespace GstOnnxNamespace { template < typename T > std::ostream & operator<< (std::ostream & os, const std::vector < T > &v) { os << "["; for (size_t i = 0; i < v.size (); ++i) { os << v[i]; if (i != v.size () - 1) { os << ", "; } } os << "]"; return os; } GstMlOutputNodeInfo::GstMlOutputNodeInfo (void):index (GST_ML_NODE_INDEX_DISABLED), type (ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { } GstOnnxClient::GstOnnxClient ():session (nullptr), width (0), height (0), channels (0), dest (nullptr), m_provider (GST_ONNX_EXECUTION_PROVIDER_CPU), inputImageFormat (GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC), fixedInputImageSize (true) { for (size_t i = 0; i < GST_ML_OUTPUT_NODE_NUMBER_OF; ++i) outputNodeIndexToFunction[i] = (GstMlOutputNodeFunction) i; } GstOnnxClient::~GstOnnxClient () { delete session; delete[]dest; } Ort::Env & GstOnnxClient::getEnv (void) { static Ort::Env env (OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, "GstOnnxNamespace"); return env; } int32_t GstOnnxClient::getWidth (void) { return width; } int32_t GstOnnxClient::getHeight (void) { return height; } bool GstOnnxClient::isFixedInputImageSize (void) { return fixedInputImageSize; } std::string GstOnnxClient::getOutputNodeName (GstMlOutputNodeFunction nodeType) { switch (nodeType) { case GST_ML_OUTPUT_NODE_FUNCTION_DETECTION: return "detection"; break; case GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX: return "bounding box"; break; case GST_ML_OUTPUT_NODE_FUNCTION_SCORE: return "score"; break; case GST_ML_OUTPUT_NODE_FUNCTION_CLASS: return "label"; break; }; return ""; } void GstOnnxClient::setInputImageFormat (GstMlModelInputImageFormat format) { inputImageFormat = format; } GstMlModelInputImageFormat GstOnnxClient::getInputImageFormat (void) { return inputImageFormat; } std::vector < const char *>GstOnnxClient::getOutputNodeNames (void) { return outputNames; } void GstOnnxClient::setOutputNodeIndex (GstMlOutputNodeFunction node, gint index) { g_assert (index < GST_ML_OUTPUT_NODE_NUMBER_OF); outputNodeInfo[node].index = index; if (index != GST_ML_NODE_INDEX_DISABLED) outputNodeIndexToFunction[index] = node; } gint GstOnnxClient::getOutputNodeIndex (GstMlOutputNodeFunction node) { return outputNodeInfo[node].index; } void GstOnnxClient::setOutputNodeType (GstMlOutputNodeFunction node, ONNXTensorElementDataType type) { outputNodeInfo[node].type = type; } ONNXTensorElementDataType GstOnnxClient::getOutputNodeType (GstMlOutputNodeFunction node) { return outputNodeInfo[node].type; } bool GstOnnxClient::hasSession (void) { return session != nullptr; } bool GstOnnxClient::createSession (std::string modelFile, GstOnnxOptimizationLevel optim, GstOnnxExecutionProvider provider) { if (session) return true; GraphOptimizationLevel onnx_optim; switch (optim) { case GST_ONNX_OPTIMIZATION_LEVEL_DISABLE_ALL: onnx_optim = GraphOptimizationLevel::ORT_DISABLE_ALL; break; case GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_BASIC: onnx_optim = GraphOptimizationLevel::ORT_ENABLE_BASIC; break; case GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED: onnx_optim = GraphOptimizationLevel::ORT_ENABLE_EXTENDED; break; case GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_ALL: onnx_optim = GraphOptimizationLevel::ORT_ENABLE_ALL; break; default: onnx_optim = GraphOptimizationLevel::ORT_ENABLE_EXTENDED; break; }; Ort::SessionOptions sessionOptions; // for debugging //sessionOptions.SetIntraOpNumThreads (1); sessionOptions.SetGraphOptimizationLevel (onnx_optim); m_provider = provider; switch (m_provider) { case GST_ONNX_EXECUTION_PROVIDER_CUDA: #ifdef GST_ML_ONNX_RUNTIME_HAVE_CUDA Ort::ThrowOnError (OrtSessionOptionsAppendExecutionProvider_CUDA (sessionOptions, 0)); #else return false; #endif break; default: break; }; session = new Ort::Session (getEnv (), modelFile.c_str (), sessionOptions); auto inputTypeInfo = session->GetInputTypeInfo (0); std::vector < int64_t > inputDims = inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape (); if (inputImageFormat == GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC) { height = inputDims[1]; width = inputDims[2]; channels = inputDims[3]; } else { channels = inputDims[1]; height = inputDims[2]; width = inputDims[3]; } fixedInputImageSize = width > 0 && height > 0; GST_DEBUG ("Number of Output Nodes: %d", (gint) session->GetOutputCount ()); Ort::AllocatorWithDefaultOptions allocator; GST_DEBUG ("Input name: %s", session->GetInputName (0, allocator)); for (size_t i = 0; i < session->GetOutputCount (); ++i) { auto output_name = session->GetOutputName (i, allocator); outputNames.push_back (output_name); auto type_info = session->GetOutputTypeInfo (i); auto tensor_info = type_info.GetTensorTypeAndShapeInfo (); if (i < GST_ML_OUTPUT_NODE_NUMBER_OF) { auto function = outputNodeIndexToFunction[i]; outputNodeInfo[function].type = tensor_info.GetElementType (); } } return true; } std::vector < GstMlBoundingBox > GstOnnxClient::run (uint8_t * img_data, GstVideoMeta * vmeta, std::string labelPath, float scoreThreshold) { auto type = getOutputNodeType (GST_ML_OUTPUT_NODE_FUNCTION_CLASS); return (type == ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) ? doRun < float >(img_data, vmeta, labelPath, scoreThreshold) : doRun < int >(img_data, vmeta, labelPath, scoreThreshold); } void GstOnnxClient::parseDimensions (GstVideoMeta * vmeta) { int32_t newWidth = fixedInputImageSize ? width : vmeta->width; int32_t newHeight = fixedInputImageSize ? height : vmeta->height; if (!dest || width * height < newWidth * newHeight) { delete[] dest; dest = new uint8_t[newWidth * newHeight * channels]; } width = newWidth; height = newHeight; } template < typename T > std::vector < GstMlBoundingBox > GstOnnxClient::doRun (uint8_t * img_data, GstVideoMeta * vmeta, std::string labelPath, float scoreThreshold) { std::vector < GstMlBoundingBox > boundingBoxes; if (!img_data) return boundingBoxes; parseDimensions (vmeta); Ort::AllocatorWithDefaultOptions allocator; auto inputName = session->GetInputName (0, allocator); auto inputTypeInfo = session->GetInputTypeInfo (0); std::vector < int64_t > inputDims = inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape (); inputDims[0] = 1; if (inputImageFormat == GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC) { inputDims[1] = height; inputDims[2] = width; } else { inputDims[2] = height; inputDims[3] = width; } std::ostringstream buffer; buffer << inputDims; GST_DEBUG ("Input dimensions: %s", buffer.str ().c_str ()); // copy video frame uint8_t *srcPtr[3] = { img_data, img_data + 1, img_data + 2 }; uint32_t srcSamplesPerPixel = 3; switch (vmeta->format) { case GST_VIDEO_FORMAT_RGBA: srcSamplesPerPixel = 4; break; case GST_VIDEO_FORMAT_BGRA: srcSamplesPerPixel = 4; srcPtr[0] = img_data + 2; srcPtr[1] = img_data + 1; srcPtr[2] = img_data + 0; break; case GST_VIDEO_FORMAT_ARGB: srcSamplesPerPixel = 4; srcPtr[0] = img_data + 1; srcPtr[1] = img_data + 2; srcPtr[2] = img_data + 3; break; case GST_VIDEO_FORMAT_ABGR: srcSamplesPerPixel = 4; srcPtr[0] = img_data + 3; srcPtr[1] = img_data + 2; srcPtr[2] = img_data + 1; break; case GST_VIDEO_FORMAT_BGR: srcPtr[0] = img_data + 2; srcPtr[1] = img_data + 1; srcPtr[2] = img_data + 0; break; default: break; } size_t destIndex = 0; uint32_t stride = vmeta->stride[0]; if (inputImageFormat == GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC) { for (int32_t j = 0; j < height; ++j) { for (int32_t i = 0; i < width; ++i) { for (int32_t k = 0; k < channels; ++k) { dest[destIndex++] = *srcPtr[k]; srcPtr[k] += srcSamplesPerPixel; } } // correct for stride for (uint32_t k = 0; k < 3; ++k) srcPtr[k] += stride - srcSamplesPerPixel * width; } } else { size_t frameSize = width * height; uint8_t *destPtr[3] = { dest, dest + frameSize, dest + 2 * frameSize }; for (int32_t j = 0; j < height; ++j) { for (int32_t i = 0; i < width; ++i) { for (int32_t k = 0; k < channels; ++k) { destPtr[k][destIndex] = *srcPtr[k]; srcPtr[k] += srcSamplesPerPixel; } destIndex++; } // correct for stride for (uint32_t k = 0; k < 3; ++k) srcPtr[k] += stride - srcSamplesPerPixel * width; } } const size_t inputTensorSize = width * height * channels; auto memoryInfo = Ort::MemoryInfo::CreateCpu (OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault); std::vector < Ort::Value > inputTensors; inputTensors.push_back (Ort::Value::CreateTensor < uint8_t > (memoryInfo, dest, inputTensorSize, inputDims.data (), inputDims.size ())); std::vector < const char *>inputNames { inputName }; std::vector < Ort::Value > modelOutput = session->Run (Ort::RunOptions { nullptr}, inputNames.data (), inputTensors.data (), 1, outputNames.data (), outputNames.size ()); auto numDetections = modelOutput[getOutputNodeIndex (GST_ML_OUTPUT_NODE_FUNCTION_DETECTION)].GetTensorMutableData < float >(); auto bboxes = modelOutput[getOutputNodeIndex (GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX)].GetTensorMutableData < float >(); auto scores = modelOutput[getOutputNodeIndex (GST_ML_OUTPUT_NODE_FUNCTION_SCORE)].GetTensorMutableData < float >(); T *labelIndex = nullptr; if (getOutputNodeIndex (GST_ML_OUTPUT_NODE_FUNCTION_CLASS) != GST_ML_NODE_INDEX_DISABLED) { labelIndex = modelOutput[getOutputNodeIndex (GST_ML_OUTPUT_NODE_FUNCTION_CLASS)].GetTensorMutableData < T > (); } if (labels.empty () && !labelPath.empty ()) labels = ReadLabels (labelPath); for (int i = 0; i < numDetections[0]; ++i) { if (scores[i] > scoreThreshold) { std::string label = ""; if (labelIndex && !labels.empty ()) label = labels[labelIndex[i] - 1]; auto score = scores[i]; auto y0 = bboxes[i * 4] * height; auto x0 = bboxes[i * 4 + 1] * width; auto bheight = bboxes[i * 4 + 2] * height - y0; auto bwidth = bboxes[i * 4 + 3] * width - x0; boundingBoxes.push_back (GstMlBoundingBox (label, score, x0, y0, bwidth, bheight)); } } return boundingBoxes; } std::vector < std::string > GstOnnxClient::ReadLabels (const std::string & labelsFile) { std::vector < std::string > labels; std::string line; std::ifstream fp (labelsFile); while (std::getline (fp, line)) labels.push_back (line); return labels; } }