validate: Use a single TCPServer for subprocess communication

Instead of creating a separate TCPServer for each test, just create
one which handles all connections in a threaded fashion.

Shaves off ~500ms per test

https://bugzilla.gnome.org/show_bug.cgi?id=791159
This commit is contained in:
Edward Hervey 2017-12-03 10:42:49 +01:00 committed by Edward Hervey
parent 58e62f651c
commit 92285ef261
2 changed files with 62 additions and 37 deletions

View file

@ -452,7 +452,7 @@ done:
void void
gst_validate_report_init (void) gst_validate_report_init (void)
{ {
const gchar *var, *file_env, *server_env; const gchar *var, *file_env, *server_env, *uuid;
const GDebugKey keys[] = { const GDebugKey keys[] = {
{"fatal_criticals", GST_VALIDATE_FATAL_CRITICALS}, {"fatal_criticals", GST_VALIDATE_FATAL_CRITICALS},
{"fatal_warnings", GST_VALIDATE_FATAL_WARNINGS}, {"fatal_warnings", GST_VALIDATE_FATAL_WARNINGS},
@ -481,7 +481,11 @@ gst_validate_report_init (void)
} }
server_env = g_getenv ("GST_VALIDATE_SERVER"); server_env = g_getenv ("GST_VALIDATE_SERVER");
if (server_env) { uuid = g_getenv ("GST_VALIDATE_UUID");
if (server_env && !uuid) {
GST_ERROR ("No GST_VALIDATE_UUID specified !");
} else if (server_env) {
GstUri *server_uri = gst_uri_from_string (server_env); GstUri *server_uri = gst_uri_from_string (server_env);
if (server_uri && !g_strcmp0 (gst_uri_get_scheme (server_uri), "tcp")) { if (server_uri && !g_strcmp0 (gst_uri_get_scheme (server_uri), "tcp")) {
@ -502,6 +506,8 @@ gst_validate_report_init (void)
g_io_stream_get_output_stream (G_IO_STREAM (server_connection)); g_io_stream_get_output_stream (G_IO_STREAM (server_connection));
jbuilder = json_builder_new (); jbuilder = json_builder_new ();
json_builder_begin_object (jbuilder); json_builder_begin_object (jbuilder);
json_builder_set_member_name (jbuilder, "uuid");
json_builder_add_string_value (jbuilder, uuid);
json_builder_set_member_name (jbuilder, "started"); json_builder_set_member_name (jbuilder, "started");
json_builder_add_boolean_value (jbuilder, TRUE); json_builder_add_boolean_value (jbuilder, TRUE);
json_builder_end_object (jbuilder); json_builder_end_object (jbuilder);

View file

@ -36,6 +36,7 @@ import queue
import configparser import configparser
import xml import xml
import random import random
import uuid
from . import reporters from . import reporters
from . import loggable from . import loggable
@ -95,6 +96,7 @@ class Test(Loggable):
self.queue = None self.queue = None
self.duration = duration self.duration = duration
self.stack_trace = None self.stack_trace = None
self._uuid = None
if expected_failures is None: if expected_failures is None:
self.expected_failures = [] self.expected_failures = []
elif not isinstance(expected_failures, list): elif not isinstance(expected_failures, list):
@ -208,6 +210,11 @@ class Test(Loggable):
def get_name(self): def get_name(self):
return self.classname.split('.')[-1] return self.classname.split('.')[-1]
def get_uuid(self):
if self._uuid is None:
self._uuid = self.classname + str(uuid.uuid4())
return self._uuid
def add_arguments(self, *args): def add_arguments(self, *args):
self.command += args self.command += args
@ -514,10 +521,13 @@ class Test(Loggable):
return self.result return self.result
class GstValidateTCPServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
pass
class GstValidateListener(socketserver.BaseRequestHandler): class GstValidateListener(socketserver.BaseRequestHandler):
def handle(self): def handle(self):
"""Implements BaseRequestHandler handle method""" """Implements BaseRequestHandler handle method"""
test = None
while True: while True:
raw_len = self.request.recv(4) raw_len = self.request.recv(4)
if raw_len == b'': if raw_len == b'':
@ -528,7 +538,20 @@ class GstValidateListener(socketserver.BaseRequestHandler):
return return
obj = json.loads(msg) obj = json.loads(msg)
test = getattr(self.server, "test")
if test is None:
# First message must contain the uuid
uuid = obj.get("uuid", None)
if uuid is None:
return
# Find test from launcher
for t in self.server.launcher.tests:
if uuid == t.get_uuid():
test = t
break
if test is None:
self.server.launcher.error("Could not find test for UUID %s" % uuid)
return
obj_type = obj.get("type", '') obj_type = obj.get("type", '')
if obj_type == 'position': if obj_type == 'position':
@ -617,16 +640,8 @@ class GstValidateTest(Test):
else: else:
self.scenario = scenario self.scenario = scenario
def stop_server(self):
if self.server:
self.server.shutdown()
self.server_thread.join()
self.server.server_close()
self.server = None
def kill_subprocess(self): def kill_subprocess(self):
Test.kill_subprocess(self) Test.kill_subprocess(self)
self.stop_server()
def add_report(self, report): def add_report(self, report):
self.reports.append(report) self.reports.append(report)
@ -642,31 +657,6 @@ class GstValidateTest(Test):
self._sent_eos_time = time.time() self._sent_eos_time = time.time()
self.actions_infos.append(action_infos) self.actions_infos.append(action_infos)
def server_wrapper(self, ready):
self.server = socketserver.TCPServer(('localhost', 0), GstValidateListener)
self.server.socket.settimeout(None)
self.server.test = self
self.serverport = self.server.socket.getsockname()[1]
self.info("%s server port: %s" % (self, self.serverport))
ready.set()
self.server.serve_forever()
def test_start(self, queue):
ready = threading.Event()
self.server_thread = threading.Thread(target=self.server_wrapper,
kwargs={'ready': ready})
self.server_thread.start()
ready.wait()
Test.test_start(self, queue)
def test_end(self):
res = Test.test_end(self)
self.stop_server()
return res
def get_override_file(self, media_descriptor): def get_override_file(self, media_descriptor):
if media_descriptor: if media_descriptor:
if media_descriptor.get_path(): if media_descriptor.get_path():
@ -701,7 +691,7 @@ class GstValidateTest(Test):
def get_subproc_env(self): def get_subproc_env(self):
subproc_env = os.environ.copy() subproc_env = os.environ.copy()
subproc_env["GST_VALIDATE_SERVER"] = "tcp://localhost:%s" % self.serverport subproc_env["GST_VALIDATE_UUID"] = self.get_uuid()
if 'GST_DEBUG' in os.environ and not self.options.redirect_logs: if 'GST_DEBUG' in os.environ and not self.options.redirect_logs:
gstlogsfile = self.logfile + '.gstdebug' gstlogsfile = self.logfile + '.gstdebug'
@ -1294,6 +1284,7 @@ class _TestsLauncher(Loggable):
self.queue = queue.Queue() self.queue = queue.Queue()
self.jobs = [] self.jobs = []
self.total_num_tests = 0 self.total_num_tests = 0
self.server = None
def _list_app_dirs(self): def _list_app_dirs(self):
app_dirs = [] app_dirs = []
@ -1551,6 +1542,32 @@ class _TestsLauncher(Loggable):
cur_test_num = self.tests.index(test) + 1 cur_test_num = self.tests.index(test) + 1
sys.stdout.write("[%d / %d] " % (cur_test_num, self.total_num_tests)) sys.stdout.write("[%d / %d] " % (cur_test_num, self.total_num_tests))
def server_wrapper(self, ready):
self.server = GstValidateTCPServer(('localhost', 0), GstValidateListener)
self.server.socket.settimeout(None)
self.server.launcher = self
self.serverport = self.server.socket.getsockname()[1]
self.info("%s server port: %s" % (self, self.serverport))
ready.set()
self.server.serve_forever(poll_interval=0.05)
def _start_server(self):
self.info("Starting TCP Server")
ready = threading.Event()
self.server_thread = threading.Thread(target=self.server_wrapper,
kwargs={'ready': ready})
self.server_thread.start()
ready.wait()
os.environ["GST_VALIDATE_SERVER"] = "tcp://localhost:%s" % self.serverport
def _stop_server(self):
if self.server:
self.server.shutdown()
self.server_thread.join()
self.server.server_close()
self.server = None
def test_wait(self): def test_wait(self):
while True: while True:
# Check process every second for timeout # Check process every second for timeout
@ -1640,8 +1657,10 @@ class _TestsLauncher(Loggable):
def clean_tests(self): def clean_tests(self):
for test in self.tests: for test in self.tests:
test.clean() test.clean()
self._stop_server()
def run_tests(self): def run_tests(self):
self._start_server()
if self.options.forever: if self.options.forever:
r = 1 r = 1
while True: while True: