diff --git a/validate/tools/launcher/apps/gst-validate.py b/validate/tools/launcher/apps/gst-validate.py index 0603edc3d6..4e89519998 100644 --- a/validate/tools/launcher/apps/gst-validate.py +++ b/validate/tools/launcher/apps/gst-validate.py @@ -29,11 +29,45 @@ from utils import MediaFormatCombination, get_profile,\ path2url, DEFAULT_TIMEOUT, which, GST_SECOND, Result, \ compare_rendered_with_original, Protocols -def is_image(media_xml): - for stream in media_xml.findall("streams")[0].findall("stream"): - if stream.attrib["type"] == "image": - return True - return False +class MediaDescriptor(object): + def __init__(self, xml_path): + self.media_xml = ET.parse(xml_path).getroot() + + # Sanity checks + self.media_xml.attrib["duration"] + self.media_xml.attrib["seekable"] + + def get_caps(self): + return self.media_xml.findall("streams")[0].attrib["caps"] + + def get_uri(self): + return self.media_xml.attrib["uri"] + + def get_duration(self): + return self.media_xml.attrib["duration"] + + def set_protocol(self, protocol): + self.media_xml.attrib["protocol"] = protocol + + def get_protocol(self): + return self.media_xml.attrib["protocol"] + + def is_seekable(self): + return self.media_xml.attrib["seekable"] + + def is_image(self): + for stream in self.media_xml.findall("streams")[0].findall("stream"): + if stream.attrib["type"] == "image": + return True + return False + + def num_audio_tracks(media_xml): + naudio = 0 + for stream in media_xml.findall("streams")[0].findall("stream"): + if stream.attrib["type"] == "audio": + naudio += 1 + + return naudio class PipelineDescriptor(object): def __init__(self, name, pipeline): @@ -138,9 +172,9 @@ G_V_BLACKLISTED_TESTS = \ class GstValidateLaunchTest(GstValidateTest): def __init__(self, classname, options, reporter, pipeline_desc, - timeout=DEFAULT_TIMEOUT, scenario=None, media_xml=None): + timeout=DEFAULT_TIMEOUT, scenario=None, media_descriptor=None): try: - timeout = G_V_PROTOCOL_TIMEOUTS[media_xml.attrib["protocol"]] + timeout = G_V_PROTOCOL_TIMEOUTS[media_descriptor.get_protocol()] except KeyError: pass @@ -150,7 +184,7 @@ class GstValidateLaunchTest(GstValidateTest): timeout=timeout) self.pipeline_desc = pipeline_desc - self.media_xml = media_xml + self.media_descriptor = media_descriptor def build_arguments(self): GstValidateTest.build_arguments(self) @@ -162,7 +196,7 @@ class GstValidateLaunchTest(GstValidateTest): if sent_eos is not None: t = time.time() if ((t - sent_eos)) > 30: - if self.media_xml.attrib["protocol"] == Protocols.HLS: + if self.media_descriptor.get_protocol() == Protocols.HLS: self.set_result(Result.PASSED, """Got no EOS 30 seconds after sending EOS, in HLS known and tolerated issue: @@ -175,13 +209,13 @@ class GstValidateLaunchTest(GstValidateTest): class GstValidateMediaCheckTest(Test): - def __init__(self, classname, options, reporter, media_xml, uri, minfo_path, + def __init__(self, classname, options, reporter, media_descriptor, uri, minfo_path, timeout=DEFAULT_TIMEOUT): super(GstValidateMediaCheckTest, self).__init__(G_V_DISCOVERER_COMMAND, classname, options, reporter, timeout=timeout) self._uri = uri - self.media_xml = media_xml + self.media_descriptor = media_descriptor self._media_info_path = minfo_path def build_arguments(self): @@ -192,12 +226,12 @@ class GstValidateMediaCheckTest(Test): class GstValidateTranscodingTest(GstValidateTest): _scenarios = ScenarioManager() def __init__(self, classname, options, reporter, - combination, uri, media_xml, timeout=DEFAULT_TIMEOUT, + combination, uri, media_descriptor, timeout=DEFAULT_TIMEOUT, scenario_name="play_15s"): Loggable.__init__(self) - file_dur = long(media_xml.attrib["duration"]) / GST_SECOND + file_dur = long(media_descriptor.get_duration()) / GST_SECOND if file_dur < 30: self.debug("%s is short (%ds< 30 secs) playing it all" % (uri, file_dur)) scenario = None @@ -205,7 +239,7 @@ class GstValidateTranscodingTest(GstValidateTest): self.debug("%s is long (%ds > 30 secs) playing it all" % (uri, file_dur)) scenario = self._scenarios.get_scenario(scenario_name) try: - timeout = G_V_PROTOCOL_TIMEOUTS[media_xml.attrib["protocol"]] + timeout = G_V_PROTOCOL_TIMEOUTS[media_descriptor.get_protocol()] except KeyError: pass @@ -219,7 +253,7 @@ class GstValidateTranscodingTest(GstValidateTest): options, reporter, scenario=scenario, timeout=timeout, hard_timeout=hard_timeout) - self.media_xml = media_xml + self.media_descriptor = media_descriptor self.uri = uri self.combination = combination self.dest_file = "" @@ -233,7 +267,7 @@ class GstValidateTranscodingTest(GstValidateTest): self.dest_file = path2url(self.dest_file) try: - video_restriction = G_V_PROTOCOL_VIDEO_RESTRICTION_CAPS[self.media_xml.attrib["protocol"]] + video_restriction = G_V_PROTOCOL_VIDEO_RESTRICTION_CAPS[self.media_descriptor.get_protocol()] except KeyError: video_restriction = None @@ -251,7 +285,7 @@ class GstValidateTranscodingTest(GstValidateTest): if sent_eos is not None: t = time.time() if ((t - sent_eos)) > 30: - if self.media_xml.attrib["protocol"] == Protocols.HLS: + if self.media_descriptor.get_protocol() == Protocols.HLS: self.set_result(Result.PASSED, """Got no EOS 30 seconds after sending EOS, in HLS known and tolerated issue: @@ -265,7 +299,7 @@ class GstValidateTranscodingTest(GstValidateTest): def check_results(self): if self.result is Result.PASSED and not self.scenario: - orig_duration = long(self.media_xml.attrib["duration"]) + orig_duration = long(self.media_descriptor.get_duration()) res, msg = compare_rendered_with_original(orig_duration, self.dest_file) self.set_result(res, msg) elif self.message == "": @@ -297,7 +331,7 @@ class GstValidateManager(TestsManager, Loggable): self._add_playback_test(test_pipeline) for uri, mediainfo in self._list_uris(): - protocol = mediainfo.media_xml.attrib["protocol"] + protocol = mediainfo.media_descriptor.get_protocol() try: timeout = G_V_PROTOCOL_TIMEOUTS[protocol] except KeyError: @@ -308,7 +342,7 @@ class GstValidateManager(TestsManager, Loggable): self.add_test(GstValidateMediaCheckTest(classname, self.options, self.reporter, - mediainfo.media_xml, + mediainfo.media_descriptor, uri, mediainfo.path, timeout=timeout)) @@ -316,37 +350,36 @@ class GstValidateManager(TestsManager, Loggable): for uri, mediainfo in self._list_uris(): - if is_image(mediainfo.media_xml): + if mediainfo.media_descriptor.is_image(): continue for comb in G_V_ENCODING_TARGET_COMBINATIONS: - classname = "validate.%s.transcode.to_%s.%s" % (mediainfo.media_xml.attrib["protocol"], + classname = "validate.%s.transcode.to_%s.%s" % (mediainfo.media_descriptor.get_protocol(), str(comb).replace(' ', '_'), os.path.basename(uri).replace(".", "_")) self.add_test(GstValidateTranscodingTest(classname, self.options, self.reporter, comb, uri, - mediainfo.media_xml)) + mediainfo.media_descriptor)) return self.tests def _check_discovering_info(self, media_info, uri=None): self.debug("Checking %s", media_info) - media_xml = ET.parse(media_info).getroot() + media_descriptor = MediaDescriptor(media_info) try: # Just testing that the vairous mandatory infos are present - caps = media_xml.findall("streams")[0].attrib["caps"] - media_xml.attrib["duration"] - media_xml.attrib["seekable"] + caps = media_descriptor.get_caps() if uri is None: - uri = media_xml.attrib["uri"] - media_xml.attrib["protocol"] = urlparse.urlparse(uri).scheme + uri = media_descriptor.get_uri() + + media_descriptor.set_protocol(urlparse.urlparse(uri).scheme) for caps2, prot in G_V_CAPS_TO_PROTOCOL: if caps2 == caps: - media_xml.attrib["protocol"] = prot + media_descriptor.set_protocol(prot) break self._uris.append((uri, NamedDic({"path": media_info, - "media_xml": media_xml}))) + "media_descriptor": media_descriptor}))) except ConfigParser.NoOptionError as e: self.debug("Exception: %s for %s", e, media_info) @@ -405,14 +438,19 @@ class GstValidateManager(TestsManager, Loggable): def _add_playback_test(self, pipe_descriptor): if pipe_descriptor.needs_uri(): for uri, minfo in self._list_uris(): - protocol = minfo.media_xml.attrib["protocol"] - for scenario_name in G_V_SCENARIOS[protocol]: - scenario = self._scenarios.get_scenario(scenario_name) + protocol = minfo.media_descriptor.get_protocol() + if self._run_defaults: + scenarios = [self._scenarios.get_scenario(scenario_name) + for scenario_name in G_V_SCENARIOS[protocol]] + else: + scenarios = self._scenarios.get_scenario(None) + + for scenario in scenarios: npipe = pipe_descriptor.get_pipeline(self.options, protocol, scenario, uri) - if not minfo.media_xml.attrib["seekable"] or is_image(minfo.media_xml): + if not minfo.media_descriptor.is_seekable() or minfo.media_descriptor.is_image(): self.debug("Do not run %s as %s does not support seeking", scenario, uri) continue @@ -427,7 +465,7 @@ class GstValidateManager(TestsManager, Loggable): self.reporter, npipe, scenario=scenario, - media_xml=minfo.media_xml) + media_descriptor=minfo.media_descriptor) ) else: self.add_test(GstValidateLaunchTest(self._get_fname(scenario, "testing"), @@ -439,8 +477,8 @@ class GstValidateManager(TestsManager, Loggable): def needs_http_server(self): for test in self.list_tests(): if self._is_test_wanted(test): - protocol = test.media_xml.attrib["protocol"] - uri = test.media_xml.attrib["uri"] + protocol = test.media_descriptor.get_protocol() + uri = test.media_descriptor.get_uri() if protocol == Protocols.HTTP and \ "127.0.0.1:%s" % (self.options.http_server_port) in uri: