diff --git a/src/kloop/includes/libc.pxd b/src/kloop/includes/libc.pxd index cdd0698..ae75fea 100644 --- a/src/kloop/includes/libc.pxd +++ b/src/kloop/includes/libc.pxd @@ -31,6 +31,16 @@ cdef extern from "sys/socket.h" nogil: size_t msg_controllen # ancillary data buffer len int msg_flags # flags on received message + struct cmsghdr: + socklen_t cmsg_len # data byte count, including header + int cmsg_level # originating protocol + int cmsg_type # protocol-specific type + + size_t CMSG_LEN(size_t length) + cmsghdr* CMSG_FIRSTHDR(msghdr* msgh) + unsigned char* CMSG_DATA(cmsghdr* cmsg) + size_t CMSG_SPACE(size_t length) + cdef extern from "arpa/inet.h" nogil: int inet_pton(int af, char* src, void* dst) diff --git a/src/kloop/includes/linux.pxd b/src/kloop/includes/linux.pxd index 4c9f821..9b9afc0 100644 --- a/src/kloop/includes/linux.pxd +++ b/src/kloop/includes/linux.pxd @@ -32,6 +32,8 @@ cdef extern from "linux/tcp.h" nogil: cdef extern from "linux/tls.h" nogil: + int TLS_GET_RECORD_TYPE + __u16 TLS_CIPHER_AES_GCM_256 int TLS_CIPHER_AES_GCM_256_IV_SIZE int TLS_CIPHER_AES_GCM_256_SALT_SIZE diff --git a/src/kloop/includes/ssl.pxd b/src/kloop/includes/ssl.pxd index bd77c26..32c3af0 100644 --- a/src/kloop/includes/ssl.pxd +++ b/src/kloop/includes/ssl.pxd @@ -23,6 +23,16 @@ cdef extern from "openssl/ssl.h" nogil: SSL_CTX_keylog_cb_func SSL_CTX_get_keylog_callback(SSL_CTX* ctx) SSL_CTX* SSL_get_SSL_CTX(SSL* ssl) + ctypedef enum OSSL_HANDSHAKE_STATE: + pass + + OSSL_HANDSHAKE_STATE SSL_get_state(const SSL *ssl); + + unsigned int SSL3_RT_CHANGE_CIPHER_SPEC + unsigned int SSL3_RT_ALERT + unsigned int SSL3_RT_HANDSHAKE + unsigned int SSL3_RT_APPLICATION_DATA + cdef extern from "includes/ssl.h" nogil: ctypedef struct PySSLSocket: diff --git a/src/kloop/ktls.pyx b/src/kloop/ktls.pyx index f978037..ce3f4d3 100644 --- a/src/kloop/ktls.pyx +++ b/src/kloop/ktls.pyx @@ -11,6 +11,7 @@ import socket import hmac import hashlib +import struct from ssl import SSLWantReadError from cpython cimport PyErr_SetFromErrno @@ -61,20 +62,14 @@ def do_handshake_capturing_secrets(sslobj): ssl.SSL_CTX_set_keylog_callback(ctx, orig_cb) -def hkdf_expand(pseudo_random_key, info=b"", length=32, hash=hashlib.sha384): +def hkdf_expand(pseudo_random_key, label, length, hash_method=hashlib.sha384): ''' Expand `pseudo_random_key` and `info` into a key of length `bytes` using HKDF's expand function based on HMAC with the provided hash (default SHA-512). See the HKDF draft RFC and paper for usage notes. ''' - # info_in = info - # info = b'\0' + struct.pack("H", len(info)) + info + b'\0' - # print(f'hkdf_expand info_in= label={info.hex()}') - hash_len = hash().digest_size - length = int(length) - if length > 255 * hash_len: - raise Exception("Cannot expand to more than 255 * %d = %d bytes using the specified hash function" % \ - (hash_len, 255 * hash_len)) + info = struct.pack("!HB", length, len(label)) + label + b'\0' + hash_len = hash_method().digest_size blocks_needed = length // hash_len + (0 if length % hash_len == 0 else 1) # ceil okm = b"" output_block = b"" @@ -82,7 +77,7 @@ def hkdf_expand(pseudo_random_key, info=b"", length=32, hash=hashlib.sha384): output_block = hmac.new( pseudo_random_key, (output_block + info + bytearray((counter + 1,))), - hash, + hash_method, ).digest() okm += output_block return okm[:length] @@ -95,6 +90,12 @@ def enable_ulp(sock): return +def get_state(sslobj): + cdef: + ssl.SSL* s = (sslobj._sslobj).ssl + print(ssl.SSL_get_state(s)) + + def upgrade_aes_gcm_256(sslobj, sock, secret, sending): cdef: ssl.SSL* s = (sslobj._sslobj).ssl @@ -108,7 +109,7 @@ def upgrade_aes_gcm_256(sslobj, sock, secret, sending): # s->rlayer->read_sequence seq = ((s) + 6104) - # print(sslobj.cipher()) + # print(sslobj.cipher()) string.memset(&crypto_info, 0, sizeof(crypto_info)) crypto_info.info.cipher_type = linux.TLS_CIPHER_AES_GCM_256 @@ -116,7 +117,7 @@ def upgrade_aes_gcm_256(sslobj, sock, secret, sending): key = hkdf_expand( secret, - b'\x00 \ttls13 key\x00', + b'tls13 key', linux.TLS_CIPHER_AES_GCM_256_KEY_SIZE, ) string.memcpy( @@ -131,7 +132,7 @@ def upgrade_aes_gcm_256(sslobj, sock, secret, sending): ) iv = hkdf_expand( secret, - b'\x00\x0c\x08tls13 iv\x00', + b'tls13 iv', linux.TLS_CIPHER_AES_GCM_256_IV_SIZE + linux.TLS_CIPHER_AES_GCM_256_SALT_SIZE, ) diff --git a/src/kloop/loop.py b/src/kloop/loop.py index 8e72fe6..d2941db 100644 --- a/src/kloop/loop.py +++ b/src/kloop/loop.py @@ -13,6 +13,7 @@ import asyncio.futures import asyncio.trsock import asyncio.transports import contextvars +import errno import socket import ssl @@ -98,11 +99,14 @@ class KLoopSocketTransport( self._recv_buffer_factory = protocol.get_buffer else: self._read_ready_cb = self._read_ready__data_received - self._recv_buffer = bytearray(256 * 1024) + self._recv_buffer = bytearray(64 * 1024 * 1024) self._recv_buffer_factory = lambda _hint: self._recv_buffer self._protocol = protocol def _read(self): + # print("RecvMsgWork") + if self._read_paused: + return self._loop._selector.submit( uring.RecvMsgWork( self._sock.fileno(), @@ -111,7 +115,7 @@ class KLoopSocketTransport( ) ) - def _read_ready__buffer_updated(self, res): + def _read_ready__buffer_updated(self, res, app_data): if res < 0: raise IOError elif res == 0: @@ -124,20 +128,29 @@ class KLoopSocketTransport( if not self._closing: self._read() - def _read_ready__data_received(self, res): + def _read_ready__data_received(self, res, app_data): + print("_read_ready__data_received", res) if res < 0: - raise IOError(f"{res}") + if abs(res) == errno.EAGAIN: + print('EAGAIN') + self._read() + else: + raise IOError(f"{res}") elif res == 0: self._protocol.eof_received() else: try: - # print(f"data received: {res}") - self._protocol.data_received(self._recv_buffer[:res]) + print(f"data received: {res}") + data = bytes(self._recv_buffer[:res]) + # print(f"data received: {data}") + if app_data: + self._protocol.data_received(data) finally: if not self._closing: self._read() def _write_done(self, res): + # print("_write_done") self._current_work = None if res < 0: # TODO: force close transport @@ -153,6 +166,7 @@ class KLoopSocketTransport( self._sock.fileno(), self._buffers, self._write_done ) self._loop._selector.submit(self._current_work) + # print("more SendWork") self._buffers = [] elif self._closing: self._loop.call_soon(self._call_connection_lost, None) @@ -168,6 +182,7 @@ class KLoopSocketTransport( self._sock.fileno(), data, self._write_done ) self._loop._selector.submit(self._current_work) + # print("SendWork") else: self._buffers.append(data) self._maybe_pause_protocol() @@ -233,45 +248,87 @@ class KLoopSSLHandshakeProtocol(asyncio.Protocol): self._handshake() def _handshake(self): + ktls.get_state(self._sslobj) success, secrets = ktls.do_handshake_capturing_secrets(self._sslobj) self._secrets.update(secrets) if success: + # print("handshake done") if self._handshaking: + print(self._sslobj.cipher()) + ktls.get_state(self._sslobj) self._handshaking = False + try: + data = self._sslobj.read(64 * 1024) + except ssl.SSLWantReadError: + data = None + if data: + while True: + try: + data += self._sslobj.read(64 * 1024) + except ssl.SSLWantReadError: + break + print("try read", data) + self._transport._upgrade_ktls_read( + self._sslobj, + self._secrets["SERVER_TRAFFIC_SECRET_0"], + data, + ) if data := self._outgoing.read(): + print("last message") self._transport.write(data) self._transport._write_waiter = self._after_last_write + # self._transport._write_waiter = lambda: self._transport._loop.call_later(1, self._after_last_write) self._transport.pause_reading() else: self._after_last_write() - # else: - # try: - # data = self._sslobj.read(16384) - # except ssl.SSLWantReadError: - # data = None - # self._transport._upgrade_ktls_read( - # self._sslobj, - # self._secrets["SERVER_TRAFFIC_SECRET_0"], - # data, - # ) + else: + assert False + # try: + # data = self._sslobj.read(64 * 1024) + # except ssl.SSLWantReadError: + # data = None + # if data: + # while True: + # try: + # data += self._sslobj.read(64 * 1024) + # except ssl.SSLWantReadError: + # break + # print("try read", data) + # # ktls.get_state(self._sslobj) + # self._after_last_write() + # self._transport._upgrade_ktls_read( + # self._sslobj, + # self._secrets["SERVER_TRAFFIC_SECRET_0"], + # data, + # ) else: + # print("SSLWantReadError") if data := self._outgoing.read(): self._transport.write(data) def _after_last_write(self): - try: - data = self._sslobj.read(16384) - except ssl.SSLWantReadError: - data = None + print("_after_last_write") + ktls.get_state(self._sslobj) + # try: + # data = self._sslobj.read(64 * 1024) + # except ssl.SSLWantReadError: + # data = None + # if data: + # while True: + # try: + # data += self._sslobj.read(64 * 1024) + # except ssl.SSLWantReadError: + # break + # print("try read", data) self._transport._upgrade_ktls_write( self._sslobj, self._secrets["CLIENT_TRAFFIC_SECRET_0"], ) - self._transport._upgrade_ktls_read( - self._sslobj, - self._secrets["SERVER_TRAFFIC_SECRET_0"], - data, - ) + # self._transport._upgrade_ktls_read( + # self._sslobj, + # self._secrets["SERVER_TRAFFIC_SECRET_0"], + # data, + # ) self._transport.resume_reading() @@ -303,7 +360,9 @@ class KLoopSSLTransport(KLoopSocketTransport): self._waiter = waiter def _upgrade_ktls_write(self, sslobj, secret): + print("_upgrade_ktls_write") ktls.upgrade_aes_gcm_256(sslobj, self._sock, secret, True) + self.set_protocol(self._app_protocol) self._loop.call_soon(self._app_protocol.connection_made, self) if self._waiter is not None: self._loop.call_soon( @@ -313,13 +372,14 @@ class KLoopSSLTransport(KLoopSocketTransport): ) def _upgrade_ktls_read(self, sslobj, secret, data): + print("_upgrade_ktls_read") ktls.upgrade_aes_gcm_256(sslobj, self._sock, secret, False) - self.set_protocol(self._app_protocol) - if data is not None: - if data: - self._app_protocol.data_received(data) - else: - self._app_protocol.eof_received() + # self.set_protocol(self._app_protocol) + # if data is not None: + # if data: + # self._app_protocol.data_received(data) + # else: + # self._app_protocol.eof_received() class KLoop(asyncio.BaseEventLoop): @@ -344,6 +404,7 @@ class KLoop(asyncio.BaseEventLoop): def _make_socket_transport( self, sock, protocol, waiter=None, *, extra=None, server=None ): + sock.setblocking(True) return KLoopSocketTransport( self, sock, protocol, waiter, extra, server ) diff --git a/src/kloop/uring.pyx b/src/kloop/uring.pyx index e4e1e55..7a73190 100644 --- a/src/kloop/uring.pyx +++ b/src/kloop/uring.pyx @@ -16,7 +16,7 @@ from cpython cimport PyMem_RawMalloc, PyMem_RawFree from libc cimport errno, string from posix cimport mman -from .includes cimport barrier, libc, linux +from .includes cimport barrier, libc, linux, ssl cdef linux.__u32 SIG_SIZE = libc._NSIG // 8 @@ -187,6 +187,7 @@ cdef class Ring: def submit(self, Work work): cdef linux.io_uring_sqe* sqe = self.sq.next_sqe() + # print(f"submit: {work}") work.submit(sqe) def select(self, timeout): @@ -225,19 +226,22 @@ cdef class Ring: if need_enter: arg.sigmask = 0 arg.sigmask_sz = SIG_SIZE - print(f"SYS_io_uring_enter(submit={submit}, wait_nr={wait_nr}, " - f"flags={flags:b}, timeout={timeout})") - ret = libc.syscall( - libc.SYS_io_uring_enter, - self.enter_fd, - submit, - wait_nr, - flags, - &arg, - sizeof(arg), - ) + # print(f"SYS_io_uring_enter(submit={submit}, wait_nr={wait_nr}, " + # f"flags={flags:b}, timeout={timeout})") + with nogil: + ret = libc.syscall( + libc.SYS_io_uring_enter, + self.enter_fd, + submit, + wait_nr, + flags, + &arg, + sizeof(arg), + ) if ret < 0: if errno.errno != errno.ETIME: + print(f"SYS_io_uring_enter(submit={submit}, wait_nr={wait_nr}, " + f"flags={flags:b}, timeout={timeout})") PyErr_SetFromErrno(IOError) return @@ -263,7 +267,7 @@ cdef class Work: int op, linux.io_uring_sqe * sqe, int fd, - void * addr, + void* addr, unsigned len, linux.__u64 offset, ): @@ -386,6 +390,7 @@ cdef class RecvWork(Work): cdef class RecvMsgWork(Work): def __init__(self, int fd, buffers, callback): + cdef size_t size = libc.CMSG_SPACE(sizeof(unsigned char)) self.fd = fd self.buffers = buffers self.callback = callback @@ -398,9 +403,9 @@ cdef class RecvMsgWork(Work): for i, buf in enumerate(buffers): self.msg.msg_iov[i].iov_base = buf self.msg.msg_iov[i].iov_len = len(buf) - self.control_msg = bytearray(256) + self.control_msg = bytearray(size) self.msg.msg_control = self.control_msg - self.msg.msg_controllen = 256 + self.msg.msg_controllen = size def __dealloc__(self): if self.msg.msg_iov != NULL: @@ -410,11 +415,24 @@ cdef class RecvMsgWork(Work): self._submit(linux.IORING_OP_RECVMSG, sqe, self.fd, &self.msg, 1, 0) def complete(self): - if self.res < 0: - errno.errno = abs(self.res) - PyErr_SetFromErrno(IOError) - return - # if self.msg.msg_controllen: - # print('control_msg:', self.control_msg[:self.msg.msg_controllen]) - # print('flags:', self.msg.msg_flags) - self.callback(self.res) + cdef: + libc.cmsghdr* cmsg + unsigned char* cmsg_data + unsigned char record_type + # if self.res < 0: + # errno.errno = abs(self.res) + # PyErr_SetFromErrno(IOError) + # return + app_data = True + if self.msg.msg_controllen: + print('msg_controllen:', self.msg.msg_controllen) + cmsg = libc.CMSG_FIRSTHDR(&self.msg) + if cmsg.cmsg_level == libc.SOL_TLS and cmsg.cmsg_type == linux.TLS_GET_RECORD_TYPE: + cmsg_data = libc.CMSG_DATA(cmsg) + record_type = (cmsg_data)[0] + if record_type != ssl.SSL3_RT_APPLICATION_DATA: + app_data = False + print(f'cmsg.len={cmsg.cmsg_len}, cmsg.level={cmsg.cmsg_level}, cmsg.type={cmsg.cmsg_type}') + print(f'record type: {record_type}') + print('flags:', self.msg.msg_flags) + self.callback(self.res, app_data) diff --git a/tests/test_loop.py b/tests/test_loop.py index d947891..67f2bac 100644 --- a/tests/test_loop.py +++ b/tests/test_loop.py @@ -38,14 +38,21 @@ class TestLoop(unittest.TestCase): def test_connect(self): ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + ctx.minimum_version = ssl.TLSVersion.TLSv1_3 + host = "www.google.com" r, w = self.loop.run_until_complete( - asyncio.open_connection("www.google.com", 443, ssl=ctx) + # asyncio.open_connection("127.0.0.1", 8080, ssl=ctx) + asyncio.open_connection(host, 443, ssl=ctx) ) - w.write(b"GET / HTTP/1.1\r\n" - b"Host: www.google.com\r\n" + self.loop.run_until_complete(asyncio.sleep(1)) + print('send request') + w.write(b"GET / HTTP/1.1\r\n" + + f"Host: {host}\r\n".encode("ISO-8859-1") + b"Connection: close\r\n" b"\r\n") - while line := self.loop.run_until_complete(r.readline()): + while line := self.loop.run_until_complete(r.read()): print(line) w.close() self.loop.run_until_complete(w.wait_closed())