1
0
Fork 0
mirror of https://gitee.com/fantix/kloop.git synced 2024-11-22 10:21:25 +00:00

Almost working PoC

This commit is contained in:
Fantix King 2022-04-19 18:32:40 -04:00
parent 97f02d40bc
commit 9fc44a51c8
No known key found for this signature in database
GPG key ID: 95304B04071CCDB4
7 changed files with 180 additions and 71 deletions

View file

@ -31,6 +31,16 @@ cdef extern from "sys/socket.h" nogil:
size_t msg_controllen # ancillary data buffer len size_t msg_controllen # ancillary data buffer len
int msg_flags # flags on received message int msg_flags # flags on received message
struct cmsghdr:
socklen_t cmsg_len # data byte count, including header
int cmsg_level # originating protocol
int cmsg_type # protocol-specific type
size_t CMSG_LEN(size_t length)
cmsghdr* CMSG_FIRSTHDR(msghdr* msgh)
unsigned char* CMSG_DATA(cmsghdr* cmsg)
size_t CMSG_SPACE(size_t length)
cdef extern from "arpa/inet.h" nogil: cdef extern from "arpa/inet.h" nogil:
int inet_pton(int af, char* src, void* dst) int inet_pton(int af, char* src, void* dst)

View file

@ -32,6 +32,8 @@ cdef extern from "linux/tcp.h" nogil:
cdef extern from "linux/tls.h" nogil: cdef extern from "linux/tls.h" nogil:
int TLS_GET_RECORD_TYPE
__u16 TLS_CIPHER_AES_GCM_256 __u16 TLS_CIPHER_AES_GCM_256
int TLS_CIPHER_AES_GCM_256_IV_SIZE int TLS_CIPHER_AES_GCM_256_IV_SIZE
int TLS_CIPHER_AES_GCM_256_SALT_SIZE int TLS_CIPHER_AES_GCM_256_SALT_SIZE

View file

@ -23,6 +23,16 @@ cdef extern from "openssl/ssl.h" nogil:
SSL_CTX_keylog_cb_func SSL_CTX_get_keylog_callback(SSL_CTX* ctx) SSL_CTX_keylog_cb_func SSL_CTX_get_keylog_callback(SSL_CTX* ctx)
SSL_CTX* SSL_get_SSL_CTX(SSL* ssl) SSL_CTX* SSL_get_SSL_CTX(SSL* ssl)
ctypedef enum OSSL_HANDSHAKE_STATE:
pass
OSSL_HANDSHAKE_STATE SSL_get_state(const SSL *ssl);
unsigned int SSL3_RT_CHANGE_CIPHER_SPEC
unsigned int SSL3_RT_ALERT
unsigned int SSL3_RT_HANDSHAKE
unsigned int SSL3_RT_APPLICATION_DATA
cdef extern from "includes/ssl.h" nogil: cdef extern from "includes/ssl.h" nogil:
ctypedef struct PySSLSocket: ctypedef struct PySSLSocket:

View file

@ -11,6 +11,7 @@
import socket import socket
import hmac import hmac
import hashlib import hashlib
import struct
from ssl import SSLWantReadError from ssl import SSLWantReadError
from cpython cimport PyErr_SetFromErrno from cpython cimport PyErr_SetFromErrno
@ -61,20 +62,14 @@ def do_handshake_capturing_secrets(sslobj):
ssl.SSL_CTX_set_keylog_callback(ctx, orig_cb) ssl.SSL_CTX_set_keylog_callback(ctx, orig_cb)
def hkdf_expand(pseudo_random_key, info=b"", length=32, hash=hashlib.sha384): def hkdf_expand(pseudo_random_key, label, length, hash_method=hashlib.sha384):
''' '''
Expand `pseudo_random_key` and `info` into a key of length `bytes` using Expand `pseudo_random_key` and `info` into a key of length `bytes` using
HKDF's expand function based on HMAC with the provided hash (default HKDF's expand function based on HMAC with the provided hash (default
SHA-512). See the HKDF draft RFC and paper for usage notes. SHA-512). See the HKDF draft RFC and paper for usage notes.
''' '''
# info_in = info info = struct.pack("!HB", length, len(label)) + label + b'\0'
# info = b'\0' + struct.pack("H", len(info)) + info + b'\0' hash_len = hash_method().digest_size
# print(f'hkdf_expand info_in= label={info.hex()}')
hash_len = hash().digest_size
length = int(length)
if length > 255 * hash_len:
raise Exception("Cannot expand to more than 255 * %d = %d bytes using the specified hash function" % \
(hash_len, 255 * hash_len))
blocks_needed = length // hash_len + (0 if length % hash_len == 0 else 1) # ceil blocks_needed = length // hash_len + (0 if length % hash_len == 0 else 1) # ceil
okm = b"" okm = b""
output_block = b"" output_block = b""
@ -82,7 +77,7 @@ def hkdf_expand(pseudo_random_key, info=b"", length=32, hash=hashlib.sha384):
output_block = hmac.new( output_block = hmac.new(
pseudo_random_key, pseudo_random_key,
(output_block + info + bytearray((counter + 1,))), (output_block + info + bytearray((counter + 1,))),
hash, hash_method,
).digest() ).digest()
okm += output_block okm += output_block
return okm[:length] return okm[:length]
@ -95,6 +90,12 @@ def enable_ulp(sock):
return return
def get_state(sslobj):
cdef:
ssl.SSL* s = (<ssl.PySSLSocket*>sslobj._sslobj).ssl
print(ssl.SSL_get_state(s))
def upgrade_aes_gcm_256(sslobj, sock, secret, sending): def upgrade_aes_gcm_256(sslobj, sock, secret, sending):
cdef: cdef:
ssl.SSL* s = (<ssl.PySSLSocket*>sslobj._sslobj).ssl ssl.SSL* s = (<ssl.PySSLSocket*>sslobj._sslobj).ssl
@ -116,7 +117,7 @@ def upgrade_aes_gcm_256(sslobj, sock, secret, sending):
key = hkdf_expand( key = hkdf_expand(
secret, secret,
b'\x00 \ttls13 key\x00', b'tls13 key',
linux.TLS_CIPHER_AES_GCM_256_KEY_SIZE, linux.TLS_CIPHER_AES_GCM_256_KEY_SIZE,
) )
string.memcpy( string.memcpy(
@ -131,7 +132,7 @@ def upgrade_aes_gcm_256(sslobj, sock, secret, sending):
) )
iv = hkdf_expand( iv = hkdf_expand(
secret, secret,
b'\x00\x0c\x08tls13 iv\x00', b'tls13 iv',
linux.TLS_CIPHER_AES_GCM_256_IV_SIZE + linux.TLS_CIPHER_AES_GCM_256_IV_SIZE +
linux.TLS_CIPHER_AES_GCM_256_SALT_SIZE, linux.TLS_CIPHER_AES_GCM_256_SALT_SIZE,
) )

View file

