gstreamer/webrtc/signalling/simple_server.py

291 lines
12 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
#
# Example 1-1 call signalling server
#
# Copyright (C) 2017 Centricular Ltd.
#
# Author: Nirbheek Chauhan <nirbheek@centricular.com>
#
import os
import sys
import ssl
import logging
import asyncio
import websockets
import argparse
import http
import concurrent
class WebRTCSimpleServer(object):
def __init__(self, addr, port, keepalive_timeout, disable_ssl, certpath, health_path=None):
############### Global data ###############
# Format: {uid: (Peer WebSocketServerProtocol,
# remote_address,
# <'session'|room_id|None>)}
self.peers = dict()
# Format: {caller_uid: callee_uid,
# callee_uid: caller_uid}
# Bidirectional mapping between the two peers
self.sessions = dict()
# Format: {room_id: {peer1_id, peer2_id, peer3_id, ...}}
# Room dict with a set of peers in each room
self.rooms = dict()
self.keepalive_timeout = keepalive_timeout
self.addr = addr
self.port = port
self.disable_ssl = disable_ssl
self.certpath = certpath
self.health_path = health_path
############### Helper functions ###############
async def health_check(self, path, request_headers):
if path == self.health_part:
return http.HTTPStatus.OK, [], b"OK\n"
return None
async def recv_msg_ping(self, ws, raddr):
'''
Wait for a message forever, and send a regular ping to prevent bad routers
from closing the connection.
'''
msg = None
while msg is None:
try:
msg = await asyncio.wait_for(ws.recv(), self.keepalive_timeout)
except (asyncio.exceptions.TimeoutError, concurrent.futures._base.TimeoutError):
print('Sending keepalive ping to {!r} in recv'.format(raddr))
await ws.ping()
return msg
async def cleanup_session(self, uid):
if uid in self.sessions:
other_id = self.sessions[uid]
del self.sessions[uid]
print("Cleaned up {} session".format(uid))
if other_id in self.sessions:
del self.sessions[other_id]
print("Also cleaned up {} session".format(other_id))
# If there was a session with this peer, also
# close the connection to reset its state.
if other_id in self.peers:
print("Closing connection to {}".format(other_id))
wso, oaddr, _ = self.peers[other_id]
del self.peers[other_id]
await wso.close()
async def cleanup_room(self, uid, room_id):
room_peers = self.rooms[room_id]
if uid not in room_peers:
return
room_peers.remove(uid)
for pid in room_peers:
wsp, paddr, _ = self.peers[pid]
msg = 'ROOM_PEER_LEFT {}'.format(uid)
print('room {}: {} -> {}: {}'.format(room_id, uid, pid, msg))
await wsp.send(msg)
async def remove_peer(self, uid):
await self.cleanup_session(uid)
if uid in self.peers:
ws, raddr, status = self.peers[uid]
if status and status != 'session':
await self.cleanup_room(uid, status)
del self.peers[uid]
await ws.close()
print("Disconnected from peer {!r} at {!r}".format(uid, raddr))
############### Handler functions ###############
async def connection_handler(self, ws, uid):
raddr = ws.remote_address
peer_status = None
self.peers[uid] = [ws, raddr, peer_status]
print("Registered peer {!r} at {!r}".format(uid, raddr))
while True:
# Receive command, wait forever if necessary
msg = await self.recv_msg_ping(ws, raddr)
# Update current status
peer_status = self.peers[uid][2]
# We are in a session or a room, messages must be relayed
if peer_status is not None:
# We're in a session, route message to connected peer
if peer_status == 'session':
other_id = self.sessions[uid]
wso, oaddr, status = self.peers[other_id]
assert(status == 'session')
print("{} -> {}: {}".format(uid, other_id, msg))
await wso.send(msg)
# We're in a room, accept room-specific commands
elif peer_status:
# ROOM_PEER_MSG peer_id MSG
if msg.startswith('ROOM_PEER_MSG'):
_, other_id, msg = msg.split(maxsplit=2)
if other_id not in self.peers:
await ws.send('ERROR peer {!r} not found'
''.format(other_id))
continue
wso, oaddr, status = self.peers[other_id]
if status != room_id:
await ws.send('ERROR peer {!r} is not in the room'
''.format(other_id))
continue
msg = 'ROOM_PEER_MSG {} {}'.format(uid, msg)
print('room {}: {} -> {}: {}'.format(room_id, uid, other_id, msg))
await wso.send(msg)
elif msg == 'ROOM_PEER_LIST':
room_id = self.peers[peer_id][2]
room_peers = ' '.join([pid for pid in self.rooms[room_id] if pid != peer_id])
msg = 'ROOM_PEER_LIST {}'.format(room_peers)
print('room {}: -> {}: {}'.format(room_id, uid, msg))
await ws.send(msg)
else:
await ws.send('ERROR invalid msg, already in room')
continue
else:
raise AssertionError('Unknown peer status {!r}'.format(peer_status))
# Requested a session with a specific peer
elif msg.startswith('SESSION'):
print("{!r} command {!r}".format(uid, msg))
_, callee_id = msg.split(maxsplit=1)
if callee_id not in self.peers:
await ws.send('ERROR peer {!r} not found'.format(callee_id))
continue
if peer_status is not None:
await ws.send('ERROR peer {!r} busy'.format(callee_id))
continue
await ws.send('SESSION_OK')
wsc = self.peers[callee_id][0]
print('Session from {!r} ({!r}) to {!r} ({!r})'
''.format(uid, raddr, callee_id, wsc.remote_address))
# Register session
self.peers[uid][2] = peer_status = 'session'
self.sessions[uid] = callee_id
self.peers[callee_id][2] = 'session'
self.sessions[callee_id] = uid
# Requested joining or creation of a room
elif msg.startswith('ROOM'):
print('{!r} command {!r}'.format(uid, msg))
_, room_id = msg.split(maxsplit=1)
# Room name cannot be 'session', empty, or contain whitespace
if room_id == 'session' or room_id.split() != [room_id]:
await ws.send('ERROR invalid room id {!r}'.format(room_id))
continue
if room_id in self.rooms:
if uid in self.rooms[room_id]:
raise AssertionError('How did we accept a ROOM command '
'despite already being in a room?')
else:
# Create room if required
self.rooms[room_id] = set()
room_peers = ' '.join([pid for pid in self.rooms[room_id]])
await ws.send('ROOM_OK {}'.format(room_peers))
# Enter room
self.peers[uid][2] = peer_status = room_id
self.rooms[room_id].add(uid)
for pid in self.rooms[room_id]:
if pid == uid:
continue
wsp, paddr, _ = self.peers[pid]
msg = 'ROOM_PEER_JOINED {}'.format(uid)
print('room {}: {} -> {}: {}'.format(room_id, uid, pid, msg))
await wsp.send(msg)
else:
print('Ignoring unknown message {!r} from {!r}'.format(msg, uid))
async def hello_peer(self, ws):
'''
Exchange hello, register peer
'''
raddr = ws.remote_address
hello = await ws.recv()
hello, uid = hello.split(maxsplit=1)
if hello != 'HELLO':
await ws.close(code=1002, reason='invalid protocol')
raise Exception("Invalid hello from {!r}".format(raddr))
if not uid or uid in self.peers or uid.split() != [uid]: # no whitespace
await ws.close(code=1002, reason='invalid peer uid')
raise Exception("Invalid uid {!r} from {!r}".format(uid, raddr))
# Send back a HELLO
await ws.send('HELLO')
return uid
def run(self):
sslctx = None
if not self.disable_ssl:
# Create an SSL context to be used by the websocket server
print('Using TLS with keys in {!r}'.format(self.certpath))
if 'letsencrypt' in self.certpath:
chain_pem = os.path.join(self.certpath, 'fullchain.pem')
key_pem = os.path.join(self.certpath, 'privkey.pem')
else:
chain_pem = os.path.join(self.certpath, 'cert.pem')
key_pem = os.path.join(self.certpath, 'key.pem')
sslctx = ssl.create_default_context()
try:
sslctx.load_cert_chain(chain_pem, keyfile=key_pem)
except FileNotFoundError:
print("Certificates not found, did you run generate_cert.sh?")
sys.exit(1)
# FIXME
sslctx.check_hostname = False
sslctx.verify_mode = ssl.CERT_NONE
async def handler(ws, path):
'''
All incoming messages are handled here. @path is unused.
'''
raddr = ws.remote_address
print("Connected to {!r}".format(raddr))
peer_id = await self.hello_peer(ws)
try:
await self.connection_handler(ws, peer_id)
except websockets.ConnectionClosed:
print("Connection to peer {!r} closed, exiting handler".format(raddr))
finally:
await self.remove_peer(peer_id)
print("Listening on https://{}:{}".format(self.addr, self.port))
# Websocket server
wsd = websockets.serve(handler, self.addr, self.port, ssl=sslctx, process_request=self.health_check if self.health_path else None,
# Maximum number of messages that websockets will pop
# off the asyncio and OS buffers per connection. See:
# https://websockets.readthedocs.io/en/stable/api.html#websockets.protocol.WebSocketCommonProtocol
max_queue=16)
logger = logging.getLogger('websockets.server')
logger.setLevel(logging.ERROR)
logger.addHandler(logging.StreamHandler())
return wsd
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# See: host, port in https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.create_server
parser.add_argument('--addr', default='', help='Address to listen on (default: all interfaces, both ipv4 and ipv6)')
parser.add_argument('--port', default=8443, type=int, help='Port to listen on')
parser.add_argument('--keepalive-timeout', dest='keepalive_timeout', default=30, type=int, help='Timeout for keepalive (in seconds)')
parser.add_argument('--cert-path', default=os.path.dirname(__file__))
parser.add_argument('--disable-ssl', default=False, help='Disable ssl', action='store_true')
parser.add_argument('--health', default='/health', help='Health check route')
options = parser.parse_args(sys.argv[1:])
loop = asyncio.get_event_loop()
r = WebRTCSimpleServer(options.addr, options.port, options.keepalive_timeout, options.disable_ssl, options.cert_path)
loop.run_until_complete (r.run())
loop.run_forever ()
print ("Goodbye!")
if __name__ == "__main__":
main()