From 74f0062154070c21a8b4d5ed35a3ccb38b702b74 Mon Sep 17 00:00:00 2001 From: Fantix King Date: Sat, 25 Jun 2022 19:46:19 -0400 Subject: [PATCH] Barely working TLS client! Refs #I5ANZH --- src/kloop/includes/openssl/bio.pxd | 10 +- src/kloop/includes/openssl/ssl.pxd | 21 ++ src/kloop/loop.pyx | 4 +- src/kloop/tls.pxd | 37 ++- src/kloop/tls.pyx | 500 +++++++++++++++++++++++++++-- src/kloop/uring.pxd | 16 + 6 files changed, 558 insertions(+), 30 deletions(-) diff --git a/src/kloop/includes/openssl/bio.pxd b/src/kloop/includes/openssl/bio.pxd index 608c489..7693a4c 100644 --- a/src/kloop/includes/openssl/bio.pxd +++ b/src/kloop/includes/openssl/bio.pxd @@ -84,5 +84,11 @@ cdef extern from "openssl/bio.h" nogil: void set_init "BIO_set_init" (BIO* a, int init) void set_shutdown "BIO_set_shutdown" (BIO* a, int shut) - void set_retry_read "BIO_set_retry_read" (BIO *b) - void set_retry_write "BIO_set_retry_write" (BIO *b) + void set_retry_read "BIO_set_retry_read" (BIO* b) + void set_retry_write "BIO_set_retry_write" (BIO* b) + void clear_retry_flags "BIO_clear_retry_flags" (BIO *b) + + cdef int FLAGS_IN_EOF "BIO_FLAGS_IN_EOF" + + int test_flags "BIO_test_flags" (BIO* b, int flags) + void set_flags "BIO_set_flags" (BIO *b, int flags) diff --git a/src/kloop/includes/openssl/ssl.pxd b/src/kloop/includes/openssl/ssl.pxd index 72003e4..bc87e69 100644 --- a/src/kloop/includes/openssl/ssl.pxd +++ b/src/kloop/includes/openssl/ssl.pxd @@ -9,9 +9,30 @@ # See the Mulan PSL v2 for more details. +from .. cimport linux + + cdef extern from "openssl/ssl.h" nogil: ctypedef struct SSL: pass + int SSL3_RT_APPLICATION_DATA int OP_ENABLE_KTLS "SSL_OP_ENABLE_KTLS" int set_options "SSL_set_options" (SSL* ssl, int options) + + +cdef extern from *: + """ + typedef struct { + union { + struct tls12_crypto_info_aes_gcm_128 gcm128; + struct tls12_crypto_info_aes_gcm_256 gcm256; + struct tls12_crypto_info_aes_ccm_128 ccm128; + struct tls12_crypto_info_chacha20_poly1305 chacha20poly1305; + }; + size_t tls_crypto_info_len; + } ktls_crypto_info_t; + """ + + ctypedef struct ktls_crypto_info_t: + size_t tls_crypto_info_len diff --git a/src/kloop/loop.pyx b/src/kloop/loop.pyx index 3dc7c92..9572023 100644 --- a/src/kloop/loop.pyx +++ b/src/kloop/loop.pyx @@ -526,7 +526,9 @@ cdef class KLoopImpl: fd = await tcp_connect(self, host, port) protocol = protocol_factory() if ssl is not None: - transport = tls.TLSTransport.new(fd, protocol, self, ssl) + waiter = self.create_future() + transport = tls.TLSTransport.new(fd, protocol, self, ssl, waiter=waiter) + await waiter else: transport = TCPTransport.new(fd, protocol, self) return transport, protocol diff --git a/src/kloop/tls.pxd b/src/kloop/tls.pxd index daa796f..3bfe8af 100644 --- a/src/kloop/tls.pxd +++ b/src/kloop/tls.pxd @@ -9,8 +9,32 @@ # See the Mulan PSL v2 for more details. +from cpython cimport PyObject +from .includes cimport libc from .includes.openssl cimport bio -from .loop cimport KLoopImpl +from .loop cimport KLoopImpl, Loop, RingCallback + + +cdef struct Proxy: + PyObject* transport + libc.iovec send_vec + libc.msghdr send_msg + RingCallback send_callback + libc.iovec recv_vec + libc.msghdr recv_msg + RingCallback recv_callback + unsigned char flags + char* read_buffer + + Loop* loop + int fd + + +cdef enum State: + UNWRAPPED + HANDSHAKING + WRAPPED + WRAPPED_KTLS cdef class TLSTransport: @@ -21,3 +45,14 @@ cdef class TLSTransport: object protocol object sslctx object sslobj + object waiter + Proxy proxy + State state + object write_buffer + bint sending + + cdef do_handshake(self) + cdef do_read(self) + cdef do_read_ktls(self) + cdef write_cb(self, int res) + cdef read_cb(self, int res) diff --git a/src/kloop/tls.pyx b/src/kloop/tls.pyx index 0db3184..66c0ae0 100644 --- a/src/kloop/tls.pyx +++ b/src/kloop/tls.pyx @@ -9,12 +9,50 @@ # See the Mulan PSL v2 for more details. +import collections +import socket import ssl from cpython cimport PyMem_RawMalloc, PyMem_RawFree -from libc cimport string +from libc cimport errno, string from .includes.openssl cimport bio, err, ssl as ssl_h -from .includes cimport pyssl +from .includes cimport pyssl, linux +from .loop cimport ring_sq_submit_sendmsg, ring_sq_submit_recvmsg + + +cdef int BIO_CTRL_SET_KTLS = 72 +cdef int BIO_CTRL_GET_KTLS_SEND = 73 +cdef int BIO_CTRL_GET_KTLS_RECV = 76 + +cdef int FLAGS_KTLS_TX_CTRL_MSG = 0x1000 +cdef int FLAGS_KTLS_RX = 0x2000 +cdef int FLAGS_KTLS_TX = 0x4000 + +cdef unsigned char FLAGS_PROXY_SEND_SUBMITTED = 1 << 0 +cdef unsigned char FLAGS_PROXY_SEND_COMPLETED = 1 << 1 +cdef unsigned char FLAGS_PROXY_SEND_IN_PROXY = 1 << 2 +cdef unsigned char FLAGS_PROXY_SEND_ALL = ( + FLAGS_PROXY_SEND_SUBMITTED | + FLAGS_PROXY_SEND_COMPLETED | + FLAGS_PROXY_SEND_IN_PROXY +) +cdef unsigned char FLAGS_PROXY_RECV_SUBMITTED = 1 << 4 +cdef unsigned char FLAGS_PROXY_RECV_COMPLETED = 1 << 5 +cdef unsigned char FLAGS_PROXY_RECV_KTLS = 1 << 6 +cdef unsigned char FLAGS_PROXY_RECV_ALL = ( + FLAGS_PROXY_RECV_SUBMITTED | + FLAGS_PROXY_RECV_COMPLETED +) + +cdef size_t CMSG_SIZE = libc.CMSG_SPACE(sizeof(unsigned char)) +DEF DEBUG = 0 + + +cdef inline void reset_msg(libc.msghdr* msg, size_t controllen) nogil: + msg.msg_name = NULL + msg.msg_namelen = 0 + msg.msg_flags = 0 + msg.msg_controllen = controllen cdef object fromOpenSSLError(object err_type): @@ -30,35 +68,266 @@ cdef object fromOpenSSLError(object err_type): cdef int bio_write_ex( bio.BIO* b, const char* data, size_t datal, size_t* written ) nogil: - with gil: - print('bio_write', data[:datal], int(data)) - bio.set_retry_write(b) - written[0] = 0 + cdef: + Proxy* proxy = bio.get_data(b) + int res + + if DEBUG: + with gil: + print("bio_write_ex(data=%x, datal=%d)" % (data, datal)) + + if proxy.flags & FLAGS_PROXY_SEND_SUBMITTED: + if proxy.send_vec.iov_base != data: + if DEBUG: + with gil: + print("bio_write_ex() error: concurrent call") + return 0 + if proxy.send_vec.iov_len > datal: + if DEBUG: + with gil: + print("bio_write_ex() error: short rewrite") + return 0 + + bio.clear_retry_flags(b) + if proxy.flags & FLAGS_PROXY_SEND_COMPLETED: + proxy.flags &= ~FLAGS_PROXY_SEND_ALL + res = proxy.send_callback.res + if res < 0: + if DEBUG: + with gil: + print("bio_write_ex() error:", -res) + errno.errno = -res + return 0 + + written[0] = res + if DEBUG: + with gil: + print('bio_write_ex() written:', res) + for i in range(res): + print( + "%02x " % data[i], + end="" if (i + 1) % 16 and i < res - 1 else "\n", + ) + + else: + written[0] = 0 + bio.set_retry_write(b) + + if not proxy.flags & FLAGS_PROXY_SEND_SUBMITTED: + if DEBUG: + with gil: + print("bio_write_ex() submit") + proxy.send_vec.iov_base = data + proxy.send_vec.iov_len = datal + reset_msg(&proxy.send_msg, 0) + if not ring_sq_submit_sendmsg( + &proxy.loop.ring.sq, + proxy.fd, + &proxy.send_msg, + &proxy.send_callback, + ): + if DEBUG: + with gil: + print("bio_write_ex() error: SQ full") + return 0 + proxy.flags |= FLAGS_PROXY_SEND_SUBMITTED + return 1 cdef int bio_read_ex( bio.BIO* b, char* data, size_t datal, size_t* readbytes ) nogil: - with gil: - print('bio_read', datal, int(data)) - bio.set_retry_read(b) - readbytes[0] = 0 + cdef: + Proxy* proxy = bio.get_data(b) + libc.cmsghdr* cmsg = NULL + int res + int is_ktls = bio.test_flags(b, FLAGS_KTLS_RX) + + if DEBUG: + with gil: + print('bio_read_ex(data=%x, datal=%d)' % (data, datal)) + + if proxy.flags & FLAGS_PROXY_RECV_SUBMITTED: + if proxy.recv_vec.iov_base != (data + 5 if is_ktls else data): + if DEBUG: + with gil: + print("bio_read_ex() error: concurrent call") + return 0 + if proxy.recv_vec.iov_len > (datal - 21 if is_ktls else datal): + if DEBUG: + with gil: + print("bio_read_ex() error: short reread") + return 0 + + bio.clear_retry_flags(b) + if proxy.flags & FLAGS_PROXY_RECV_KTLS: + res = proxy.recv_callback.res + if datal < res + 5: + if DEBUG: + with gil: + print("bio_read_ex() error: datal too short") + errno.errno = errno.EINVAL + return 0 + + cmsg = libc.CMSG_FIRSTHDR(&proxy.recv_msg) + if cmsg.cmsg_type == linux.TLS_GET_RECORD_TYPE: + data[0] = ( libc.CMSG_DATA(cmsg))[0] + data[1] = 0x03 # TLS1_2_VERSION_MAJOR + data[2] = 0x03 # TLS1_2_VERSION_MINOR + # returned length is limited to msg_iov.iov_len above + data[3] = (res >> 8) & 0xff + data[4] = res & 0xff + string.memcpy(data + 5, proxy.read_buffer, res) + res += 5 + else: + string.memcpy(data, proxy.read_buffer, res) + readbytes[0] = res + + if DEBUG: + with gil: + print( + "bio_read_ex() read:", + res, + "(forwarded TLS record)" + ) + for i in range(res): + print( + "%02x " % data[i], + end="" if (i + 1) % 16 and i < res - 1 else "\n", + ) + + elif proxy.flags & FLAGS_PROXY_RECV_COMPLETED: + proxy.flags &= ~FLAGS_PROXY_RECV_ALL + res = proxy.recv_callback.res + if res < 0: + if DEBUG: + with gil: + print("bio_read_ex() error:", -res) + errno.errno = -res + return 0 + + if is_ktls: + if proxy.recv_msg.msg_controllen: + cmsg = libc.CMSG_FIRSTHDR(&proxy.recv_msg) + if cmsg.cmsg_type == linux.TLS_GET_RECORD_TYPE: + data[0] = (libc.CMSG_DATA(cmsg))[0] + data[1] = 0x03 # TLS1_2_VERSION_MAJOR + data[2] = 0x03 # TLS1_2_VERSION_MINOR + # returned length is limited to msg_iov.iov_len above + data[3] = (res >> 8) & 0xff + data[4] = res & 0xff + res += 5 + + if res == 0: + bio.set_flags(b, bio.FLAGS_IN_EOF) + readbytes[0] = res + + if DEBUG: + with gil: + print( + "bio_read_ex() read:", + res, + "(TLS record)" if cmsg else "" + ) + for i in range(res): + print( + "%02x " % data[i], + end="" if (i + 1) % 16 and i < res - 1 else "\n", + ) + else: + bio.set_retry_read(b) + readbytes[0] = 0 + if not proxy.flags & FLAGS_PROXY_RECV_SUBMITTED: + if is_ktls: + if datal < 21: + if DEBUG: + with gil: + print("bio_read_ex() error: datal too short") + errno.errno = errno.EINVAL + return 0 + + proxy.recv_vec.iov_base = data + 5 + proxy.recv_vec.iov_len = datal - 21 + if DEBUG: + with gil: + print("bio_read_ex() submit(%x, %d)" % ( + proxy.recv_vec.iov_base, + proxy.recv_vec.iov_len, + )) + else: + proxy.recv_vec.iov_base = data + proxy.recv_vec.iov_len = datal + if DEBUG: + with gil: + print("bio_read_ex() submit") + reset_msg(&proxy.recv_msg, CMSG_SIZE) + + if not ring_sq_submit_recvmsg( + &proxy.loop.ring.sq, + proxy.fd, + &proxy.recv_msg, + &proxy.recv_callback, + ): + if DEBUG: + with gil: + print("bio_read_ex() error: SQ full") + return 0 + proxy.flags |= FLAGS_PROXY_RECV_SUBMITTED + return 1 cdef long bio_ctrl(bio.BIO* b, int cmd, long num, void* ptr) nogil: - cdef long ret = 0 - with gil: - if cmd == bio.BIO_CTRL_EOF: - print("BIO_CTRL_EOF", ret) - elif cmd == bio.BIO_CTRL_PUSH: - print("BIO_CTRL_PUSH", ret) - elif cmd == bio.BIO_CTRL_FLUSH: - ret = 1 - print('BIO_CTRL_FLUSH', ret) + cdef: + ssl_h.ktls_crypto_info_t* crypto_info + long ret = 0 + if cmd == bio.BIO_CTRL_EOF: + if DEBUG: + with gil: + print("BIO_CTRL_EOF", ret) + elif cmd == bio.BIO_CTRL_PUSH: + if DEBUG: + with gil: + print("BIO_CTRL_PUSH", ret) + elif cmd == bio.BIO_CTRL_POP: + if DEBUG: + with gil: + print("BIO_CTRL_POP", ret) + elif cmd == bio.BIO_CTRL_FLUSH: + ret = 1 + if DEBUG: + with gil: + print('BIO_CTRL_FLUSH', ret) + elif cmd == BIO_CTRL_SET_KTLS: + if DEBUG: + with gil: + print("BIO_CTRL_SET_KTLS", "TX end" if num else "RX end") + crypto_info = ptr + if libc.setsockopt( + (bio.get_data(b)).fd, + libc.SOL_TLS, + linux.TLS_TX if num else linux.TLS_RX, + crypto_info, + crypto_info.tls_crypto_info_len, + ) == 0: + bio.set_flags(b, FLAGS_KTLS_TX if num else FLAGS_KTLS_RX) else: - print('bio_ctrl', cmd, num) + if DEBUG: + with gil: + print( + "BIO_CTRL_SET_KTLS", + "TX end" if num else "RX end", + "failed", + ) + elif cmd == BIO_CTRL_GET_KTLS_SEND: + return bio.test_flags(b, FLAGS_KTLS_TX) != 0 + elif cmd == BIO_CTRL_GET_KTLS_RECV: + return bio.test_flags(b, FLAGS_KTLS_RX) != 0 + else: + if DEBUG: + with gil: + print('bio_ctrl', cmd, num) return ret @@ -72,6 +341,22 @@ cdef int bio_destroy(bio.BIO* b) nogil: return 1 +cdef int tls_send_cb(RingCallback* cb) nogil except 0: + cdef Proxy* proxy = cb.data + proxy.flags |= FLAGS_PROXY_SEND_COMPLETED + with gil: + (proxy.transport).write_cb(cb.res) + return 1 + + +cdef int tls_recv_cb(RingCallback* cb) nogil except 0: + cdef Proxy* proxy = cb.data + proxy.flags |= FLAGS_PROXY_RECV_COMPLETED + with gil: + (proxy.transport).read_cb(cb.res) + return 1 + + cdef class TLSTransport: @staticmethod def new( @@ -82,11 +367,14 @@ cdef class TLSTransport: server_side=False, server_hostname=None, session=None, + waiter=None, ): cdef: TLSTransport rv = TLSTransport.__new__(TLSTransport) pyssl.PySSLMemoryBIO* c_bio + libc.setsockopt(fd, socket.SOL_TCP, linux.TCP_ULP, b"tls", 3) + py_bio = ssl.MemoryBIO() c_bio = py_bio c_bio.bio, rv.bio = rv.bio, c_bio.bio @@ -99,25 +387,185 @@ cdef class TLSTransport: del py_bio ssl_h.set_options( - (rv.sslobj).ssl, ssl_h.OP_ENABLE_KTLS + (rv.sslobj._sslobj).ssl, ssl_h.OP_ENABLE_KTLS ) rv.fd = fd rv.protocol = protocol rv.loop = loop rv.sslctx = sslctx + rv.proxy.loop = &loop.loop + rv.proxy.fd = fd + rv.waiter = waiter + rv.write_buffer = collections.deque() - try: - rv.sslobj.do_handshake() - except (ssl.SSLWantReadError, ssl.SSLWantWriteError): - pass + rv.do_handshake() return rv def __cinit__(self): + self.state = UNWRAPPED self.bio = bio.new(KTLS_BIO_METHOD) - bio.set_data(self.bio, self) + self.proxy.transport = self + self.proxy.send_msg.msg_iov = &self.proxy.send_vec + self.proxy.send_msg.msg_iovlen = 1 + self.proxy.send_callback.data = &self.proxy + self.proxy.send_callback.callback = tls_send_cb + self.proxy.recv_msg.msg_control = PyMem_RawMalloc(CMSG_SIZE) + if self.proxy.recv_msg.msg_control == NULL: + raise MemoryError + self.proxy.recv_msg.msg_controllen = CMSG_SIZE + self.proxy.recv_msg.msg_iov = &self.proxy.recv_vec + self.proxy.recv_msg.msg_iovlen = 1 + self.proxy.recv_callback.data = &self.proxy + self.proxy.recv_callback.callback = tls_recv_cb + bio.set_data(self.bio, &self.proxy) def __dealloc__(self): + self.sslobj = None bio.free(self.bio) + PyMem_RawFree(self.proxy.read_buffer) + PyMem_RawFree(self.proxy.recv_msg.msg_control) + + cdef do_handshake(self): + if self.state == UNWRAPPED: + self.state = HANDSHAKING + elif self.state != HANDSHAKING: + raise RuntimeError("Cannot do handshake now") + + try: + self.sslobj.do_handshake() + except ssl.SSLWantReadError: + if DEBUG: + print("do_handshake() SSLWantReadError") + except ssl.SSLWantWriteError: + if DEBUG: + print("do_handshake() SSLWantWriteError") + except Exception as ex: + if DEBUG: + print('do_handshake() error:', ex) + raise + else: + if DEBUG: + print('do_handshake() done') + + self.state = WRAPPED + if self.waiter: + self.waiter.set_result(self) + self.waiter = None + + if bio.test_flags(self.bio, FLAGS_KTLS_RX): + self.proxy.read_buffer = PyMem_RawMalloc(65536) + if self.proxy.read_buffer == NULL: + raise MemoryError + self.proxy.flags |= FLAGS_PROXY_RECV_KTLS + self.proxy.recv_vec.iov_base = self.proxy.read_buffer + self.proxy.recv_vec.iov_len = 65536 + self.do_read_ktls() + else: + self.do_read() + + cdef do_read_ktls(self): + cdef: + int res + libc.cmsghdr* cmsg + unsigned char record_type + + if self.proxy.flags & FLAGS_PROXY_RECV_COMPLETED: + self.proxy.flags &= ~FLAGS_PROXY_RECV_ALL + res = self.proxy.recv_callback.res + if res < 0: + if DEBUG: + print("do_read_ktls() error:", -res) + self.loop.call_soon( + self.protocol.connection_lost, + IOError(-res, string.strerror(-res)) + ) + elif res == 0: + if DEBUG: + print("do_read_ktls() EOF") + self.loop.call_soon(self.protocol.eof_received) + self.loop.call_soon(self.protocol.connection_lost, None) + else: + if self.proxy.recv_msg.msg_controllen: + cmsg = libc.CMSG_FIRSTHDR(&self.proxy.recv_msg) + if cmsg.cmsg_type == linux.TLS_GET_RECORD_TYPE: + record_type = (libc.CMSG_DATA(cmsg))[0] + if record_type != ssl_h.SSL3_RT_APPLICATION_DATA: + if DEBUG: + print("do_read_ktls() forward CMSG") + return self.do_read() + if DEBUG: + print("do_read_ktls() received", res, "bytes") + self.loop.call_soon( + self.protocol.data_received, + bytes(self.proxy.read_buffer[:res]), + ) + self.loop.call_soon(self.do_read_ktls, self) + + elif not self.proxy.flags & FLAGS_PROXY_RECV_SUBMITTED: + if DEBUG: + print("do_read_ktls() submit") + self.proxy.recv_msg.msg_controllen = CMSG_SIZE + reset_msg(&self.proxy.recv_msg, CMSG_SIZE) + if not ring_sq_submit_recvmsg( + &self.proxy.loop.ring.sq, + self.fd, + &self.proxy.recv_msg, + &self.proxy.recv_callback, + ): + raise RuntimeError("SQ full") + self.proxy.flags |= FLAGS_PROXY_RECV_SUBMITTED + + cdef do_read(self): + try: + data = self.sslobj.read(65536) + except ssl.SSLWantReadError: + if DEBUG: + print("do_read() SSLWantReadError") + except ssl.SSLWantWriteError: + if DEBUG: + print("do_read() SSLWantWriteError") + except Exception as ex: + if DEBUG: + print("do_read() error:", ex) + self.loop.call_soon(self.protocol.connection_lost, ex) + else: + if data: + if DEBUG: + print("do_read() received", len(data), bytes) + print(data) + self.loop.call_soon(self.protocol.data_received, data) + if self.proxy.flags & FLAGS_PROXY_RECV_KTLS: + self.loop.call_soon(self.do_read_ktls, self) + else: + self.loop.call_soon(self.do_read, self) + else: + if DEBUG: + print("do_read() EOF") + self.loop.call_soon(self.protocol.eof_received) + self.loop.call_soon(self.protocol.connection_lost, None) + + cdef write_cb(self, int res): + if self.state == HANDSHAKING: + self.do_handshake() + + cdef read_cb(self, int res): + if self.state == HANDSHAKING: + self.do_handshake() + elif self.state == WRAPPED: + if self.proxy.flags & FLAGS_PROXY_RECV_KTLS: + self.do_read_ktls() + else: + self.do_read() + + def write(self, data): + if self.sending: + self.write_buffer.append(data) + else: + try: + self.sslobj.write(data) + except ssl.SSLWantWriteError: + if DEBUG: + print("write() SSLWantWriteError") cdef bio.Method* KTLS_BIO_METHOD = bio.meth_new( diff --git a/src/kloop/uring.pxd b/src/kloop/uring.pxd index 5477fe3..b7a0a05 100644 --- a/src/kloop/uring.pxd +++ b/src/kloop/uring.pxd @@ -56,3 +56,19 @@ cdef struct RingCallback: void* data int res int (*callback)(RingCallback* cb) nogil except 0 + + +cdef int ring_sq_submit_sendmsg( + SubmissionQueue* sq, + int fd, + const libc.msghdr *msg, + RingCallback* callback, +) nogil + + +cdef int ring_sq_submit_recvmsg( + SubmissionQueue* sq, + int fd, + const libc.msghdr *msg, + RingCallback* callback, +) nogil