@ -13,6 +13,7 @@ import asyncio.futures
import asyncio.trsock import asyncio.trsock
import asyncio.transports import asyncio.transports
import contextvars import contextvars
import errno
import socket import socket
import ssl import ssl
@ -98,11 +99,14 @@ class KLoopSocketTransport(
self._recv_buffer_factory = protocol.get_buffer self._recv_buffer_factory = protocol.get_buffer
else: else:
self._read_ready_cb = self._read_ready__data_received self._read_ready_cb = self._read_ready__data_received
self._recv_buffer = bytearray(256 * 1024) self._recv_buffer = bytearray(64 * 1024 * 1024)
self._recv_buffer_factory = lambda _hint: self._recv_buffer self._recv_buffer_factory = lambda _hint: self._recv_buffer
self._protocol = protocol self._protocol = protocol
def _read(self): def _read(self):
# print("RecvMsgWork")
if self._read_paused:
return
self._loop._selector.submit( self._loop._selector.submit(
uring.RecvMsgWork( uring.RecvMsgWork(
self._sock.fileno(), self._sock.fileno(),
@ -111,7 +115,7 @@ class KLoopSocketTransport(
) )
) )
def _read_ready__buffer_updated(self, res): def _read_ready__buffer_updated(self, res, app_data):
if res < 0: if res < 0:
raise IOError raise IOError
elif res == 0: elif res == 0:
@ -124,20 +128,29 @@ class KLoopSocketTransport(
if not self._closing: if not self._closing:
self._read() self._read()
def _read_ready__data_received(self, res): def _read_ready__data_received(self, res, app_data):
print("_read_ready__data_received", res)
if res < 0: if res < 0:
if abs(res) == errno.EAGAIN:
print('EAGAIN')
self._read()
else:
raise IOError(f"{res}") raise IOError(f"{res}")
elif res == 0: elif res == 0:
self._protocol.eof_received() self._protocol.eof_received()
else: else:
try: try:
# print(f"data received: {res}") print(f"data received: {res}")
self._protocol.data_received(self._recv_buffer[:res]) data = bytes(self._recv_buffer[:res])
# print(f"data received: {data}")
if app_data:
self._protocol.data_received(data)
finally: finally:
if not self._closing: if not self._closing:
self._read() self._read()
def _write_done(self, res): def _write_done(self, res):
# print("_write_done")
self._current_work = None self._current_work = None
if res < 0: if res < 0:
# TODO: force close transport # TODO: force close transport
@ -153,6 +166,7 @@ class KLoopSocketTransport(
self._sock.fileno(), self._buffers, self._write_done self._sock.fileno(), self._buffers, self._write_done
) )
self._loop._selector.submit(self._current_work) self._loop._selector.submit(self._current_work)
# print("more SendWork")
self._buffers = [] self._buffers = []
elif self._closing: elif self._closing:
self._loop.call_soon(self._call_connection_lost, None) self._loop.call_soon(self._call_connection_lost, None)
@ -168,6 +182,7 @@ class KLoopSocketTransport(
self._sock.fileno(), data, self._write_done self._sock.fileno(), data, self._write_done
) )
self._loop._selector.submit(self._current_work) self._loop._selector.submit(self._current_work)
# print("SendWork")
else: else:
self._buffers.append(data) self._buffers.append(data)
self._maybe_pause_protocol() self._maybe_pause_protocol()
@ -233,45 +248,87 @@ class KLoopSSLHandshakeProtocol(asyncio.Protocol):
self._handshake() self._handshake()
def _handshake(self): def _handshake(self):
ktls.get_state(self._sslobj)
success, secrets = ktls.do_handshake_capturing_secrets(self._sslobj) success, secrets = ktls.do_handshake_capturing_secrets(self._sslobj)
self._secrets.update(secrets) self._secrets.update(secrets)
if success: if success:
# print("handshake done")
if self._handshaking: if self._handshaking:
print(self._sslobj.cipher())
ktls.get_state(self._sslobj)
self._handshaking = False self._handshaking = False
try:
data = self._sslobj.read(64 * 1024)
except ssl.SSLWantReadError:
data = None
if data:
while True:
try:
data += self._sslobj.read(64 * 1024)
except ssl.SSLWantReadError:
break
print("try read", data)
self._transport._upgrade_ktls_read(
self._sslobj,
self._secrets["SERVER_TRAFFIC_SECRET_0"],
data,
)
if data := self._outgoing.read(): if data := self._outgoing.read():
print("last message")
self._transport.write(data) self._transport.write(data)
self._transport._write_waiter = self._after_last_write self._transport._write_waiter = self._after_last_write
# self._transport._write_waiter = lambda: self._transport._loop.call_later(1, self._after_last_write)
self._transport.pause_reading() self._transport.pause_reading()
else: else:
self._after_last_write() self._after_last_write()
# else: else:
assert False
# try: # try:
# data = self._sslobj.read(16384) # data = self._sslobj.read(64 * 1024)
# except ssl.SSLWantReadError: # except ssl.SSLWantReadError:
# data = None # data = None
# if data:
# while True:
# try:
# data += self._sslobj.read(64 * 1024)
# except ssl.SSLWantReadError:
# break
# print("try read", data)
# # ktls.get_state(self._sslobj)
# self._after_last_write()
# self._transport._upgrade_ktls_read( # self._transport._upgrade_ktls_read(
# self._sslobj, # self._sslobj,
# self._secrets["SERVER_TRAFFIC_SECRET_0"], # self._secrets["SERVER_TRAFFIC_SECRET_0"],
# data, # data,
# ) # )
else: else:
# print("SSLWantReadError")
if data := self._outgoing.read(): if data := self._outgoing.read():
self._transport.write(data) self._transport.write(data)
def _after_last_write(self): def _after_last_write(self):
try: print("_after_last_write")
data = self._sslobj.read(16384) ktls.get_state(self._sslobj)
except ssl.SSLWantReadError: # try:
data = None # data = self._sslobj.read(64 * 1024)
# except ssl.SSLWantReadError:
# data = None
# if data:
# while True:
# try:
# data += self._sslobj.read(64 * 1024)
# except ssl.SSLWantReadError:
# break
# print("try read", data)
self._transport._upgrade_ktls_write( self._transport._upgrade_ktls_write(
self._sslobj, self._sslobj,
self._secrets["CLIENT_TRAFFIC_SECRET_0"], self._secrets["CLIENT_TRAFFIC_SECRET_0"],
) )
self._transport._upgrade_ktls_read( # self._transport._upgrade_ktls_read(
self._sslobj, # self._sslobj,
self._secrets["SERVER_TRAFFIC_SECRET_0"], # self._secrets["SERVER_TRAFFIC_SECRET_0"],
data, # data,
) # )
self._transport.resume_reading() self._transport.resume_reading()
@ -303,7 +360,9 @@ class KLoopSSLTransport(KLoopSocketTransport):
self._waiter = waiter self._waiter = waiter
def _upgrade_ktls_write(self, sslobj, secret): def _upgrade_ktls_write(self, sslobj, secret):
print("_upgrade_ktls_write")
ktls.upgrade_aes_gcm_256(sslobj, self._sock, secret, True) ktls.upgrade_aes_gcm_256(sslobj, self._sock, secret, True)
self.set_protocol(self._app_protocol)
self._loop.call_soon(self._app_protocol.connection_made, self) self._loop.call_soon(self._app_protocol.connection_made, self)
if self._waiter is not None: if self._waiter is not None:
self._loop.call_soon( self._loop.call_soon(
@ -313,13 +372,14 @@ class KLoopSSLTransport(KLoopSocketTransport):
) )
def _upgrade_ktls_read(self, sslobj, secret, data): def _upgrade_ktls_read(self, sslobj, secret, data):
print("_upgrade_ktls_read")
ktls.upgrade_aes_gcm_256(sslobj, self._sock, secret, False) ktls.upgrade_aes_gcm_256(sslobj, self._sock, secret, False)
self.set_protocol(self._app_protocol) # self.set_protocol(self._app_protocol)
if data is not None: # if data is not None:
if data: # if data:
self._app_protocol.data_received(data) # self._app_protocol.data_received(data)
else: # else:
self._app_protocol.eof_received() # self._app_protocol.eof_received()
class KLoop(asyncio.BaseEventLoop): class KLoop(asyncio.BaseEventLoop):
@ -344,6 +404,7 @@ class KLoop(asyncio.BaseEventLoop):
def _make_socket_transport( def _make_socket_transport(
self, sock, protocol, waiter=None, *, extra=None, server=None self, sock, protocol, waiter=None, *, extra=None, server=None
): ):
sock.setblocking(True)
return KLoopSocketTransport( return KLoopSocketTransport(
self, sock, protocol, waiter, extra, server self, sock, protocol, waiter, extra, server
) )

View file

@ -16,7 +16,7 @@ from cpython cimport PyMem_RawMalloc, PyMem_RawFree
from libc cimport errno, string from libc cimport errno, string
from posix cimport mman from posix cimport mman
from .includes cimport barrier, libc, linux from .includes cimport barrier, libc, linux, ssl
cdef linux.__u32 SIG_SIZE = libc._NSIG // 8 cdef linux.__u32 SIG_SIZE = libc._NSIG // 8
@ -187,6 +187,7 @@ cdef class Ring:
def submit(self, Work work): def submit(self, Work work):
cdef linux.io_uring_sqe* sqe = self.sq.next_sqe() cdef linux.io_uring_sqe* sqe = self.sq.next_sqe()
# print(f"submit: {work}")
work.submit(sqe) work.submit(sqe)
def select(self, timeout): def select(self, timeout):
@ -225,8 +226,9 @@ cdef class Ring:
if need_enter: if need_enter:
arg.sigmask = 0 arg.sigmask = 0
arg.sigmask_sz = SIG_SIZE arg.sigmask_sz = SIG_SIZE
print(f"SYS_io_uring_enter(submit={submit}, wait_nr={wait_nr}, " # print(f"SYS_io_uring_enter(submit={submit}, wait_nr={wait_nr}, "
f"flags={flags:b}, timeout={timeout})") # f"flags={flags:b}, timeout={timeout})")
with nogil:
ret = libc.syscall( ret = libc.syscall(
libc.SYS_io_uring_enter, libc.SYS_io_uring_enter,
self.enter_fd, self.enter_fd,
@ -238,6 +240,8 @@ cdef class Ring:
) )
if ret < 0: if ret < 0:
if errno.errno != errno.ETIME: if errno.errno != errno.ETIME:
print(f"SYS_io_uring_enter(submit={submit}, wait_nr={wait_nr}, "
f"flags={flags:b}, timeout={timeout})")
PyErr_SetFromErrno(IOError) PyErr_SetFromErrno(IOError)
return return
@ -263,7 +267,7 @@ cdef class Work:
int op, int op,
linux.io_uring_sqe * sqe, linux.io_uring_sqe * sqe,
int fd, int fd,
void * addr, void* addr,
unsigned len, unsigned len,
linux.__u64 offset, linux.__u64 offset,
): ):
@ -386,6 +390,7 @@ cdef class RecvWork(Work):
cdef class RecvMsgWork(Work): cdef class RecvMsgWork(Work):
def __init__(self, int fd, buffers, callback): def __init__(self, int fd, buffers, callback):
cdef size_t size = libc.CMSG_SPACE(sizeof(unsigned char))
self.fd = fd self.fd = fd
self.buffers = buffers self.buffers = buffers
self.callback = callback self.callback = callback
@ -398,9 +403,9 @@ cdef class RecvMsgWork(Work):
for i, buf in enumerate(buffers): for i, buf in enumerate(buffers):
self.msg.msg_iov[i].iov_base = <char*>buf self.msg.msg_iov[i].iov_base = <char*>buf
self.msg.msg_iov[i].iov_len = len(buf) self.msg.msg_iov[i].iov_len = len(buf)
self.control_msg = bytearray(256) self.control_msg = bytearray(size)
self.msg.msg_control = <char*>self.control_msg self.msg.msg_control = <char*>self.control_msg
self.msg.msg_controllen = 256 self.msg.msg_controllen = size
def __dealloc__(self): def __dealloc__(self):
if self.msg.msg_iov != NULL: if self.msg.msg_iov != NULL:
@ -410,11 +415,24 @@ cdef class RecvMsgWork(Work):
self._submit(linux.IORING_OP_RECVMSG, sqe, self.fd, &self.msg, 1, 0) self._submit(linux.IORING_OP_RECVMSG, sqe, self.fd, &self.msg, 1, 0)
def complete(self): def complete(self):
if self.res < 0: cdef:
errno.errno = abs(self.res) libc.cmsghdr* cmsg
PyErr_SetFromErrno(IOError) unsigned char* cmsg_data
return unsigned char record_type
# if self.msg.msg_controllen: # if self.res < 0:
# print('control_msg:', self.control_msg[:self.msg.msg_controllen]) # errno.errno = abs(self.res)
# print('flags:', self.msg.msg_flags) # PyErr_SetFromErrno(IOError)
self.callback(self.res) # return
app_data = True
if self.msg.msg_controllen:
print('msg_controllen:', self.msg.msg_controllen)
cmsg = libc.CMSG_FIRSTHDR(&self.msg)
if cmsg.cmsg_level == libc.SOL_TLS and cmsg.cmsg_type == linux.TLS_GET_RECORD_TYPE:
cmsg_data = libc.CMSG_DATA(cmsg)
record_type = (<unsigned char*>cmsg_data)[0]
if record_type != ssl.SSL3_RT_APPLICATION_DATA:
app_data = False
print(f'cmsg.len={cmsg.cmsg_len}, cmsg.level={cmsg.cmsg_level}, cmsg.type={cmsg.cmsg_type}')
print(f'record type: {record_type}')
print('flags:', self.msg.msg_flags)
self.callback(self.res, app_data)

View file

@ -38,14 +38,21 @@ class TestLoop(unittest.TestCase):
def test_connect(self): def test_connect(self):
ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
ctx.minimum_version = ssl.TLSVersion.TLSv1_3
host = "www.google.com"
r, w = self.loop.run_until_complete( r, w = self.loop.run_until_complete(
asyncio.open_connection("www.google.com", 443, ssl=ctx) # asyncio.open_connection("127.0.0.1", 8080, ssl=ctx)
asyncio.open_connection(host, 443, ssl=ctx)
) )
w.write(b"GET / HTTP/1.1\r\n" self.loop.run_until_complete(asyncio.sleep(1))
b"Host: www.google.com\r\n" print('send request')
w.write(b"GET / HTTP/1.1\r\n" +
f"Host: {host}\r\n".encode("ISO-8859-1") +
b"Connection: close\r\n" b"Connection: close\r\n"
b"\r\n") b"\r\n")
while line := self.loop.run_until_complete(r.readline()): while line := self.loop.run_until_complete(r.read()):
print(line) print(line)
w.close() w.close()
self.loop.run_until_complete(w.wait_closed()) self.loop.run_until_complete(w.wait_closed())