From fba15f71c6d80ad6773e29a5c8d7ed723587e8d0 Mon Sep 17 00:00:00 2001
From: Thibault Saunier <tsaunier@gnome.org>
Date: Fri, 25 Apr 2014 11:31:01 +0200
Subject: [PATCH] validate:launcher: Cleanup media descriptor usage

---
 validate/tools/launcher/apps/gst-validate.py | 114 ++++++++++++-------
 1 file changed, 76 insertions(+), 38 deletions(-)

diff --git a/validate/tools/launcher/apps/gst-validate.py b/validate/tools/launcher/apps/gst-validate.py
index 0603edc3d6d..4e895199984 100644
--- a/validate/tools/launcher/apps/gst-validate.py
+++ b/validate/tools/launcher/apps/gst-validate.py
@@ -29,11 +29,45 @@ from utils import MediaFormatCombination, get_profile,\
     path2url, DEFAULT_TIMEOUT, which, GST_SECOND, Result, \
     compare_rendered_with_original, Protocols
 
-def is_image(media_xml):
-    for stream in media_xml.findall("streams")[0].findall("stream"):
-        if stream.attrib["type"] == "image":
-            return True
-    return False
+class MediaDescriptor(object):
+    def __init__(self, xml_path):
+        self.media_xml = ET.parse(xml_path).getroot()
+
+        # Sanity checks
+        self.media_xml.attrib["duration"]
+        self.media_xml.attrib["seekable"]
+
+    def get_caps(self):
+        return self.media_xml.findall("streams")[0].attrib["caps"]
+
+    def get_uri(self):
+        return self.media_xml.attrib["uri"]
+
+    def get_duration(self):
+        return self.media_xml.attrib["duration"]
+
+    def set_protocol(self, protocol):
+        self.media_xml.attrib["protocol"] = protocol
+
+    def get_protocol(self):
+        return self.media_xml.attrib["protocol"]
+
+    def is_seekable(self):
+        return self.media_xml.attrib["seekable"]
+
+    def is_image(self):
+        for stream in self.media_xml.findall("streams")[0].findall("stream"):
+            if stream.attrib["type"] == "image":
+                return True
+        return False
+
+    def num_audio_tracks(media_xml):
+        naudio = 0
+        for stream in media_xml.findall("streams")[0].findall("stream"):
+            if stream.attrib["type"] == "audio":
+                naudio += 1
+
+        return naudio
 
 class PipelineDescriptor(object):
     def __init__(self, name, pipeline):
@@ -138,9 +172,9 @@ G_V_BLACKLISTED_TESTS = \
 
 class GstValidateLaunchTest(GstValidateTest):
     def __init__(self, classname, options, reporter, pipeline_desc,
-                 timeout=DEFAULT_TIMEOUT, scenario=None, media_xml=None):
+                 timeout=DEFAULT_TIMEOUT, scenario=None, media_descriptor=None):
         try:
-            timeout = G_V_PROTOCOL_TIMEOUTS[media_xml.attrib["protocol"]]
+            timeout = G_V_PROTOCOL_TIMEOUTS[media_descriptor.get_protocol()]
         except KeyError:
             pass
 
@@ -150,7 +184,7 @@ class GstValidateLaunchTest(GstValidateTest):
                                               timeout=timeout)
 
         self.pipeline_desc = pipeline_desc
-        self.media_xml = media_xml
+        self.media_descriptor = media_descriptor
 
     def build_arguments(self):
         GstValidateTest.build_arguments(self)
@@ -162,7 +196,7 @@ class GstValidateLaunchTest(GstValidateTest):
             if sent_eos is not None:
                 t = time.time()
                 if ((t - sent_eos)) > 30:
