From 97f02d40bc9a92f81e84099f8a5f33745198b372 Mon Sep 17 00:00:00 2001 From: Fantix King Date: Sun, 20 Mar 2022 18:57:16 -0400 Subject: [PATCH] PoC uring and ktls --- MANIFEST.in | 2 +- pyproject.toml | 4 + setup.py | 6 +- src/kloop/__init__.py | 1 + src/kloop/includes/__init__.py | 9 + src/kloop/includes/barrier.h | 20 ++ src/kloop/includes/barrier.pxd | 15 ++ src/kloop/includes/libc.pxd | 42 ++++ src/kloop/includes/linux.pxd | 128 ++++++++++ src/kloop/includes/ssl.h | 20 ++ src/kloop/includes/ssl.pxd | 29 +++ src/kloop/ktls.pyx | 153 ++++++++++++ src/kloop/loop.py | 412 +++++++++++++++++++++++++++++++++ src/kloop/uring.pxd | 105 +++++++++ src/kloop/uring.pyx | 410 ++++++++++++++++++++++++++++++++ tests/test_loop.py | 50 +++- 16 files changed, 1400 insertions(+), 6 deletions(-) create mode 100644 src/kloop/includes/__init__.py create mode 100644 src/kloop/includes/barrier.h create mode 100644 src/kloop/includes/barrier.pxd create mode 100644 src/kloop/includes/libc.pxd create mode 100644 src/kloop/includes/linux.pxd create mode 100644 src/kloop/includes/ssl.h create mode 100644 src/kloop/includes/ssl.pxd create mode 100644 src/kloop/loop.py diff --git a/MANIFEST.in b/MANIFEST.in index 865a748..89a1f77 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,5 @@ include Makefile -recursive-include src *.pyx *.pxd +recursive-include src *.pyx *.pxd *.h graft tests global-exclude *.py[cod] *.c diff --git a/pyproject.toml b/pyproject.toml index aea2aa7..dc8925d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,7 @@ +[tool.black] +line-length = 79 +target-version = ["py310"] + [build-system] requires = ["setuptools>=42", "Cython>=0.29"] build-backend = "setuptools.build_meta" diff --git a/setup.py b/setup.py index 611b2b9..94143bb 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,11 @@ setup( ext_modules=cythonize( [ Extension("kloop.uring", ["src/kloop/uring.pyx"]), - Extension("kloop.ktls", ["src/kloop/ktls.pyx"]), + Extension( + "kloop.ktls", + ["src/kloop/ktls.pyx"], + libraries=["ssl", "crypto"], + ), ], language_level="3", ) diff --git a/src/kloop/__init__.py b/src/kloop/__init__.py index bf73ec7..c385634 100644 --- a/src/kloop/__init__.py +++ b/src/kloop/__init__.py @@ -8,3 +8,4 @@ # MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. # See the Mulan PSL v2 for more details. +from .loop import KLoop, KLoopPolicy diff --git a/src/kloop/includes/__init__.py b/src/kloop/includes/__init__.py new file mode 100644 index 0000000..e0326fc --- /dev/null +++ b/src/kloop/includes/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2022 Fantix King http://fantix.pro +# kLoop is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. diff --git a/src/kloop/includes/barrier.h b/src/kloop/includes/barrier.h new file mode 100644 index 0000000..cb4ec0a --- /dev/null +++ b/src/kloop/includes/barrier.h @@ -0,0 +1,20 @@ +/* Copied from liburing: src/include/liburing/barrier.h */ + +#include + +#define IO_URING_WRITE_ONCE(var, val) \ + atomic_store_explicit((_Atomic __typeof__(var) *)&(var), \ + (val), memory_order_relaxed) +#define IO_URING_READ_ONCE(var) \ + atomic_load_explicit((_Atomic __typeof__(var) *)&(var), \ + memory_order_relaxed) + +#define io_uring_smp_store_release(p, v) \ + atomic_store_explicit((_Atomic __typeof__(*(p)) *)(p), (v), \ + memory_order_release) +#define io_uring_smp_load_acquire(p) \ + atomic_load_explicit((_Atomic __typeof__(*(p)) *)(p), \ + memory_order_acquire) + +#define io_uring_smp_mb() \ + atomic_thread_fence(memory_order_seq_cst) diff --git a/src/kloop/includes/barrier.pxd b/src/kloop/includes/barrier.pxd new file mode 100644 index 0000000..f06fc13 --- /dev/null +++ b/src/kloop/includes/barrier.pxd @@ -0,0 +1,15 @@ +# Copyright (c) 2022 Fantix King http://fantix.pro +# kLoop is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. + +cdef extern from "includes/barrier.h" nogil: + unsigned IO_URING_READ_ONCE(unsigned var) + void io_uring_smp_store_release(void* p, unsigned v) + unsigned int io_uring_smp_load_acquire(void* p) + void io_uring_smp_mb() diff --git a/src/kloop/includes/libc.pxd b/src/kloop/includes/libc.pxd new file mode 100644 index 0000000..cdd0698 --- /dev/null +++ b/src/kloop/includes/libc.pxd @@ -0,0 +1,42 @@ +cdef extern from "sys/syscall.h" nogil: + int SYS_io_uring_setup + int SYS_io_uring_enter + int SYS_io_uring_register + +cdef extern from "unistd.h" nogil: + int syscall(int number, ...) + +cdef extern from "signal.h" nogil: + int _NSIG + +cdef extern from "sys/socket.h" nogil: + ctypedef int socklen_t + int SOL_TLS + + int setsockopt(int socket, int level, int option_name, + const void *option_value, socklen_t option_len); + + struct in_addr: + pass + + struct sockaddr_in: + int sin_family + int sin_port + in_addr sin_addr + + struct msghdr: + iovec* msg_iov # Scatter/gather array + size_t msg_iovlen # Number of elements in msg_iov + void* msg_control # ancillary data, see below + size_t msg_controllen # ancillary data buffer len + int msg_flags # flags on received message + + +cdef extern from "arpa/inet.h" nogil: + int inet_pton(int af, char* src, void* dst) + int htons(short p) + +cdef extern from "sys/uio.h" nogil: + struct iovec: + void* iov_base + size_t iov_len diff --git a/src/kloop/includes/linux.pxd b/src/kloop/includes/linux.pxd new file mode 100644 index 0000000..4c9f821 --- /dev/null +++ b/src/kloop/includes/linux.pxd @@ -0,0 +1,128 @@ +# Copyright (c) 2022 Fantix King http://fantix.pro +# kLoop is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. + +cdef extern from "linux/fs.h" nogil: + ctypedef int __kernel_rwf_t + + +cdef extern from "linux/types.h" nogil: + ctypedef int __u8 + ctypedef int __u16 + ctypedef int __u64 + ctypedef int __u32 + ctypedef int __s32 + ctypedef int __kernel_time64_t + + +cdef extern from "linux/time_types.h" nogil: + struct __kernel_timespec: + __kernel_time64_t tv_sec + long long tv_nsec + + +cdef extern from "linux/tcp.h" nogil: + int TCP_ULP + + +cdef extern from "linux/tls.h" nogil: + __u16 TLS_CIPHER_AES_GCM_256 + int TLS_CIPHER_AES_GCM_256_IV_SIZE + int TLS_CIPHER_AES_GCM_256_SALT_SIZE + int TLS_CIPHER_AES_GCM_256_KEY_SIZE + int TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE + int TLS_TX + int TLS_RX + + struct tls_crypto_info: + __u16 version + __u16 cipher_type + + struct tls12_crypto_info_aes_gcm_256: + tls_crypto_info info + unsigned char* iv + unsigned char* key + unsigned char* salt + unsigned char* rec_seq + + +cdef extern from "linux/io_uring.h" nogil: + unsigned IORING_SETUP_SQPOLL + unsigned IORING_SETUP_SQ_AFF + + unsigned IORING_ENTER_GETEVENTS + unsigned IORING_ENTER_SQ_WAKEUP + unsigned IORING_ENTER_EXT_ARG + + unsigned IORING_SQ_NEED_WAKEUP + unsigned IORING_SQ_CQ_OVERFLOW + + unsigned long long IORING_OFF_SQ_RING + unsigned long long IORING_OFF_SQES + + unsigned IORING_TIMEOUT_ABS + + unsigned IOSQE_IO_LINK + + enum Operation: + IORING_OP_NOP + IORING_OP_CONNECT + IORING_OP_SEND + IORING_OP_SENDMSG + IORING_OP_RECV + IORING_OP_RECVMSG + + struct io_sqring_offsets: + __u32 head + __u32 tail + __u32 ring_mask + __u32 ring_entries + __u32 flags + __u32 dropped + __u32 array + + struct io_cqring_offsets: + __u32 head + __u32 tail + __u32 ring_mask + __u32 ring_entries + __u32 overflow + __u32 cqes + __u32 flags + + struct io_uring_params: + __u32 flags + __u32 sq_thread_cpu + __u32 sq_thread_idle + + # written by the kernel: + __u32 sq_entries + __u32 cq_entries + __u32 features + __u32 resv[4] + io_sqring_offsets sq_off + io_cqring_offsets cq_off + + struct io_uring_sqe: + __u8 opcode # type of operation for this sqe + __s32 fd # file descriptor to do IO on + __u64 off # offset into file + __u64 addr # pointer to buffer or iovecs + __u32 len # buffer size or number of iovecs + __u64 user_data # data to be passed back at completion time + __u8 flags # IOSQE_ flags + + struct io_uring_cqe: + __u64 user_data # data to be passed back at completion time + __s32 res # result code for this event + + struct io_uring_getevents_arg: + __u64 sigmask + __u32 sigmask_sz + __u64 ts diff --git a/src/kloop/includes/ssl.h b/src/kloop/includes/ssl.h new file mode 100644 index 0000000..fc98f7a --- /dev/null +++ b/src/kloop/includes/ssl.h @@ -0,0 +1,20 @@ +/* +Copyright (c) 2022 Fantix King http://fantix.pro +kLoop is licensed under Mulan PSL v2. +You can use this software according to the terms and conditions of the Mulan PSL v2. +You may obtain a copy of Mulan PSL v2 at: + http://license.coscl.org.cn/MulanPSL2 +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +See the Mulan PSL v2 for more details. +*/ + +#include "Python.h" +#include "openssl/ssl.h" + +typedef struct { + PyObject_HEAD + PyObject *Socket; /* weakref to socket on which we're layered */ + SSL *ssl; +} PySSLSocket; diff --git a/src/kloop/includes/ssl.pxd b/src/kloop/includes/ssl.pxd new file mode 100644 index 0000000..bd77c26 --- /dev/null +++ b/src/kloop/includes/ssl.pxd @@ -0,0 +1,29 @@ +# Copyright (c) 2022 Fantix King http://fantix.pro +# kLoop is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. + +cdef extern from "openssl/ssl.h" nogil: + int EVP_GCM_TLS_FIXED_IV_LEN + + ctypedef struct SSL: + pass + + ctypedef struct SSL_CTX: + pass + + int SSL_version(const SSL *s) + ctypedef void(*SSL_CTX_keylog_cb_func)(SSL *ssl, char *line) + void SSL_CTX_set_keylog_callback(SSL_CTX* ctx, SSL_CTX_keylog_cb_func cb) + SSL_CTX_keylog_cb_func SSL_CTX_get_keylog_callback(SSL_CTX* ctx) + SSL_CTX* SSL_get_SSL_CTX(SSL* ssl) + + +cdef extern from "includes/ssl.h" nogil: + ctypedef struct PySSLSocket: + SSL *ssl diff --git a/src/kloop/ktls.pyx b/src/kloop/ktls.pyx index bf73ec7..f978037 100644 --- a/src/kloop/ktls.pyx +++ b/src/kloop/ktls.pyx @@ -8,3 +8,156 @@ # MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. # See the Mulan PSL v2 for more details. +import socket +import hmac +import hashlib +from ssl import SSLWantReadError + +from cpython cimport PyErr_SetFromErrno +from libc cimport string + +from .includes cimport libc, linux, ssl + + +cdef ssl.SSL_CTX_keylog_cb_func orig_cb +cdef secrets = {} + + +cdef void _capture_secrets(const ssl.SSL* s, const char* line) nogil: + if line != NULL: + try: + with gil: + global secrets + parts = line.decode("ISO-8859-1").split() + secrets[parts[0]] = bytes.fromhex(parts[-1]) + finally: + if orig_cb != NULL: + orig_cb(s, line) + + +def do_handshake_capturing_secrets(sslobj): + cdef: + ssl.SSL* s = ( sslobj._sslobj).ssl + ssl.SSL_CTX* ctx = ssl.SSL_get_SSL_CTX(s) + global orig_cb + orig_cb = ssl.SSL_CTX_get_keylog_callback(ctx) + ssl.SSL_CTX_set_keylog_callback( + ctx, _capture_secrets + ) + try: + try: + sslobj.do_handshake() + except SSLWantReadError: + success = False + else: + success = True + if secrets: + rv = dict(secrets) + secrets.clear() + else: + rv = {} + return success, rv + finally: + ssl.SSL_CTX_set_keylog_callback(ctx, orig_cb) + + +def hkdf_expand(pseudo_random_key, info=b"", length=32, hash=hashlib.sha384): + ''' + Expand `pseudo_random_key` and `info` into a key of length `bytes` using + HKDF's expand function based on HMAC with the provided hash (default + SHA-512). See the HKDF draft RFC and paper for usage notes. + ''' + # info_in = info + # info = b'\0' + struct.pack("H", len(info)) + info + b'\0' + # print(f'hkdf_expand info_in= label={info.hex()}') + hash_len = hash().digest_size + length = int(length) + if length > 255 * hash_len: + raise Exception("Cannot expand to more than 255 * %d = %d bytes using the specified hash function" % \ + (hash_len, 255 * hash_len)) + blocks_needed = length // hash_len + (0 if length % hash_len == 0 else 1) # ceil + okm = b"" + output_block = b"" + for counter in range(blocks_needed): + output_block = hmac.new( + pseudo_random_key, + (output_block + info + bytearray((counter + 1,))), + hash, + ).digest() + okm += output_block + return okm[:length] + + +def enable_ulp(sock): + cdef char *tls = b"tls" + if libc.setsockopt(sock.fileno(), socket.SOL_TCP, linux.TCP_ULP, tls, 4): + PyErr_SetFromErrno(IOError) + return + + +def upgrade_aes_gcm_256(sslobj, sock, secret, sending): + cdef: + ssl.SSL* s = (sslobj._sslobj).ssl + linux.tls12_crypto_info_aes_gcm_256 crypto_info + char* seq + + if sending: + # s->rlayer->write_sequence + seq = ((s) + 6112) + else: + # s->rlayer->read_sequence + seq = ((s) + 6104) + + # print(sslobj.cipher()) + + string.memset(&crypto_info, 0, sizeof(crypto_info)) + crypto_info.info.cipher_type = linux.TLS_CIPHER_AES_GCM_256 + crypto_info.info.version = ssl.SSL_version(s) + + key = hkdf_expand( + secret, + b'\x00 \ttls13 key\x00', + linux.TLS_CIPHER_AES_GCM_256_KEY_SIZE, + ) + string.memcpy( + crypto_info.key, + key, + linux.TLS_CIPHER_AES_GCM_256_KEY_SIZE, + ) + string.memcpy( + crypto_info.rec_seq, + seq, + linux.TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE, + ) + iv = hkdf_expand( + secret, + b'\x00\x0c\x08tls13 iv\x00', + linux.TLS_CIPHER_AES_GCM_256_IV_SIZE + + linux.TLS_CIPHER_AES_GCM_256_SALT_SIZE, + ) + string.memcpy( + crypto_info.iv, + iv+ ssl.EVP_GCM_TLS_FIXED_IV_LEN, + linux.TLS_CIPHER_AES_GCM_256_IV_SIZE, + ) + string.memcpy( + crypto_info.salt, + iv, + linux.TLS_CIPHER_AES_GCM_256_SALT_SIZE, + ) + if libc.setsockopt( + sock.fileno(), + libc.SOL_TLS, + linux.TLS_TX if sending else linux.TLS_RX, + &crypto_info, + sizeof(crypto_info), + ): + PyErr_SetFromErrno(IOError) + return + # print( + # sending, + # "iv", crypto_info.iv[:linux.TLS_CIPHER_AES_GCM_256_IV_SIZE].hex(), + # "key", crypto_info.key[:linux.TLS_CIPHER_AES_GCM_256_KEY_SIZE].hex(), + # "salt", crypto_info.salt[:linux.TLS_CIPHER_AES_GCM_256_SALT_SIZE].hex(), + # "rec_seq", crypto_info.rec_seq[:linux.TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE].hex(), + # ) diff --git a/src/kloop/loop.py b/src/kloop/loop.py new file mode 100644 index 0000000..8e72fe6 --- /dev/null +++ b/src/kloop/loop.py @@ -0,0 +1,412 @@ +# Copyright (c) 2022 Fantix King http://fantix.pro +# kLoop is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. + +import asyncio.events +import asyncio.futures +import asyncio.trsock +import asyncio.transports +import contextvars +import socket +import ssl + +from . import uring, ktls + + +class Callback: + __slots__ = ("_callback", "_context", "_args", "_kwargs") + + def __init__(self, callback, context=None, args=None, kwargs=None): + if context is None: + context = contextvars.copy_context() + self._callback = callback + self._context = context + self._args = args or () + self._kwargs = kwargs or {} + + def __call__(self): + self._context.run(self._callback, *self._args, **self._kwargs) + + def __repr__(self): + return f"{self._callback} {self._args} {self._kwargs} {self._context}" + + +class KLoopSocketTransport( + asyncio.transports._FlowControlMixin, asyncio.Transport +): + __slots__ = ( + "_waiter", + "_sock", + "_protocol", + "_closing", + "_recv_buffer", + "_recv_buffer_factory", + "_read_ready_cb", + "_buffers", + "_buffer_size", + "_current_work", + "_write_waiter", + "_read_paused", + ) + + def __init__( + self, loop, sock, protocol, waiter=None, extra=None, server=None + ): + super().__init__(extra, loop) + self._extra["socket"] = asyncio.trsock.TransportSocket(sock) + try: + self._extra["sockname"] = sock.getsockname() + except OSError: + self._extra["sockname"] = None + if "peername" not in self._extra: + try: + self._extra["peername"] = sock.getpeername() + except socket.error: + self._extra["peername"] = None + + self._buffers = [] + self._buffer_size = 0 + self._current_work = None + self._sock = sock + self._closing = False + self._write_waiter = None + self._read_paused = False + + self.set_protocol(protocol) + self._waiter = waiter + + self._loop.call_soon(self._protocol.connection_made, self) + self._loop.call_soon(self._read) + + if self._waiter is not None: + self._loop.call_soon( + asyncio.futures._set_result_unless_cancelled, + self._waiter, + None, + ) + + def set_protocol(self, protocol): + if isinstance(protocol, asyncio.BufferedProtocol): + self._read_ready_cb = self._read_ready__buffer_updated + self._recv_buffer = None + self._recv_buffer_factory = protocol.get_buffer + else: + self._read_ready_cb = self._read_ready__data_received + self._recv_buffer = bytearray(256 * 1024) + self._recv_buffer_factory = lambda _hint: self._recv_buffer + self._protocol = protocol + + def _read(self): + self._loop._selector.submit( + uring.RecvMsgWork( + self._sock.fileno(), + [self._recv_buffer_factory(-1)], + self._read_ready_cb, + ) + ) + + def _read_ready__buffer_updated(self, res): + if res < 0: + raise IOError + elif res == 0: + self._protocol.eof_received() + else: + try: + # print(f"buffer updated: {res}") + self._protocol.buffer_updated(res) + finally: + if not self._closing: + self._read() + + def _read_ready__data_received(self, res): + if res < 0: + raise IOError(f"{res}") + elif res == 0: + self._protocol.eof_received() + else: + try: + # print(f"data received: {res}") + self._protocol.data_received(self._recv_buffer[:res]) + finally: + if not self._closing: + self._read() + + def _write_done(self, res): + self._current_work = None + if res < 0: + # TODO: force close transport + raise IOError() + self._buffer_size -= res + if self._buffers: + if len(self._buffers) == 1: + self._current_work = uring.SendWork( + self._sock.fileno(), self._buffers[0], self._write_done + ) + else: + self._current_work = uring.SendMsgWork( + self._sock.fileno(), self._buffers, self._write_done + ) + self._loop._selector.submit(self._current_work) + self._buffers = [] + elif self._closing: + self._loop.call_soon(self._call_connection_lost, None) + elif self._write_waiter is not None: + self._write_waiter() + self._write_waiter = None + self._maybe_resume_protocol() + + def write(self, data): + self._buffer_size += len(data) + if self._current_work is None: + self._current_work = uring.SendWork( + self._sock.fileno(), data, self._write_done + ) + self._loop._selector.submit(self._current_work) + else: + self._buffers.append(data) + self._maybe_pause_protocol() + + def close(self): + if self._closing: + return + self._closing = True + if self._current_work is None: + self._loop.call_soon(self._call_connection_lost, None) + + def _call_connection_lost(self, exc): + try: + if self._protocol is not None: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + self._sock = None + self._protocol = None + self._loop = None + + def get_write_buffer_size(self): + return self._buffer_size + + def pause_reading(self): + self._read_paused = True + + def resume_reading(self): + if self._read_paused: + self._read_paused = False + self._read() + + +class KLoopSSLHandshakeProtocol(asyncio.Protocol): + __slots__ = ( + "_incoming", + "_outgoing", + "_handshaking", + "_secrets", + "_app_protocol", + "_transport", + "_sslobj", + ) + + def __init__(self, sslcontext, server_hostname): + self._handshaking = True + self._secrets = {} + self._incoming = ssl.MemoryBIO() + self._outgoing = ssl.MemoryBIO() + self._sslobj = sslcontext.wrap_bio( + self._incoming, + self._outgoing, + server_side=False, + server_hostname=server_hostname, + ) + + def connection_made(self, transport): + self._transport = transport + self._handshake() + + def data_received(self, data): + self._incoming.write(data) + self._handshake() + + def _handshake(self): + success, secrets = ktls.do_handshake_capturing_secrets(self._sslobj) + self._secrets.update(secrets) + if success: + if self._handshaking: + self._handshaking = False + if data := self._outgoing.read(): + self._transport.write(data) + self._transport._write_waiter = self._after_last_write + self._transport.pause_reading() + else: + self._after_last_write() + # else: + # try: + # data = self._sslobj.read(16384) + # except ssl.SSLWantReadError: + # data = None + # self._transport._upgrade_ktls_read( + # self._sslobj, + # self._secrets["SERVER_TRAFFIC_SECRET_0"], + # data, + # ) + else: + if data := self._outgoing.read(): + self._transport.write(data) + + def _after_last_write(self): + try: + data = self._sslobj.read(16384) + except ssl.SSLWantReadError: + data = None + self._transport._upgrade_ktls_write( + self._sslobj, + self._secrets["CLIENT_TRAFFIC_SECRET_0"], + ) + self._transport._upgrade_ktls_read( + self._sslobj, + self._secrets["SERVER_TRAFFIC_SECRET_0"], + data, + ) + self._transport.resume_reading() + + +class KLoopSSLTransport(KLoopSocketTransport): + __slots__ = ("_app_protocol",) + + def __init__( + self, + loop, + sock, + protocol, + waiter=None, + extra=None, + server=None, + *, + sslcontext, + server_hostname, + ): + ktls.enable_ulp(sock) + self._app_protocol = protocol + super().__init__( + loop, + sock, + KLoopSSLHandshakeProtocol(sslcontext, server_hostname), + None, + extra, + server, + ) + self._waiter = waiter + + def _upgrade_ktls_write(self, sslobj, secret): + ktls.upgrade_aes_gcm_256(sslobj, self._sock, secret, True) + self._loop.call_soon(self._app_protocol.connection_made, self) + if self._waiter is not None: + self._loop.call_soon( + asyncio.futures._set_result_unless_cancelled, + self._waiter, + None, + ) + + def _upgrade_ktls_read(self, sslobj, secret, data): + ktls.upgrade_aes_gcm_256(sslobj, self._sock, secret, False) + self.set_protocol(self._app_protocol) + if data is not None: + if data: + self._app_protocol.data_received(data) + else: + self._app_protocol.eof_received() + + +class KLoop(asyncio.BaseEventLoop): + def __init__(self, args): + super().__init__() + self._selector = uring.Ring(*args) + + def _process_events(self, works): + for work in works: + work.complete() + + async def sock_connect(self, sock, address): + fut = self.create_future() + self._selector.submit(uring.ConnectWork(sock.fileno(), address, fut)) + return await fut + + async def getaddrinfo( + self, host, port, *, family=0, type=0, proto=0, flags=0 + ): + return socket.getaddrinfo(host, port, family, type, proto, flags) + + def _make_socket_transport( + self, sock, protocol, waiter=None, *, extra=None, server=None + ): + return KLoopSocketTransport( + self, sock, protocol, waiter, extra, server + ) + + def _make_ssl_transport( + self, + rawsock, + protocol, + sslcontext, + waiter=None, + *, + server_side=False, + server_hostname=None, + extra=None, + server=None, + ssl_handshake_timeout=None, + call_connection_made=True, + ): + if sslcontext is None: + sslcontext = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + return KLoopSSLTransport( + self, + rawsock, + protocol, + waiter, + extra, + server, + sslcontext=sslcontext, + server_hostname=server_hostname, + ) + + +class KLoopPolicy(asyncio.events.BaseDefaultEventLoopPolicy): + __slots__ = ("_selector_args",) + + def __init__( + self, queue_depth=128, sq_thread_idle=2000, sq_thread_cpu=None + ): + super().__init__() + assert queue_depth in { + 1, + 2, + 4, + 8, + 16, + 32, + 64, + 128, + 256, + 512, + 1024, + 2048, + 4096, + } + self._selector_args = (queue_depth, sq_thread_idle, sq_thread_cpu) + + def _loop_factory(self): + return KLoop(self._selector_args) + + # Child processes handling (Unix only). + + def get_child_watcher(self): + raise NotImplementedError + + def set_child_watcher(self, watcher): + raise NotImplementedError diff --git a/src/kloop/uring.pxd b/src/kloop/uring.pxd index bf73ec7..0cd9bd6 100644 --- a/src/kloop/uring.pxd +++ b/src/kloop/uring.pxd @@ -8,3 +8,108 @@ # MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. # See the Mulan PSL v2 for more details. +from .includes cimport linux, libc + + +cdef class RingQueue: + cdef: + unsigned* head + unsigned* tail + unsigned* ring_mask + unsigned* ring_entries + unsigned* flags + + size_t ring_size + void* ring_ptr + + +cdef class SubmissionQueue(RingQueue): + cdef: + unsigned* dropped + unsigned* array + linux.io_uring_sqe* sqes + unsigned sqe_head + unsigned sqe_tail + + cdef init(self, linux.io_sqring_offsets sq_off) + cdef linux.io_uring_sqe * next_sqe(self) + cdef unsigned flush(self) + + +cdef class CompletionQueue(RingQueue): + cdef: + unsigned* overflow + linux.io_uring_cqe* cqes + + cdef init(self, linux.io_cqring_offsets cq_off) + cdef unsigned ready(self) + cdef inline object pop_works(self, unsigned ready) + + +cdef class Ring: + cdef: + SubmissionQueue sq + CompletionQueue cq + unsigned features + int fd + int enter_fd + + +cdef class Work: + cdef: + readonly object fut + public bint link + int res + + cdef void submit(self, linux.io_uring_sqe* sqe) + + cdef inline void _submit( + self, + int op, + linux.io_uring_sqe * sqe, + int fd, + void * addr, + unsigned len, + linux.__u64 offset, + ) + + +cdef class ConnectWork(Work): + cdef: + int fd + libc.sockaddr_in addr + object host_bytes + + +cdef class SendWork(Work): + cdef: + int fd + object data + char* data_ptr + linux.__u32 size + object callback + + +cdef class SendMsgWork(Work): + cdef: + int fd + list buffers + libc.msghdr msg + object callback + + +cdef class RecvWork(Work): + cdef: + int fd + object buffer + object callback + char* buffer_ptr + + +cdef class RecvMsgWork(Work): + cdef: + int fd + list buffers + libc.msghdr msg + object callback + object control_msg diff --git a/src/kloop/uring.pyx b/src/kloop/uring.pyx index bf73ec7..e4e1e55 100644 --- a/src/kloop/uring.pyx +++ b/src/kloop/uring.pyx @@ -8,3 +8,413 @@ # MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. # See the Mulan PSL v2 for more details. +import os +import socket + +from cpython cimport Py_INCREF, Py_DECREF, PyErr_SetFromErrno +from cpython cimport PyMem_RawMalloc, PyMem_RawFree +from libc cimport errno, string +from posix cimport mman + +from .includes cimport barrier, libc, linux + +cdef linux.__u32 SIG_SIZE = libc._NSIG // 8 + + +class SubmissionQueueFull(Exception): + pass + + +cdef class RingQueue: + def __cinit__(self, size_t ring_size): + self.ring_size = ring_size + + +cdef class SubmissionQueue(RingQueue): + cdef init(self, linux.io_sqring_offsets sq_off): + self.head = (self.ring_ptr + sq_off.head) + self.tail = (self.ring_ptr + sq_off.tail) + self.ring_mask = (self.ring_ptr + sq_off.ring_mask) + self.ring_entries = (self.ring_ptr + sq_off.ring_entries) + self.flags = (self.ring_ptr + sq_off.flags) + self.dropped = (self.ring_ptr + sq_off.dropped) + self.array = (self.ring_ptr + sq_off.array) + + cdef linux.io_uring_sqe* next_sqe(self): + cdef: + unsigned int head, next + linux.io_uring_sqe* rv + head = barrier.io_uring_smp_load_acquire(self.head) + next = self.sqe_tail + 1 + if next - head <= self.ring_entries[0]: + rv = &self.sqes[self.sqe_tail & self.ring_mask[0]] + self.sqe_tail = next + return rv + else: + # TODO: IORING_ENTER_SQ_WAIT and retry + raise SubmissionQueueFull() + + cdef unsigned flush(self): + cdef: + unsigned mask = self.ring_mask[0] + unsigned tail = self.tail[0] + unsigned to_submit = self.sqe_tail - self.sqe_head + + if to_submit: + while to_submit: + self.array[tail & mask] = self.sqe_head & mask + tail += 1 + self.sqe_head += 1 + to_submit -= 1 + barrier.io_uring_smp_store_release(self.tail, tail) + return tail - self.head[0] + + +cdef class CompletionQueue(RingQueue): + cdef init(self, linux.io_cqring_offsets cq_off): + self.head = (self.ring_ptr + cq_off.head) + self.tail = (self.ring_ptr + cq_off.tail) + self.ring_mask = (self.ring_ptr + cq_off.ring_mask) + self.ring_entries = (self.ring_ptr + cq_off.ring_entries) + self.overflow = (self.ring_ptr + cq_off.overflow) + self.cqes = (self.ring_ptr + cq_off.cqes) + if cq_off.flags: + self.flags = (self.ring_ptr + cq_off.flags) + + cdef unsigned ready(self): + return barrier.io_uring_smp_load_acquire(self.tail) - self.head[0] + + cdef inline object pop_works(self, unsigned ready): + cdef: + object rv = [] + Work work + unsigned head, mask, last + linux.io_uring_cqe* cqe + head = self.head[0] + mask = self.ring_mask[0] + last = head + ready + while head != last: + cqe = self.cqes + (head & mask) + work = cqe.user_data + work.res = cqe.res + rv.append(work) + Py_DECREF(work) + head += 1 + barrier.io_uring_smp_store_release(self.head, self.head[0] + ready) + return rv + + +cdef class Ring: + def __cinit__( + self, + linux.__u32 queue_depth, + linux.__u32 sq_thread_idle, + object sq_thread_cpu, + ): + cdef: + linux.io_uring_params params + int fd + size_t size + void* ptr + + # Prepare io_uring_params + string.memset(¶ms, 0, sizeof(params)) + params.flags = linux.IORING_SETUP_SQPOLL + if sq_thread_cpu is not None: + params.flags |= linux.IORING_SETUP_SQ_AFF + params.sq_thread_cpu = sq_thread_cpu + params.sq_thread_idle = sq_thread_idle + + # SYSCALL: SYS_io_uring_setup + fd = libc.syscall(libc.SYS_io_uring_setup, queue_depth, ¶ms) + if fd < 0: + PyErr_SetFromErrno(IOError) + return + self.fd = self.enter_fd = fd + + # Initialize 2 RingQueue and mmap the ring_ptr + size = max( + params.sq_off.array + params.sq_entries * sizeof(unsigned), + params.cq_off.cqes + params.cq_entries * sizeof(linux.io_uring_cqe) + ) + self.sq = SubmissionQueue(size) + self.cq = CompletionQueue(size) + ptr = mman.mmap( + NULL, + size, + mman.PROT_READ | mman.PROT_WRITE, + mman.MAP_SHARED | mman.MAP_POPULATE, + fd, + linux.IORING_OFF_SQ_RING, + ) + if ptr == mman.MAP_FAILED: + PyErr_SetFromErrno(IOError) + return + self.sq.ring_ptr = self.cq.ring_ptr = ptr + + # Initialize the SubmissionQueue + self.sq.init(params.sq_off) + size = params.sq_entries * sizeof(linux.io_uring_sqe) + ptr = mman.mmap( + NULL, + size, + mman.PROT_READ | mman.PROT_WRITE, + mman.MAP_SHARED | mman.MAP_POPULATE, + fd, + linux.IORING_OFF_SQES, + ) + if ptr == mman.MAP_FAILED: + mman.munmap(self.sq.ring_ptr, self.sq.ring_size) + PyErr_SetFromErrno(IOError) + return + self.sq.sqes = ptr + + # Initialize the CompletionQueue + self.cq.init(params.cq_off) + + self.features = params.features + + def __dealloc__(self): + if self.sq is not None: + if self.sq.sqes != NULL: + mman.munmap( + self.sq.sqes, self.sq.ring_entries[0] * sizeof(linux.io_uring_sqe) + ) + if self.sq.ring_ptr != NULL: + mman.munmap(self.sq.ring_ptr, self.sq.ring_size) + if self.fd: + os.close(self.fd) + + def submit(self, Work work): + cdef linux.io_uring_sqe* sqe = self.sq.next_sqe() + work.submit(sqe) + + def select(self, timeout): + cdef: + int flags = linux.IORING_ENTER_EXT_ARG, ret + bint need_enter = False + unsigned submit, ready + unsigned wait_nr = 0 + linux.io_uring_getevents_arg arg + linux.__kernel_timespec ts + + # Call enter if we have no CQE ready and timeout is not 0, or else we + # handle the ready CQEs first. + ready = self.cq.ready() + if not ready and timeout is not 0: + flags |= linux.IORING_ENTER_GETEVENTS + if timeout is not None: + ts.tv_sec = int(timeout) + ts.tv_nsec = int(round((timeout - ts.tv_sec) * 1_000_000_000)) + arg.ts = &ts + wait_nr = 1 + need_enter = True + + # Flush the submission queue, and only wakeup the SQ polling thread if + # there is something for the kernel to handle. + submit = self.sq.flush() + if submit: + barrier.io_uring_smp_mb() + if barrier.IO_URING_READ_ONCE( + self.sq.flags[0] + ) & linux.IORING_SQ_NEED_WAKEUP: + arg.ts = 0 + flags |= linux.IORING_ENTER_SQ_WAKEUP + need_enter = True + + if need_enter: + arg.sigmask = 0 + arg.sigmask_sz = SIG_SIZE + print(f"SYS_io_uring_enter(submit={submit}, wait_nr={wait_nr}, " + f"flags={flags:b}, timeout={timeout})") + ret = libc.syscall( + libc.SYS_io_uring_enter, + self.enter_fd, + submit, + wait_nr, + flags, + &arg, + sizeof(arg), + ) + if ret < 0: + if errno.errno != errno.ETIME: + PyErr_SetFromErrno(IOError) + return + + ready = self.cq.ready() + + if ready: + return self.cq.pop_works(ready) + else: + return [] + + +cdef class Work: + def __init__(self, fut): + self.fut = fut + self.link = False + self.res = -1 + + cdef void submit(self, linux.io_uring_sqe* sqe): + raise NotImplementedError + + cdef inline void _submit( + self, + int op, + linux.io_uring_sqe * sqe, + int fd, + void * addr, + unsigned len, + linux.__u64 offset, + ): + string.memset(sqe, 0, sizeof(linux.io_uring_sqe)) + sqe.opcode = op + sqe.fd = fd + sqe.off = offset + sqe.addr = addr + sqe.len = len + if self.link: + sqe.flags = linux.IOSQE_IO_LINK + else: + sqe.flags = 0 + sqe.user_data = self + Py_INCREF(self) + + def complete(self): + if self.res == 0: + self.fut.set_result(None) + else: + def _raise(): + errno.errno = abs(self.res) + PyErr_SetFromErrno(IOError) + try: + _raise() + except IOError as ex: + self.fut.set_exception(ex) + + +cdef class ConnectWork(Work): + def __init__(self, int fd, sockaddr, fut): + cdef char* host + super().__init__(fut) + self.fd = fd + host_str, port = sockaddr + self.host_bytes = host_str.encode() + host = self.host_bytes + string.memset(&self.addr, 0, sizeof(self.addr)) + self.addr.sin_family = socket.AF_INET + if not libc.inet_pton(socket.AF_INET, host, &self.addr.sin_addr): + PyErr_SetFromErrno(IOError) + return + self.addr.sin_port = libc.htons(port) + + cdef void submit(self, linux.io_uring_sqe* sqe): + self._submit( + linux.IORING_OP_CONNECT, + sqe, + self.fd, + &self.addr, + 0, + sizeof(self.addr), + ) + + +cdef class SendWork(Work): + def __init__(self, int fd, data, callback): + self.fd = fd + self.data = data + self.data_ptr = data + self.size = len(data) + self.callback = callback + + cdef void submit(self, linux.io_uring_sqe* sqe): + self._submit(linux.IORING_OP_SEND, sqe, self.fd, self.data_ptr, self.size, 0) + + def complete(self): + self.callback(self.res) + + +cdef class SendMsgWork(Work): + def __init__(self, int fd, buffers, callback): + self.fd = fd + self.buffers = buffers + self.callback = callback + self.msg.msg_iov = PyMem_RawMalloc( + sizeof(libc.iovec) * len(buffers) + ) + if self.msg.msg_iov == NULL: + raise MemoryError + self.msg.msg_iovlen = len(buffers) + for i, buf in enumerate(buffers): + self.msg.msg_iov[i].iov_base = buf + self.msg.msg_iov[i].iov_len = len(buf) + + def __dealloc__(self): + if self.msg.msg_iov != NULL: + PyMem_RawFree(self.msg.msg_iov) + + cdef void submit(self, linux.io_uring_sqe* sqe): + self._submit(linux.IORING_OP_SENDMSG, sqe, self.fd, &self.msg, 1, 0) + + def complete(self): + if self.res < 0: + errno.errno = abs(self.res) + PyErr_SetFromErrno(IOError) + return + self.callback(self.res) + + +cdef class RecvWork(Work): + def __init__(self, int fd, buffer, callback): + self.fd = fd + self.buffer = buffer + self.callback = callback + self.buffer_ptr = buffer + + cdef void submit(self, linux.io_uring_sqe* sqe): + self._submit( + linux.IORING_OP_RECV, sqe, self.fd, self.buffer_ptr, len(self.buffer), 0 + ) + + def complete(self): + if self.res < 0: + errno.errno = abs(self.res) + PyErr_SetFromErrno(IOError) + return + self.callback(self.res) + + +cdef class RecvMsgWork(Work): + def __init__(self, int fd, buffers, callback): + self.fd = fd + self.buffers = buffers + self.callback = callback + self.msg.msg_iov = PyMem_RawMalloc( + sizeof(libc.iovec) * len(buffers) + ) + if self.msg.msg_iov == NULL: + raise MemoryError + self.msg.msg_iovlen = len(buffers) + for i, buf in enumerate(buffers): + self.msg.msg_iov[i].iov_base = buf + self.msg.msg_iov[i].iov_len = len(buf) + self.control_msg = bytearray(256) + self.msg.msg_control = self.control_msg + self.msg.msg_controllen = 256 + + def __dealloc__(self): + if self.msg.msg_iov != NULL: + PyMem_RawFree(self.msg.msg_iov) + + cdef void submit(self, linux.io_uring_sqe* sqe): + self._submit(linux.IORING_OP_RECVMSG, sqe, self.fd, &self.msg, 1, 0) + + def complete(self): + if self.res < 0: + errno.errno = abs(self.res) + PyErr_SetFromErrno(IOError) + return + # if self.msg.msg_controllen: + # print('control_msg:', self.control_msg[:self.msg.msg_controllen]) + # print('flags:', self.msg.msg_flags) + self.callback(self.res) diff --git a/tests/test_loop.py b/tests/test_loop.py index 918a8b1..d947891 100644 --- a/tests/test_loop.py +++ b/tests/test_loop.py @@ -1,9 +1,51 @@ +# Copyright (c) 2022 Fantix King http://fantix.pro +# kLoop is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. + +import asyncio +import ssl +import time import unittest -from kloop import uring, ktls +import kloop class TestLoop(unittest.TestCase): - def test_loop(self): - self.assertIsNotNone(uring) - self.assertIsNotNone(ktls) + def setUp(self): + asyncio.set_event_loop_policy(kloop.KLoopPolicy()) + self.loop = asyncio.new_event_loop() + + def tearDown(self): + self.loop.close() + asyncio.set_event_loop_policy(None) + + def test_call_soon(self): + self.loop.call_soon(self.loop.stop) + self.loop.run_forever() + + def test_call_later(self): + secs = 0.1 + self.loop.call_later(secs, self.loop.stop) + start = time.monotonic() + self.loop.run_forever() + self.assertGreaterEqual(time.monotonic() - start, secs) + + def test_connect(self): + ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + r, w = self.loop.run_until_complete( + asyncio.open_connection("www.google.com", 443, ssl=ctx) + ) + w.write(b"GET / HTTP/1.1\r\n" + b"Host: www.google.com\r\n" + b"Connection: close\r\n" + b"\r\n") + while line := self.loop.run_until_complete(r.readline()): + print(line) + w.close() + self.loop.run_until_complete(w.wait_closed())