From ad4fa052e516b76567ec986dcec4ae44dad0bc37 Mon Sep 17 00:00:00 2001 From: Daniel Morin Date: Sun, 19 Jan 2025 15:10:53 -0500 Subject: [PATCH] gst-python: Test for GstAnalyticsRelationMeta iterator Part-of: --- .../gst-python/testsuite/test_analytics.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/subprojects/gst-python/testsuite/test_analytics.py b/subprojects/gst-python/testsuite/test_analytics.py index f65ece65ff..bf83827722 100644 --- a/subprojects/gst-python/testsuite/test_analytics.py +++ b/subprojects/gst-python/testsuite/test_analytics.py @@ -37,6 +37,7 @@ from gi.repository import Gst from gi.repository import GstAnalytics from gi.repository import GstVideo Gst.init(None) +GstAnalytics.init() class TestAnalyticsODMtd(TestCase): @@ -245,3 +246,48 @@ class TestAnalyticsTensorMeta(TestCase): GstAnalytics.TensorDimOrder.ROW_MAJOR, [0, 2, 5]) self.assertIsNotNone(tensor3) + + +class TestAnalyticsRelationMetaIterator(TestCase): + def test(self): + buf = Gst.Buffer() + self.assertIsNotNone(buf) + + rmeta = GstAnalytics.buffer_add_analytics_relation_meta(buf) + self.assertIsNotNone(rmeta) + + mask_buf = Gst.Buffer.new_allocate(None, 100, None) + GstVideo.buffer_add_video_meta(mask_buf, + GstVideo.VideoFrameFlags.NONE, + GstVideo.VideoFormat.GRAY8, 10, 10) + + (_, od_mtd) = rmeta.add_od_mtd(GLib.quark_from_string("od"), 1, 1, 2, 2, 0.1) + (_, cls_mtd) = rmeta.add_one_cls_mtd(0.1, GLib.quark_from_string("cls")) + (_, trk_mtd) = rmeta.add_tracking_mtd(1, 10) + (_, seg_mtd) = rmeta.add_segmentation_mtd(mask_buf, + GstAnalytics.SegmentationType.SEMANTIC, + [7, 4, 2], 0, 0, 7, 13) + + mtds = [ + (od_mtd, GstAnalytics.ODMtd.get_mtd_type()), + (cls_mtd, GstAnalytics.ClsMtd.get_mtd_type()), + (trk_mtd, GstAnalytics.TrackingMtd.get_mtd_type()), + (seg_mtd, GstAnalytics.SegmentationMtd.get_mtd_type()) + ] + + mtds_from_iter = list(rmeta) + + self.assertEqual(len(mtds), len(mtds_from_iter)) + + for e, i in zip(mtds, rmeta): + assert e[0].id == i.id + assert e[0].meta == i.meta + assert e[1] == i.get_mtd_type() + + # Validate that the object is really a ODMtd + location = mtds_from_iter[0].get_location() + self.assertEqual(location[1], 1) + self.assertEqual(location[2], 1) + self.assertEqual(location[3], 2) + self.assertEqual(location[4], 2) + self.assertAlmostEqual(location[5], 0.1, 3)