-                    if self.media_xml.attrib["protocol"] == Protocols.HLS:
+                    if self.media_descriptor.get_protocol() == Protocols.HLS:
                         self.set_result(Result.PASSED,
                                         """Got no EOS 30 seconds after sending EOS,
                                         in HLS known and tolerated issue:
@@ -175,13 +209,13 @@ class GstValidateLaunchTest(GstValidateTest):
 
 
 class GstValidateMediaCheckTest(Test):
-    def __init__(self, classname, options, reporter, media_xml, uri, minfo_path,
+    def __init__(self, classname, options, reporter, media_descriptor, uri, minfo_path,
                  timeout=DEFAULT_TIMEOUT):
         super(GstValidateMediaCheckTest, self).__init__(G_V_DISCOVERER_COMMAND, classname,
                                               options, reporter,
                                               timeout=timeout)
         self._uri = uri
-        self.media_xml = media_xml
+        self.media_descriptor = media_descriptor
         self._media_info_path = minfo_path
 
     def build_arguments(self):
@@ -192,12 +226,12 @@ class GstValidateMediaCheckTest(Test):
 class GstValidateTranscodingTest(GstValidateTest):
     _scenarios = ScenarioManager()
     def __init__(self, classname, options, reporter,
-                 combination, uri, media_xml, timeout=DEFAULT_TIMEOUT,
+                 combination, uri, media_descriptor, timeout=DEFAULT_TIMEOUT,
                  scenario_name="play_15s"):
 
         Loggable.__init__(self)
 
-        file_dur = long(media_xml.attrib["duration"]) / GST_SECOND
+        file_dur = long(media_descriptor.get_duration()) / GST_SECOND
         if  file_dur < 30:
             self.debug("%s is short (%ds< 30 secs) playing it all" % (uri, file_dur))
             scenario = None
@@ -205,7 +239,7 @@ class GstValidateTranscodingTest(GstValidateTest):
             self.debug("%s is long (%ds > 30 secs) playing it all" % (uri, file_dur))
             scenario = self._scenarios.get_scenario(scenario_name)
         try:
-            timeout = G_V_PROTOCOL_TIMEOUTS[media_xml.attrib["protocol"]]
+            timeout = G_V_PROTOCOL_TIMEOUTS[media_descriptor.get_protocol()]
         except KeyError:
             pass
 
@@ -219,7 +253,7 @@ class GstValidateTranscodingTest(GstValidateTest):
             options, reporter, scenario=scenario, timeout=timeout,
             hard_timeout=hard_timeout)
 
-        self.media_xml = media_xml
+        self.media_descriptor = media_descriptor
         self.uri = uri
         self.combination = combination
         self.dest_file = ""
@@ -233,7 +267,7 @@ class GstValidateTranscodingTest(GstValidateTest):
             self.dest_file = path2url(self.dest_file)
 
         try:
-            video_restriction = G_V_PROTOCOL_VIDEO_RESTRICTION_CAPS[self.media_xml.attrib["protocol"]]
+            video_restriction = G_V_PROTOCOL_VIDEO_RESTRICTION_CAPS[self.media_descriptor.get_protocol()]
         except KeyError:
             video_restriction = None
 
@@ -251,7 +285,7 @@ class GstValidateTranscodingTest(GstValidateTest):
             if sent_eos is not None:
                 t = time.time()
                 if ((t - sent_eos)) > 30:
-                    if self.media_xml.attrib["protocol"] == Protocols.HLS:
+                    if self.media_descriptor.get_protocol() == Protocols.HLS:
                         self.set_result(Result.PASSED,
                                         """Got no EOS 30 seconds after sending EOS,
                                         in HLS known and tolerated issue:
@@ -265,7 +299,7 @@ class GstValidateTranscodingTest(GstValidateTest):
 
     def check_results(self):
         if self.result is Result.PASSED and not self.scenario:
-            orig_duration = long(self.media_xml.attrib["duration"])
+            orig_duration = long(self.media_descriptor.get_duration())
             res, msg = compare_rendered_with_original(orig_duration, self.dest_file)
             self.set_result(res, msg)
         elif self.message == "":
@@ -297,7 +331,7 @@ class GstValidateManager(TestsManager, Loggable):
             self._add_playback_test(test_pipeline)
 
         for uri, mediainfo in self._list_uris():
-            protocol = mediainfo.media_xml.attrib["protocol"]
+            protocol = mediainfo.media_descriptor.get_protocol()
             try:
                 timeout = G_V_PROTOCOL_TIMEOUTS[protocol]
             except KeyError:
@@ -308,7 +342,7 @@ class GstValidateManager(TestsManager, Loggable):
             self.add_test(GstValidateMediaCheckTest(classname,
                                                     self.options,
                                                     self.reporter,
-                                                    mediainfo.media_xml,
+                                                    mediainfo.media_descriptor,
                                                     uri,
                                                     mediainfo.path,
                                                     timeout=timeout))
@@ -316,37 +350,36 @@ class GstValidateManager(TestsManager, Loggable):
         for uri, mediainfo in self._list_uris():
 
 
-            if is_image(mediainfo.media_xml):
+            if mediainfo.media_descriptor.is_image():
                 continue
             for comb in G_V_ENCODING_TARGET_COMBINATIONS:
-                classname = "validate.%s.transcode.to_%s.%s" % (mediainfo.media_xml.attrib["protocol"],
+                classname = "validate.%s.transcode.to_%s.%s" % (mediainfo.media_descriptor.get_protocol(),
                                                                 str(comb).replace(' ', '_'),
                                                                 os.path.basename(uri).replace(".", "_"))
                 self.add_test(GstValidateTranscodingTest(classname,
                                                          self.options,
                                                          self.reporter,
                                                          comb, uri,
-                                                         mediainfo.media_xml))
+                                                         mediainfo.media_descriptor))
         return self.tests
 
     def _check_discovering_info(self, media_info, uri=None):
         self.debug("Checking %s", media_info)
