# Copyright (c) 2022 Fantix King http://fantix.pro # kLoop is licensed under Mulan PSL v2. # You can use this software according to the terms and conditions of the Mulan PSL v2. # You may obtain a copy of Mulan PSL v2 at: # http://license.coscl.org.cn/MulanPSL2 # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, # EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, # MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. # See the Mulan PSL v2 for more details. import asyncio.events import asyncio.futures import asyncio.trsock import asyncio.transports import contextvars import errno import socket import ssl from . import uring, ktls class Callback: __slots__ = ("_callback", "_context", "_args", "_kwargs") def __init__(self, callback, context=None, args=None, kwargs=None): if context is None: context = contextvars.copy_context() self._callback = callback self._context = context self._args = args or () self._kwargs = kwargs or {} def __call__(self): self._context.run(self._callback, *self._args, **self._kwargs) def __repr__(self): return f"{self._callback} {self._args} {self._kwargs} {self._context}" class KLoopSocketTransport( asyncio.transports._FlowControlMixin, asyncio.Transport ): __slots__ = ( "_waiter", "_sock", "_protocol", "_closing", "_recv_buffer", "_recv_buffer_factory", "_read_ready_cb", "_buffers", "_buffer_size", "_current_work", "_write_waiter", "_read_paused", ) def __init__( self, loop, sock, protocol, waiter=None, extra=None, server=None ): super().__init__(extra, loop) self._extra["socket"] = asyncio.trsock.TransportSocket(sock) try: self._extra["sockname"] = sock.getsockname() except OSError: self._extra["sockname"] = None if "peername" not in self._extra: try: self._extra["peername"] = sock.getpeername() except socket.error: self._extra["peername"] = None self._buffers = [] self._buffer_size = 0 self._current_work = None self._sock = sock self._closing = False self._write_waiter = None self._read_paused = False self.set_protocol(protocol) self._waiter = waiter self._loop.call_soon(self._protocol.connection_made, self) self._loop.call_soon(self._read) if self._waiter is not None: self._loop.call_soon( asyncio.futures._set_result_unless_cancelled, self._waiter, None, ) def set_protocol(self, protocol): if isinstance(protocol, asyncio.BufferedProtocol): self._read_ready_cb = self._read_ready__buffer_updated self._recv_buffer = None self._recv_buffer_factory = protocol.get_buffer else: self._read_ready_cb = self._read_ready__data_received self._recv_buffer = bytearray(64 * 1024 * 1024) self._recv_buffer_factory = lambda _hint: self._recv_buffer self._protocol = protocol def _read(self): # print("RecvMsgWork") if self._read_paused: return self._loop._selector.submit( uring.RecvMsgWork( self._sock.fileno(), [self._recv_buffer_factory(-1)], self._read_ready_cb, ) ) def _read_ready__buffer_updated(self, res, app_data): if res < 0: raise IOError elif res == 0: self._protocol.eof_received() else: try: # print(f"buffer updated: {res}") self._protocol.buffer_updated(res) finally: if not self._closing: self._read() def _read_ready__data_received(self, res, app_data): print("_read_ready__data_received", res) if res < 0: if abs(res) == errno.EAGAIN: print('EAGAIN') self._read() else: raise IOError(f"{res}") elif res == 0: self._protocol.eof_received() else: try: print(f"data received: {res}") data = bytes(self._recv_buffer[:res]) # print(f"data received: {data}") if app_data: self._protocol.data_received(data) finally: if not self._closing: self._read() def _write_done(self, res): # print("_write_done") self._current_work = None if res < 0: # TODO: force close transport raise IOError() self._buffer_size -= res if self._buffers: if len(self._buffers) == 1: self._current_work = uring.SendWork( self._sock.fileno(), self._buffers[0], self._write_done ) else: self._current_work = uring.SendMsgWork( self._sock.fileno(), self._buffers, self._write_done ) self._loop._selector.submit(self._current_work) # print("more SendWork") self._buffers = [] elif self._closing: self._loop.call_soon(self._call_connection_lost, None) elif self._write_waiter is not None: self._write_waiter() self._write_waiter = None self._maybe_resume_protocol() def write(self, data): self._buffer_size += len(data) if self._current_work is None: self._current_work = uring.SendWork( self._sock.fileno(), data, self._write_done ) self._loop._selector.submit(self._current_work) # print("SendWork") else: self._buffers.append(data) self._maybe_pause_protocol() def close(self): if self._closing: return self._closing = True if self._current_work is None: self._loop.call_soon(self._call_connection_lost, None) def _call_connection_lost(self, exc): try: if self._protocol is not None: self._protocol.connection_lost(exc) finally: self._sock.close() self._sock = None self._protocol = None self._loop = None def get_write_buffer_size(self): return self._buffer_size def pause_reading(self): self._read_paused = True def resume_reading(self): if self._read_paused: self._read_paused = False self._read() class KLoopSSLHandshakeProtocol(asyncio.Protocol): __slots__ = ( "_incoming", "_outgoing", "_handshaking", "_secrets", "_app_protocol", "_transport", "_sslobj", ) def __init__(self, sslcontext, server_hostname): self._handshaking = True self._secrets = {} self._incoming = ssl.MemoryBIO() self._outgoing = ssl.MemoryBIO() self._sslobj = sslcontext.wrap_bio( self._incoming, self._outgoing, server_side=False, server_hostname=server_hostname, ) def connection_made(self, transport): self._transport = transport self._handshake() def data_received(self, data): self._incoming.write(data) self._handshake() def _handshake(self): ktls.get_state(self._sslobj) success, secrets = ktls.do_handshake_capturing_secrets(self._sslobj) self._secrets.update(secrets) if success: # print("handshake done") if self._handshaking: print(self._sslobj.cipher()) ktls.get_state(self._sslobj) self._handshaking = False try: data = self._sslobj.read(64 * 1024) except ssl.SSLWantReadError: data = None if data: while True: try: data += self._sslobj.read(64 * 1024) except ssl.SSLWantReadError: break print("try read", data) self._transport._upgrade_ktls_read( self._sslobj, self._secrets["SERVER_TRAFFIC_SECRET_0"], data, ) if data := self._outgoing.read(): print("last message") self._transport.write(data) self._transport._write_waiter = self._after_last_write # self._transport._write_waiter = lambda: self._transport._loop.call_later(1, self._after_last_write) self._transport.pause_reading() else: self._after_last_write() else: assert False # try: # data = self._sslobj.read(64 * 1024) # except ssl.SSLWantReadError: # data = None # if data: # while True: # try: # data += self._sslobj.read(64 * 1024) # except ssl.SSLWantReadError: # break # print("try read", data) # # ktls.get_state(self._sslobj) # self._after_last_write() # self._transport._upgrade_ktls_read( # self._sslobj, # self._secrets["SERVER_TRAFFIC_SECRET_0"], # data, # ) else: # print("SSLWantReadError") if data := self._outgoing.read(): self._transport.write(data) def _after_last_write(self): print("_after_last_write") ktls.get_state(self._sslobj) # try: # data = self._sslobj.read(64 * 1024) # except ssl.SSLWantReadError: # data = None # if data: # while True: # try: # data += self._sslobj.read(64 * 1024) # except ssl.SSLWantReadError: # break # print("try read", data) self._transport._upgrade_ktls_write( self._sslobj, self._secrets["CLIENT_TRAFFIC_SECRET_0"], ) # self._transport._upgrade_ktls_read( # self._sslobj, # self._secrets["SERVER_TRAFFIC_SECRET_0"], # data, # ) self._transport.resume_reading() class KLoopSSLTransport(KLoopSocketTransport): __slots__ = ("_app_protocol",) def __init__( self, loop, sock, protocol, waiter=None, extra=None, server=None, *, sslcontext, server_hostname, ): ktls.enable_ulp(sock) self._app_protocol = protocol super().__init__( loop, sock, KLoopSSLHandshakeProtocol(sslcontext, server_hostname), None, extra, server, ) self._waiter = waiter def _upgrade_ktls_write(self, sslobj, secret): print("_upgrade_ktls_write") ktls.upgrade_aes_gcm_256(sslobj, self._sock, secret, True) self.set_protocol(self._app_protocol) self._loop.call_soon(self._app_protocol.connection_made, self) if self._waiter is not None: self._loop.call_soon( asyncio.futures._set_result_unless_cancelled, self._waiter, None, ) def _upgrade_ktls_read(self, sslobj, secret, data): print("_upgrade_ktls_read") ktls.upgrade_aes_gcm_256(sslobj, self._sock, secret, False) # self.set_protocol(self._app_protocol) # if data is not None: # if data: # self._app_protocol.data_received(data) # else: # self._app_protocol.eof_received() class KLoop(asyncio.BaseEventLoop): def __init__(self, args): super().__init__() self._selector = uring.Ring(*args) def _process_events(self, works): for work in works: work.complete() async def sock_connect(self, sock, address): fut = self.create_future() self._selector.submit(uring.ConnectWork(sock.fileno(), address, fut)) return await fut async def getaddrinfo( self, host, port, *, family=0, type=0, proto=0, flags=0 ): return socket.getaddrinfo(host, port, family, type, proto, flags) def _make_socket_transport( self, sock, protocol, waiter=None, *, extra=None, server=None ): sock.setblocking(True) return KLoopSocketTransport( self, sock, protocol, waiter, extra, server ) def _make_ssl_transport( self, rawsock, protocol, sslcontext, waiter=None, *, server_side=False, server_hostname=None, extra=None, server=None, ssl_handshake_timeout=None, call_connection_made=True, ): if sslcontext is None: sslcontext = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) return KLoopSSLTransport( self, rawsock, protocol, waiter, extra, server, sslcontext=sslcontext, server_hostname=server_hostname, ) class KLoopPolicy(asyncio.events.BaseDefaultEventLoopPolicy): __slots__ = ("_selector_args",) def __init__( self, queue_depth=128, sq_thread_idle=2000, sq_thread_cpu=None ): super().__init__() assert queue_depth in { 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, } self._selector_args = (queue_depth, sq_thread_idle, sq_thread_cpu) def _loop_factory(self): return KLoop(self._selector_args) # Child processes handling (Unix only). def get_child_watcher(self): raise NotImplementedError def set_child_watcher(self, watcher): raise NotImplementedError