2018-09-10 08:08:15 +00:00
#!/usr/bin/env python3
#
# Example 1-1 call signalling server
#
# Copyright (C) 2017 Centricular Ltd.
#
# Author: Nirbheek Chauhan <nirbheek@centricular.com>
#
import os
import sys
import ssl
import logging
import asyncio
import websockets
import argparse
import http
2020-05-01 08:58:30 +00:00
import concurrent
2018-09-10 08:08:15 +00:00
2020-05-25 18:28:29 +00:00
2018-09-10 08:08:15 +00:00
class WebRTCSimpleServer ( object ) :
2020-05-25 18:32:43 +00:00
def __init__ ( self , loop , options ) :
2018-09-10 08:08:15 +00:00
############### Global data ###############
# Format: {uid: (Peer WebSocketServerProtocol,
# remote_address,
# <'session'|room_id|None>)}
self . peers = dict ( )
# Format: {caller_uid: callee_uid,
# callee_uid: caller_uid}
# Bidirectional mapping between the two peers
self . sessions = dict ( )
# Format: {room_id: {peer1_id, peer2_id, peer3_id, ...}}
# Room dict with a set of peers in each room
self . rooms = dict ( )
2020-05-25 18:32:43 +00:00
# Event loop
self . loop = loop
# Websocket Server Instance
self . server = None
2020-05-25 18:28:29 +00:00
# Options
self . addr = options . addr
self . port = options . port
self . keepalive_timeout = options . keepalive_timeout
2020-05-25 18:34:11 +00:00
self . cert_restart = options . cert_restart
2020-05-25 18:28:29 +00:00
self . cert_path = options . cert_path
self . disable_ssl = options . disable_ssl
self . health_path = options . health
2018-09-10 08:08:15 +00:00
2020-05-25 18:34:11 +00:00
# Certificate mtime, used to detect when to restart the server
self . cert_mtime = - 1
2018-09-10 08:08:15 +00:00
############### Helper functions ###############
async def health_check ( self , path , request_headers ) :
2020-05-25 18:28:29 +00:00
if path == self . health_path :
2018-09-10 08:08:15 +00:00
return http . HTTPStatus . OK , [ ] , b " OK \n "
return None
async def recv_msg_ping ( self , ws , raddr ) :
'''
Wait for a message forever , and send a regular ping to prevent bad routers
from closing the connection .
'''
msg = None
while msg is None :
try :
msg = await asyncio . wait_for ( ws . recv ( ) , self . keepalive_timeout )
2020-05-25 18:39:16 +00:00
except ( asyncio . TimeoutError , concurrent . futures . _base . TimeoutError ) :
2018-09-10 08:08:15 +00:00
print ( ' Sending keepalive ping to {!r} in recv ' . format ( raddr ) )
await ws . ping ( )
return msg
async def cleanup_session ( self , uid ) :
if uid in self . sessions :
other_id = self . sessions [ uid ]
del self . sessions [ uid ]
print ( " Cleaned up {} session " . format ( uid ) )
if other_id in self . sessions :
del self . sessions [ other_id ]
print ( " Also cleaned up {} session " . format ( other_id ) )
# If there was a session with this peer, also
# close the connection to reset its state.
if other_id in self . peers :
print ( " Closing connection to {} " . format ( other_id ) )
wso , oaddr , _ = self . peers [ other_id ]
del self . peers [ other_id ]
await wso . close ( )
async def cleanup_room ( self , uid , room_id ) :
room_peers = self . rooms [ room_id ]
if uid not in room_peers :
return
room_peers . remove ( uid )
for pid in room_peers :
wsp , paddr , _ = self . peers [ pid ]
msg = ' ROOM_PEER_LEFT {} ' . format ( uid )
print ( ' room {} : {} -> {} : {} ' . format ( room_id , uid , pid , msg ) )
await wsp . send ( msg )
async def remove_peer ( self , uid ) :
await self . cleanup_session ( uid )
if uid in self . peers :
ws , raddr , status = self . peers [ uid ]
if status and status != ' session ' :
await self . cleanup_room ( uid , status )
del self . peers [ uid ]
await ws . close ( )
print ( " Disconnected from peer {!r} at {!r} " . format ( uid , raddr ) )
############### Handler functions ###############
async def connection_handler ( self , ws , uid ) :
raddr = ws . remote_address
peer_status = None
self . peers [ uid ] = [ ws , raddr , peer_status ]
print ( " Registered peer {!r} at {!r} " . format ( uid , raddr ) )
while True :
# Receive command, wait forever if necessary
msg = await self . recv_msg_ping ( ws , raddr )
# Update current status
peer_status = self . peers [ uid ] [ 2 ]
# We are in a session or a room, messages must be relayed
if peer_status is not None :
# We're in a session, route message to connected peer
if peer_status == ' session ' :
other_id = self . sessions [ uid ]
wso , oaddr , status = self . peers [ other_id ]
assert ( status == ' session ' )
print ( " {} -> {} : {} " . format ( uid , other_id , msg ) )
await wso . send ( msg )
# We're in a room, accept room-specific commands
elif peer_status :
# ROOM_PEER_MSG peer_id MSG
if msg . startswith ( ' ROOM_PEER_MSG ' ) :
_ , other_id , msg = msg . split ( maxsplit = 2 )
if other_id not in self . peers :
await ws . send ( ' ERROR peer {!r} not found '
' ' . format ( other_id ) )
continue
wso , oaddr , status = self . peers [ other_id ]
if status != room_id :
await ws . send ( ' ERROR peer {!r} is not in the room '
' ' . format ( other_id ) )
continue
msg = ' ROOM_PEER_MSG {} {} ' . format ( uid , msg )
print ( ' room {} : {} -> {} : {} ' . format ( room_id , uid , other_id , msg ) )
await wso . send ( msg )
elif msg == ' ROOM_PEER_LIST ' :
room_id = self . peers [ peer_id ] [ 2 ]
room_peers = ' ' . join ( [ pid for pid in self . rooms [ room_id ] if pid != peer_id ] )
msg = ' ROOM_PEER_LIST {} ' . format ( room_peers )
print ( ' room {} : -> {} : {} ' . format ( room_id , uid , msg ) )
await ws . send ( msg )
else :
await ws . send ( ' ERROR invalid msg, already in room ' )
continue
else :
raise AssertionError ( ' Unknown peer status {!r} ' . format ( peer_status ) )
# Requested a session with a specific peer
elif msg . startswith ( ' SESSION ' ) :
print ( " {!r} command {!r} " . format ( uid , msg ) )
_ , callee_id = msg . split ( maxsplit = 1 )
if callee_id not in self . peers :
await ws . send ( ' ERROR peer {!r} not found ' . format ( callee_id ) )
continue
if peer_status is not None :
await ws . send ( ' ERROR peer {!r} busy ' . format ( callee_id ) )
continue
await ws . send ( ' SESSION_OK ' )
wsc = self . peers [ callee_id ] [ 0 ]
print ( ' Session from {!r} ( {!r} ) to {!r} ( {!r} ) '
' ' . format ( uid , raddr , callee_id , wsc . remote_address ) )
# Register session
self . peers [ uid ] [ 2 ] = peer_status = ' session '
self . sessions [ uid ] = callee_id
self . peers [ callee_id ] [ 2 ] = ' session '
self . sessions [ callee_id ] = uid
# Requested joining or creation of a room
elif msg . startswith ( ' ROOM ' ) :
print ( ' {!r} command {!r} ' . format ( uid , msg ) )
_ , room_id = msg . split ( maxsplit = 1 )
# Room name cannot be 'session', empty, or contain whitespace
if room_id == ' session ' or room_id . split ( ) != [ room_id ] :
await ws . send ( ' ERROR invalid room id {!r} ' . format ( room_id ) )
continue
if room_id in self . rooms :
if uid in self . rooms [ room_id ] :
raise AssertionError ( ' How did we accept a ROOM command '
' despite already being in a room? ' )
else :
# Create room if required
self . rooms [ room_id ] = set ( )
room_peers = ' ' . join ( [ pid for pid in self . rooms [ room_id ] ] )
await ws . send ( ' ROOM_OK {} ' . format ( room_peers ) )
# Enter room
self . peers [ uid ] [ 2 ] = peer_status = room_id
self . rooms [ room_id ] . add ( uid )
for pid in self . rooms [ room_id ] :
if pid == uid :
continue
wsp , paddr , _ = self . peers [ pid ]
msg = ' ROOM_PEER_JOINED {} ' . format ( uid )
print ( ' room {} : {} -> {} : {} ' . format ( room_id , uid , pid , msg ) )
await wsp . send ( msg )
else :
print ( ' Ignoring unknown message {!r} from {!r} ' . format ( msg , uid ) )
async def hello_peer ( self , ws ) :
'''
Exchange hello , register peer
'''
raddr = ws . remote_address
hello = await ws . recv ( )
hello , uid = hello . split ( maxsplit = 1 )
if hello != ' HELLO ' :
await ws . close ( code = 1002 , reason = ' invalid protocol ' )
raise Exception ( " Invalid hello from {!r} " . format ( raddr ) )
if not uid or uid in self . peers or uid . split ( ) != [ uid ] : # no whitespace
await ws . close ( code = 1002 , reason = ' invalid peer uid ' )
raise Exception ( " Invalid uid {!r} from {!r} " . format ( uid , raddr ) )
# Send back a HELLO
await ws . send ( ' HELLO ' )
return uid
2020-05-25 18:33:32 +00:00
def get_ssl_certs ( self ) :
if ' letsencrypt ' in self . cert_path :
chain_pem = os . path . join ( self . cert_path , ' fullchain.pem ' )
key_pem = os . path . join ( self . cert_path , ' privkey.pem ' )
else :
chain_pem = os . path . join ( self . cert_path , ' cert.pem ' )
key_pem = os . path . join ( self . cert_path , ' key.pem ' )
return chain_pem , key_pem
2018-09-10 08:08:15 +00:00
2020-05-25 18:33:32 +00:00
def get_ssl_ctx ( self ) :
if self . disable_ssl :
return None
# Create an SSL context to be used by the websocket server
print ( ' Using TLS with keys in {!r} ' . format ( self . cert_path ) )
chain_pem , key_pem = self . get_ssl_certs ( )
sslctx = ssl . create_default_context ( )
try :
sslctx . load_cert_chain ( chain_pem , keyfile = key_pem )
except FileNotFoundError :
print ( " Certificates not found, did you run generate_cert.sh? " )
sys . exit ( 1 )
# FIXME
sslctx . check_hostname = False
sslctx . verify_mode = ssl . CERT_NONE
return sslctx
2018-09-10 08:08:15 +00:00
2020-05-25 18:33:32 +00:00
def run ( self ) :
2018-09-10 08:08:15 +00:00
async def handler ( ws , path ) :
'''
All incoming messages are handled here . @path is unused .
'''
raddr = ws . remote_address
print ( " Connected to {!r} " . format ( raddr ) )
peer_id = await self . hello_peer ( ws )
try :
await self . connection_handler ( ws , peer_id )
except websockets . ConnectionClosed :
print ( " Connection to peer {!r} closed, exiting handler " . format ( raddr ) )
finally :
await self . remove_peer ( peer_id )
2020-05-25 18:33:32 +00:00
sslctx = self . get_ssl_ctx ( )
2018-09-10 08:08:15 +00:00
print ( " Listening on https:// {} : {} " . format ( self . addr , self . port ) )
# Websocket server
wsd = websockets . serve ( handler , self . addr , self . port , ssl = sslctx , process_request = self . health_check if self . health_path else None ,
# Maximum number of messages that websockets will pop
# off the asyncio and OS buffers per connection. See:
# https://websockets.readthedocs.io/en/stable/api.html#websockets.protocol.WebSocketCommonProtocol
max_queue = 16 )
2020-05-25 18:29:53 +00:00
# Setup logging
logger = logging . getLogger ( ' websockets ' )
logger . setLevel ( logging . INFO )
2018-09-10 08:08:15 +00:00
logger . addHandler ( logging . StreamHandler ( ) )
2020-05-25 18:32:43 +00:00
# Run the server
self . server = self . loop . run_until_complete ( wsd )
2020-05-25 18:34:11 +00:00
# Stop the server if certificate changes
self . loop . run_until_complete ( self . check_server_needs_restart ( ) )
async def stop ( self ) :
print ( ' Stopping server... ' , end = ' ' )
self . server . close ( )
await self . server . wait_closed ( )
self . loop . stop ( )
print ( ' Stopped. ' )
def check_cert_changed ( self ) :
chain_pem , key_pem = self . get_ssl_certs ( )
mtime = max ( os . stat ( key_pem ) . st_mtime , os . stat ( chain_pem ) . st_mtime )
if self . cert_mtime < 0 :
self . cert_mtime = mtime
return False
if mtime > self . cert_mtime :
self . cert_mtime = mtime
return True
return False
async def check_server_needs_restart ( self ) :
" When the certificate changes, we need to restart the server "
if not self . cert_restart :
return
while True :
await asyncio . sleep ( 10 )
if self . check_cert_changed ( ) :
print ( ' Certificate changed, stopping server... ' )
await self . stop ( )
return
2018-09-10 08:08:15 +00:00
def main ( ) :
parser = argparse . ArgumentParser ( formatter_class = argparse . ArgumentDefaultsHelpFormatter )
# See: host, port in https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.create_server
parser . add_argument ( ' --addr ' , default = ' ' , help = ' Address to listen on (default: all interfaces, both ipv4 and ipv6) ' )
parser . add_argument ( ' --port ' , default = 8443 , type = int , help = ' Port to listen on ' )
parser . add_argument ( ' --keepalive-timeout ' , dest = ' keepalive_timeout ' , default = 30 , type = int , help = ' Timeout for keepalive (in seconds) ' )
parser . add_argument ( ' --cert-path ' , default = os . path . dirname ( __file__ ) )
parser . add_argument ( ' --disable-ssl ' , default = False , help = ' Disable ssl ' , action = ' store_true ' )
parser . add_argument ( ' --health ' , default = ' /health ' , help = ' Health check route ' )
2020-05-25 18:34:11 +00:00
parser . add_argument ( ' --restart-on-cert-change ' , default = False , dest = ' cert_restart ' , action = ' store_true ' , help = ' Automatically restart if the SSL certificate changes ' )
2018-09-10 08:08:15 +00:00
options = parser . parse_args ( sys . argv [ 1 : ] )
loop = asyncio . get_event_loop ( )
2020-05-25 18:32:43 +00:00
r = WebRTCSimpleServer ( loop , options )
2018-09-10 08:08:15 +00:00
2020-05-25 18:32:43 +00:00
print ( ' Starting server... ' )
2020-05-25 18:34:11 +00:00
while True :
r . run ( )
loop . run_forever ( )
print ( ' Restarting server... ' )
2020-05-25 18:32:43 +00:00
print ( " Goodbye! " )
2018-09-10 08:08:15 +00:00
if __name__ == " __main__ " :
main ( )