-        media_xml = ET.parse(media_info).getroot()
+        media_descriptor = MediaDescriptor(media_info)
         try:
             # Just testing that the vairous mandatory infos are present
-            caps = media_xml.findall("streams")[0].attrib["caps"]
-            media_xml.attrib["duration"]
-            media_xml.attrib["seekable"]
+            caps = media_descriptor.get_caps()
             if uri is None:
-                uri = media_xml.attrib["uri"]
-            media_xml.attrib["protocol"] = urlparse.urlparse(uri).scheme
+                uri = media_descriptor.get_uri()
+
+            media_descriptor.set_protocol(urlparse.urlparse(uri).scheme)
             for caps2, prot in G_V_CAPS_TO_PROTOCOL:
                 if caps2 == caps:
-                    media_xml.attrib["protocol"] = prot
+                    media_descriptor.set_protocol(prot)
                     break
             self._uris.append((uri,
                                NamedDic({"path": media_info,
-                                         "media_xml": media_xml})))
+                                         "media_descriptor": media_descriptor})))
         except ConfigParser.NoOptionError as e:
             self.debug("Exception: %s for %s", e, media_info)
 
@@ -405,14 +438,19 @@ class GstValidateManager(TestsManager, Loggable):
     def _add_playback_test(self, pipe_descriptor):
         if pipe_descriptor.needs_uri():
             for uri, minfo in self._list_uris():
-                protocol = minfo.media_xml.attrib["protocol"]
-                for scenario_name in G_V_SCENARIOS[protocol]:
-                    scenario = self._scenarios.get_scenario(scenario_name)
+                protocol = minfo.media_descriptor.get_protocol()
+                if self._run_defaults:
+                    scenarios = [self._scenarios.get_scenario(scenario_name)
+                                 for scenario_name in G_V_SCENARIOS[protocol]]
+                else:
+                    scenarios = self._scenarios.get_scenario(None)
+
+                for scenario in scenarios:
                     npipe = pipe_descriptor.get_pipeline(self.options,
                                                          protocol,
                                                          scenario,
                                                          uri)
-                    if not minfo.media_xml.attrib["seekable"] or is_image(minfo.media_xml):
+                    if not minfo.media_descriptor.is_seekable() or minfo.media_descriptor.is_image():
                         self.debug("Do not run %s as %s does not support seeking",
                                    scenario, uri)
                         continue
@@ -427,7 +465,7 @@ class GstValidateManager(TestsManager, Loggable):
                                                         self.reporter,
                                                         npipe,
                                                         scenario=scenario,
-                                                        media_xml=minfo.media_xml)
+                                                        media_descriptor=minfo.media_descriptor)
                                  )
         else:
             self.add_test(GstValidateLaunchTest(self._get_fname(scenario, "testing"),
@@ -439,8 +477,8 @@ class GstValidateManager(TestsManager, Loggable):
     def needs_http_server(self):
         for test in self.list_tests():
             if self._is_test_wanted(test):
-                protocol = test.media_xml.attrib["protocol"]
-                uri = test.media_xml.attrib["uri"]
+                protocol = test.media_descriptor.get_protocol()
+                uri = test.media_descriptor.get_uri()
 
                 if protocol == Protocols.HTTP and \
                     "127.0.0.1:%s" % (self.options.http_server_port) in uri: