From 761f741d5d8f49cd6ec6b51bf6d169015a99db42 Mon Sep 17 00:00:00 2001 From: Fantix King Date: Sat, 2 Jul 2022 14:37:19 -0400 Subject: [PATCH] TLS: fix buffer usage and reset message Also improved debug --- src/kloop/includes/openssl/ssl.pxd | 5 + src/kloop/tls.pxd | 2 + src/kloop/tls.pyx | 142 ++++++++++++++++------------- 3 files changed, 84 insertions(+), 65 deletions(-) diff --git a/src/kloop/includes/openssl/ssl.pxd b/src/kloop/includes/openssl/ssl.pxd index bc87e69..7372a01 100644 --- a/src/kloop/includes/openssl/ssl.pxd +++ b/src/kloop/includes/openssl/ssl.pxd @@ -20,6 +20,11 @@ cdef extern from "openssl/ssl.h" nogil: int OP_ENABLE_KTLS "SSL_OP_ENABLE_KTLS" int set_options "SSL_set_options" (SSL* ssl, int options) + long SSL_MODE_RELEASE_BUFFERS + long SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER + long clear_mode "SSL_clear_mode" (SSL* ssl, long mode) + int free_buffers "SSL_free_buffers" (SSL *ssl) + cdef extern from *: """ diff --git a/src/kloop/tls.pxd b/src/kloop/tls.pxd index 3bfe8af..2a7163c 100644 --- a/src/kloop/tls.pxd +++ b/src/kloop/tls.pxd @@ -25,6 +25,7 @@ cdef struct Proxy: RingCallback recv_callback unsigned char flags char* read_buffer + void* cmsg Loop* loop int fd @@ -51,6 +52,7 @@ cdef class TLSTransport: object write_buffer bint sending + cdef do_handshake(self) cdef do_read(self) cdef do_read_ktls(self) diff --git a/src/kloop/tls.pyx b/src/kloop/tls.pyx index 66c0ae0..8da3bae 100644 --- a/src/kloop/tls.pyx +++ b/src/kloop/tls.pyx @@ -15,7 +15,7 @@ import ssl from cpython cimport PyMem_RawMalloc, PyMem_RawFree from libc cimport errno, string -from .includes.openssl cimport bio, err, ssl as ssl_h +from .includes.openssl cimport err, ssl as ssl_h from .includes cimport pyssl, linux from .loop cimport ring_sq_submit_sendmsg, ring_sq_submit_recvmsg @@ -48,11 +48,16 @@ cdef size_t CMSG_SIZE = libc.CMSG_SPACE(sizeof(unsigned char)) DEF DEBUG = 0 -cdef inline void reset_msg(libc.msghdr* msg, size_t controllen) nogil: +cdef inline void reset_msg(libc.msghdr* msg, void* cmsg) nogil: msg.msg_name = NULL msg.msg_namelen = 0 msg.msg_flags = 0 - msg.msg_controllen = controllen + if cmsg == NULL: + msg.msg_control = NULL + msg.msg_controllen = 0 + else: + msg.msg_control = cmsg + msg.msg_controllen = CMSG_SIZE cdef object fromOpenSSLError(object err_type): @@ -72,18 +77,18 @@ cdef int bio_write_ex( Proxy* proxy = bio.get_data(b) int res - if DEBUG: + IF DEBUG: with gil: print("bio_write_ex(data=%x, datal=%d)" % (data, datal)) if proxy.flags & FLAGS_PROXY_SEND_SUBMITTED: if proxy.send_vec.iov_base != data: - if DEBUG: + IF DEBUG: with gil: print("bio_write_ex() error: concurrent call") return 0 if proxy.send_vec.iov_len > datal: - if DEBUG: + IF DEBUG: with gil: print("bio_write_ex() error: short rewrite") return 0 @@ -93,40 +98,42 @@ cdef int bio_write_ex( proxy.flags &= ~FLAGS_PROXY_SEND_ALL res = proxy.send_callback.res if res < 0: - if DEBUG: + IF DEBUG: with gil: print("bio_write_ex() error:", -res) errno.errno = -res return 0 written[0] = res - if DEBUG: + IF DEBUG: with gil: - print('bio_write_ex() written:', res) + print("bio_write_ex() written:", res) + print(">>> ", end="") for i in range(res): print( "%02x " % data[i], - end="" if (i + 1) % 16 and i < res - 1 else "\n", + end="" if (i + 1) % 16 or i == res - 1 else "\n>>> ", ) + print() else: written[0] = 0 bio.set_retry_write(b) if not proxy.flags & FLAGS_PROXY_SEND_SUBMITTED: - if DEBUG: + IF DEBUG: with gil: print("bio_write_ex() submit") proxy.send_vec.iov_base = data proxy.send_vec.iov_len = datal - reset_msg(&proxy.send_msg, 0) + reset_msg(&proxy.send_msg, NULL) if not ring_sq_submit_sendmsg( &proxy.loop.ring.sq, proxy.fd, &proxy.send_msg, &proxy.send_callback, ): - if DEBUG: + IF DEBUG: with gil: print("bio_write_ex() error: SQ full") return 0 @@ -144,18 +151,18 @@ cdef int bio_read_ex( int res int is_ktls = bio.test_flags(b, FLAGS_KTLS_RX) - if DEBUG: + IF DEBUG: with gil: print('bio_read_ex(data=%x, datal=%d)' % (data, datal)) if proxy.flags & FLAGS_PROXY_RECV_SUBMITTED: if proxy.recv_vec.iov_base != (data + 5 if is_ktls else data): - if DEBUG: + IF DEBUG: with gil: print("bio_read_ex() error: concurrent call") return 0 if proxy.recv_vec.iov_len > (datal - 21 if is_ktls else datal): - if DEBUG: + IF DEBUG: with gil: print("bio_read_ex() error: short reread") return 0 @@ -164,7 +171,7 @@ cdef int bio_read_ex( if proxy.flags & FLAGS_PROXY_RECV_KTLS: res = proxy.recv_callback.res if datal < res + 5: - if DEBUG: + IF DEBUG: with gil: print("bio_read_ex() error: datal too short") errno.errno = errno.EINVAL @@ -184,24 +191,22 @@ cdef int bio_read_ex( string.memcpy(data, proxy.read_buffer, res) readbytes[0] = res - if DEBUG: + IF DEBUG: with gil: - print( - "bio_read_ex() read:", - res, - "(forwarded TLS record)" - ) + print("bio_read_ex() read:", res, "(forwarded TLS record)") + print("<<< ", end="") for i in range(res): print( "%02x " % data[i], - end="" if (i + 1) % 16 and i < res - 1 else "\n", + end="" if (i + 1) % 16 or i == res - 1 else "\n<<< ", ) + print() elif proxy.flags & FLAGS_PROXY_RECV_COMPLETED: proxy.flags &= ~FLAGS_PROXY_RECV_ALL res = proxy.recv_callback.res if res < 0: - if DEBUG: + IF DEBUG: with gil: print("bio_read_ex() error:", -res) errno.errno = -res @@ -223,25 +228,25 @@ cdef int bio_read_ex( bio.set_flags(b, bio.FLAGS_IN_EOF) readbytes[0] = res - if DEBUG: + IF DEBUG: with gil: print( - "bio_read_ex() read:", - res, - "(TLS record)" if cmsg else "" + "bio_read_ex() read:", res, "(TLS record)" if cmsg else "" ) + print("<<< ", end="") for i in range(res): print( "%02x " % data[i], - end="" if (i + 1) % 16 and i < res - 1 else "\n", + end="" if (i + 1) % 16 or i == res - 1 else "\n<<< ", ) + print() else: bio.set_retry_read(b) readbytes[0] = 0 if not proxy.flags & FLAGS_PROXY_RECV_SUBMITTED: if is_ktls: if datal < 21: - if DEBUG: + IF DEBUG: with gil: print("bio_read_ex() error: datal too short") errno.errno = errno.EINVAL @@ -249,7 +254,8 @@ cdef int bio_read_ex( proxy.recv_vec.iov_base = data + 5 proxy.recv_vec.iov_len = datal - 21 - if DEBUG: + reset_msg(&proxy.recv_msg, proxy.cmsg) + IF DEBUG: with gil: print("bio_read_ex() submit(%x, %d)" % ( proxy.recv_vec.iov_base, @@ -258,10 +264,10 @@ cdef int bio_read_ex( else: proxy.recv_vec.iov_base = data proxy.recv_vec.iov_len = datal - if DEBUG: + reset_msg(&proxy.recv_msg, NULL) + IF DEBUG: with gil: print("bio_read_ex() submit") - reset_msg(&proxy.recv_msg, CMSG_SIZE) if not ring_sq_submit_recvmsg( &proxy.loop.ring.sq, @@ -269,7 +275,7 @@ cdef int bio_read_ex( &proxy.recv_msg, &proxy.recv_callback, ): - if DEBUG: + IF DEBUG: with gil: print("bio_read_ex() error: SQ full") return 0 @@ -283,24 +289,24 @@ cdef long bio_ctrl(bio.BIO* b, int cmd, long num, void* ptr) nogil: ssl_h.ktls_crypto_info_t* crypto_info long ret = 0 if cmd == bio.BIO_CTRL_EOF: - if DEBUG: + IF DEBUG: with gil: print("BIO_CTRL_EOF", ret) elif cmd == bio.BIO_CTRL_PUSH: - if DEBUG: + IF DEBUG: with gil: print("BIO_CTRL_PUSH", ret) elif cmd == bio.BIO_CTRL_POP: - if DEBUG: + IF DEBUG: with gil: print("BIO_CTRL_POP", ret) elif cmd == bio.BIO_CTRL_FLUSH: ret = 1 - if DEBUG: + IF DEBUG: with gil: print('BIO_CTRL_FLUSH', ret) elif cmd == BIO_CTRL_SET_KTLS: - if DEBUG: + IF DEBUG: with gil: print("BIO_CTRL_SET_KTLS", "TX end" if num else "RX end") crypto_info = ptr @@ -313,7 +319,7 @@ cdef long bio_ctrl(bio.BIO* b, int cmd, long num, void* ptr) nogil: ) == 0: bio.set_flags(b, FLAGS_KTLS_TX if num else FLAGS_KTLS_RX) else: - if DEBUG: + IF DEBUG: with gil: print( "BIO_CTRL_SET_KTLS", @@ -325,7 +331,7 @@ cdef long bio_ctrl(bio.BIO* b, int cmd, long num, void* ptr) nogil: elif cmd == BIO_CTRL_GET_KTLS_RECV: return bio.test_flags(b, FLAGS_KTLS_RX) != 0 else: - if DEBUG: + IF DEBUG: with gil: print('bio_ctrl', cmd, num) return ret @@ -372,6 +378,7 @@ cdef class TLSTransport: cdef: TLSTransport rv = TLSTransport.__new__(TLSTransport) pyssl.PySSLMemoryBIO* c_bio + pyssl.SSL* s libc.setsockopt(fd, socket.SOL_TCP, linux.TCP_ULP, b"tls", 3) @@ -386,8 +393,12 @@ cdef class TLSTransport: c_bio.bio, rv.bio = rv.bio, c_bio.bio del py_bio - ssl_h.set_options( - (rv.sslobj._sslobj).ssl, ssl_h.OP_ENABLE_KTLS + s = (rv.sslobj._sslobj).ssl + ssl_h.set_options(s, ssl_h.OP_ENABLE_KTLS) + ssl_h.clear_mode( + s, + ssl_h.SSL_MODE_RELEASE_BUFFERS | + ssl_h.SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER ) rv.fd = fd rv.protocol = protocol @@ -409,10 +420,9 @@ cdef class TLSTransport: self.proxy.send_msg.msg_iovlen = 1 self.proxy.send_callback.data = &self.proxy self.proxy.send_callback.callback = tls_send_cb - self.proxy.recv_msg.msg_control = PyMem_RawMalloc(CMSG_SIZE) - if self.proxy.recv_msg.msg_control == NULL: + self.proxy.cmsg = PyMem_RawMalloc(CMSG_SIZE) + if self.proxy.cmsg == NULL: raise MemoryError - self.proxy.recv_msg.msg_controllen = CMSG_SIZE self.proxy.recv_msg.msg_iov = &self.proxy.recv_vec self.proxy.recv_msg.msg_iovlen = 1 self.proxy.recv_callback.data = &self.proxy @@ -423,7 +433,7 @@ cdef class TLSTransport: self.sslobj = None bio.free(self.bio) PyMem_RawFree(self.proxy.read_buffer) - PyMem_RawFree(self.proxy.recv_msg.msg_control) + PyMem_RawFree(self.proxy.cmsg) cdef do_handshake(self): if self.state == UNWRAPPED: @@ -431,20 +441,23 @@ cdef class TLSTransport: elif self.state != HANDSHAKING: raise RuntimeError("Cannot do handshake now") + try: + IF DEBUG: + print("do_handshake()") self.sslobj.do_handshake() except ssl.SSLWantReadError: - if DEBUG: + IF DEBUG: print("do_handshake() SSLWantReadError") except ssl.SSLWantWriteError: - if DEBUG: + IF DEBUG: print("do_handshake() SSLWantWriteError") except Exception as ex: - if DEBUG: + IF DEBUG: print('do_handshake() error:', ex) raise else: - if DEBUG: + IF DEBUG: print('do_handshake() done') self.state = WRAPPED @@ -473,14 +486,14 @@ cdef class TLSTransport: self.proxy.flags &= ~FLAGS_PROXY_RECV_ALL res = self.proxy.recv_callback.res if res < 0: - if DEBUG: + IF DEBUG: print("do_read_ktls() error:", -res) self.loop.call_soon( self.protocol.connection_lost, IOError(-res, string.strerror(-res)) ) elif res == 0: - if DEBUG: + IF DEBUG: print("do_read_ktls() EOF") self.loop.call_soon(self.protocol.eof_received) self.loop.call_soon(self.protocol.connection_lost, None) @@ -490,10 +503,10 @@ cdef class TLSTransport: if cmsg.cmsg_type == linux.TLS_GET_RECORD_TYPE: record_type = (libc.CMSG_DATA(cmsg))[0] if record_type != ssl_h.SSL3_RT_APPLICATION_DATA: - if DEBUG: + IF DEBUG: print("do_read_ktls() forward CMSG") return self.do_read() - if DEBUG: + IF DEBUG: print("do_read_ktls() received", res, "bytes") self.loop.call_soon( self.protocol.data_received, @@ -502,10 +515,9 @@ cdef class TLSTransport: self.loop.call_soon(self.do_read_ktls, self) elif not self.proxy.flags & FLAGS_PROXY_RECV_SUBMITTED: - if DEBUG: + IF DEBUG: print("do_read_ktls() submit") - self.proxy.recv_msg.msg_controllen = CMSG_SIZE - reset_msg(&self.proxy.recv_msg, CMSG_SIZE) + reset_msg(&self.proxy.recv_msg, self.proxy.cmsg) if not ring_sq_submit_recvmsg( &self.proxy.loop.ring.sq, self.fd, @@ -519,18 +531,18 @@ cdef class TLSTransport: try: data = self.sslobj.read(65536) except ssl.SSLWantReadError: - if DEBUG: + IF DEBUG: print("do_read() SSLWantReadError") except ssl.SSLWantWriteError: - if DEBUG: + IF DEBUG: print("do_read() SSLWantWriteError") except Exception as ex: - if DEBUG: + IF DEBUG: print("do_read() error:", ex) self.loop.call_soon(self.protocol.connection_lost, ex) else: if data: - if DEBUG: + IF DEBUG: print("do_read() received", len(data), bytes) print(data) self.loop.call_soon(self.protocol.data_received, data) @@ -539,7 +551,7 @@ cdef class TLSTransport: else: self.loop.call_soon(self.do_read, self) else: - if DEBUG: + IF DEBUG: print("do_read() EOF") self.loop.call_soon(self.protocol.eof_received) self.loop.call_soon(self.protocol.connection_lost, None) @@ -564,7 +576,7 @@ cdef class TLSTransport: try: self.sslobj.write(data) except ssl.SSLWantWriteError: - if DEBUG: + IF DEBUG: print("write() SSLWantWriteError")