1
0
Fork 0
mirror of https://gitee.com/fantix/kloop.git synced 2024-05-08 08:02:41 +00:00

Almost working PoC

This commit is contained in:
Fantix King 2022-04-19 18:32:40 -04:00
parent 97f02d40bc
commit 9fc44a51c8
No known key found for this signature in database
GPG key ID: 95304B04071CCDB4
7 changed files with 180 additions and 71 deletions

View file

@ -31,6 +31,16 @@ cdef extern from "sys/socket.h" nogil:
size_t msg_controllen # ancillary data buffer len
int msg_flags # flags on received message
struct cmsghdr:
socklen_t cmsg_len # data byte count, including header
int cmsg_level # originating protocol
int cmsg_type # protocol-specific type
size_t CMSG_LEN(size_t length)
cmsghdr* CMSG_FIRSTHDR(msghdr* msgh)
unsigned char* CMSG_DATA(cmsghdr* cmsg)
size_t CMSG_SPACE(size_t length)
cdef extern from "arpa/inet.h" nogil:
int inet_pton(int af, char* src, void* dst)

View file

@ -32,6 +32,8 @@ cdef extern from "linux/tcp.h" nogil:
cdef extern from "linux/tls.h" nogil:
int TLS_GET_RECORD_TYPE
__u16 TLS_CIPHER_AES_GCM_256
int TLS_CIPHER_AES_GCM_256_IV_SIZE
int TLS_CIPHER_AES_GCM_256_SALT_SIZE

View file

@ -23,6 +23,16 @@ cdef extern from "openssl/ssl.h" nogil:
SSL_CTX_keylog_cb_func SSL_CTX_get_keylog_callback(SSL_CTX* ctx)
SSL_CTX* SSL_get_SSL_CTX(SSL* ssl)
ctypedef enum OSSL_HANDSHAKE_STATE:
pass
OSSL_HANDSHAKE_STATE SSL_get_state(const SSL *ssl);
unsigned int SSL3_RT_CHANGE_CIPHER_SPEC
unsigned int SSL3_RT_ALERT
unsigned int SSL3_RT_HANDSHAKE
unsigned int SSL3_RT_APPLICATION_DATA
cdef extern from "includes/ssl.h" nogil:
ctypedef struct PySSLSocket:

View file

@ -11,6 +11,7 @@
import socket
import hmac
import hashlib
import struct
from ssl import SSLWantReadError
from cpython cimport PyErr_SetFromErrno
@ -61,20 +62,14 @@ def do_handshake_capturing_secrets(sslobj):
ssl.SSL_CTX_set_keylog_callback(ctx, orig_cb)
def hkdf_expand(pseudo_random_key, info=b"", length=32, hash=hashlib.sha384):
def hkdf_expand(pseudo_random_key, label, length, hash_method=hashlib.sha384):
'''
Expand `pseudo_random_key` and `info` into a key of length `bytes` using
HKDF's expand function based on HMAC with the provided hash (default
SHA-512). See the HKDF draft RFC and paper for usage notes.
'''
# info_in = info
# info = b'\0' + struct.pack("H", len(info)) + info + b'\0'
# print(f'hkdf_expand info_in= label={info.hex()}')
hash_len = hash().digest_size
length = int(length)
if length > 255 * hash_len:
raise Exception("Cannot expand to more than 255 * %d = %d bytes using the specified hash function" % \
(hash_len, 255 * hash_len))
info = struct.pack("!HB", length, len(label)) + label + b'\0'
hash_len = hash_method().digest_size
blocks_needed = length // hash_len + (0 if length % hash_len == 0 else 1) # ceil
okm = b""
output_block = b""
@ -82,7 +77,7 @@ def hkdf_expand(pseudo_random_key, info=b"", length=32, hash=hashlib.sha384):
output_block = hmac.new(
pseudo_random_key,
(output_block + info + bytearray((counter + 1,))),
hash,
hash_method,
).digest()
okm += output_block
return okm[:length]
@ -95,6 +90,12 @@ def enable_ulp(sock):
return
def get_state(sslobj):
cdef:
ssl.SSL* s = (<ssl.PySSLSocket*>sslobj._sslobj).ssl
print(ssl.SSL_get_state(s))
def upgrade_aes_gcm_256(sslobj, sock, secret, sending):
cdef:
ssl.SSL* s = (<ssl.PySSLSocket*>sslobj._sslobj).ssl
@ -108,7 +109,7 @@ def upgrade_aes_gcm_256(sslobj, sock, secret, sending):
# s->rlayer->read_sequence
seq = <char*>((<void*>s) + 6104)
# print(sslobj.cipher())
# print(sslobj.cipher())
string.memset(&crypto_info, 0, sizeof(crypto_info))
crypto_info.info.cipher_type = linux.TLS_CIPHER_AES_GCM_256
@ -116,7 +117,7 @@ def upgrade_aes_gcm_256(sslobj, sock, secret, sending):
key = hkdf_expand(
secret,
b'\x00 \ttls13 key\x00',
b'tls13 key',
linux.TLS_CIPHER_AES_GCM_256_KEY_SIZE,
)
string.memcpy(
@ -131,7 +132,7 @@ def upgrade_aes_gcm_256(sslobj, sock, secret, sending):
)
iv = hkdf_expand(
secret,
b'\x00\x0c\x08tls13 iv\x00',
b'tls13 iv',
linux.TLS_CIPHER_AES_GCM_256_IV_SIZE +
linux.TLS_CIPHER_AES_GCM_256_SALT_SIZE,
)

View file

@ -13,6 +13,7 @@ import asyncio.futures
import asyncio.trsock
import asyncio.transports
import contextvars
import errno
import socket
import ssl
@ -98,11 +99,14 @@ class KLoopSocketTransport(
self._recv_buffer_factory = protocol.get_buffer
else:
self._read_ready_cb = self._read_ready__data_received
self._recv_buffer = bytearray(256 * 1024)
self._recv_buffer = bytearray(64 * 1024 * 1024)
self._recv_buffer_factory = lambda _hint: self._recv_buffer
self._protocol = protocol
def _read(self):
# print("RecvMsgWork")
if self._read_paused:
return
self._loop._selector.submit(
uring.RecvMsgWork(
self._sock.fileno(),
@ -111,7 +115,7 @@ class KLoopSocketTransport(
)
)
def _read_ready__buffer_updated(self, res):
def _read_ready__buffer_updated(self, res, app_data):
if res < 0:
raise IOError
elif res == 0:
@ -124,20 +128,29 @@ class KLoopSocketTransport(
if not self._closing:
self._read()
def _read_ready__data_received(self, res):
def _read_ready__data_received(self, res, app_data):
print("_read_ready__data_received", res)
if res < 0:
raise IOError(f"{res}")
if abs(res) == errno.EAGAIN:
print('EAGAIN')
self._read()
else:
raise IOError(f"{res}")
elif res == 0:
self._protocol.eof_received()
else:
try:
# print(f"data received: {res}")
self._protocol.data_received(self._recv_buffer[:res])
print(f"data received: {res}")
data = bytes(self._recv_buffer[:res])
# print(f"data received: {data}")
if app_data:
self._protocol.data_received(data)
finally:
if not self._closing:
self._read()
def _write_done(self, res):
# print("_write_done")
self._current_work = None
if res < 0:
# TODO: force close transport
@ -153,6 +166,7 @@ class KLoopSocketTransport(
self._sock.fileno(), self._buffers, self._write_done
)
self._loop._selector.submit(self._current_work)
# print("more SendWork")
self._buffers = []
elif self._closing:
self._loop.call_soon(self._call_connection_lost, None)
@ -168,6 +182,7 @@ class KLoopSocketTransport(
self._sock.fileno(), data, self._write_done
)
self._loop._selector.submit(self._current_work)
# print("SendWork")
else:
self._buffers.append(data)
self._maybe_pause_protocol()
@ -233,45 +248,87 @@ class KLoopSSLHandshakeProtocol(asyncio.Protocol):
self._handshake()
def _handshake(self):
ktls.get_state(self._sslobj)
success, secrets = ktls.do_handshake_capturing_secrets(self._sslobj)
self._secrets.update(secrets)
if success:
# print("handshake done")
if self._handshaking:
print(self._sslobj.cipher())
ktls.get_state(self._sslobj)
self._handshaking = False
try:
data = self._sslobj.read(64 * 1024)
except ssl.SSLWantReadError:
data = None
if data:
while True:
try:
data += self._sslobj.read(64 * 1024)
except ssl.SSLWantReadError:
break
print("try read", data)
self._transport._upgrade_ktls_read(
self._sslobj,
self._secrets["SERVER_TRAFFIC_SECRET_0"],
data,
)
if data := self._outgoing.read():
print("last message")
self._transport.write(data)
self._transport._write_waiter = self._after_last_write
# self._transport._write_waiter = lambda: self._transport._loop.call_later(1, self._after_last_write)
self._transport.pause_reading()
else:
self._after_last_write()
# else:
# try:
# data = self._sslobj.read(16384)
# except ssl.SSLWantReadError:
# data = None
# self._transport._upgrade_ktls_read(
# self._sslobj,
# self._secrets["SERVER_TRAFFIC_SECRET_0"],
# data,
# )
else:
assert False
# try:
# data = self._sslobj.read(64 * 1024)
# except ssl.SSLWantReadError:
# data = None
# if data:
# while True:
# try:
# data += self._sslobj.read(64 * 1024)
# except ssl.SSLWantReadError:
# break
# print("try read", data)
# # ktls.get_state(self._sslobj)
# self._after_last_write()
# self._transport._upgrade_ktls_read(
# self._sslobj,
# self._secrets["SERVER_TRAFFIC_SECRET_0"],
# data,
# )
else:
# print("SSLWantReadError")
if data := self._outgoing.read():
self._transport.write(data)
def _after_last_write(self):
try:
data = self._sslobj.read(16384)
except ssl.SSLWantReadError:
data = None
print("_after_last_write")
ktls.get_state(self._sslobj)
# try:
# data = self._sslobj.read(64 * 1024)
# except ssl.SSLWantReadError:
# data = None
# if data:
# while True:
# try:
# data += self._sslobj.read(64 * 1024)
# except ssl.SSLWantReadError:
# break
# print("try read", data)
self._transport._upgrade_ktls_write(
self._sslobj,
self._secrets["CLIENT_TRAFFIC_SECRET_0"],
)
self._transport._upgrade_ktls_read(
self._sslobj,
self._secrets["SERVER_TRAFFIC_SECRET_0"],
data,
)
# self._transport._upgrade_ktls_read(
# self._sslobj,
# self._secrets["SERVER_TRAFFIC_SECRET_0"],
# data,
# )
self._transport.resume_reading()
@ -303,7 +360,9 @@ class KLoopSSLTransport(KLoopSocketTransport):
self._waiter = waiter
def _upgrade_ktls_write(self, sslobj, secret):
print("_upgrade_ktls_write")
ktls.upgrade_aes_gcm_256(sslobj, self._sock, secret, True)
self.set_protocol(self._app_protocol)
self._loop.call_soon(self._app_protocol.connection_made, self)
if self._waiter is not None:
self._loop.call_soon(
@ -313,13 +372,14 @@ class KLoopSSLTransport(KLoopSocketTransport):
)
def _upgrade_ktls_read(self, sslobj, secret, data):
print("_upgrade_ktls_read")
ktls.upgrade_aes_gcm_256(sslobj, self._sock, secret, False)
self.set_protocol(self._app_protocol)
if data is not None:
if data:
self._app_protocol.data_received(data)
else:
self._app_protocol.eof_received()
# self.set_protocol(self._app_protocol)
# if data is not None:
# if data:
# self._app_protocol.data_received(data)
# else:
# self._app_protocol.eof_received()
class KLoop(asyncio.BaseEventLoop):
@ -344,6 +404,7 @@ class KLoop(asyncio.BaseEventLoop):
def _make_socket_transport(
self, sock, protocol, waiter=None, *, extra=None, server=None
):
sock.setblocking(True)
return KLoopSocketTransport(
self, sock, protocol, waiter, extra, server
)

View file

@ -16,7 +16,7 @@ from cpython cimport PyMem_RawMalloc, PyMem_RawFree
from libc cimport errno, string
from posix cimport mman
from .includes cimport barrier, libc, linux
from .includes cimport barrier, libc, linux, ssl
cdef linux.__u32 SIG_SIZE = libc._NSIG // 8
@ -187,6 +187,7 @@ cdef class Ring:
def submit(self, Work work):
cdef linux.io_uring_sqe* sqe = self.sq.next_sqe()
# print(f"submit: {work}")
work.submit(sqe)
def select(self, timeout):
@ -225,19 +226,22 @@ cdef class Ring:
if need_enter:
arg.sigmask = 0
arg.sigmask_sz = SIG_SIZE
print(f"SYS_io_uring_enter(submit={submit}, wait_nr={wait_nr}, "
f"flags={flags:b}, timeout={timeout})")
ret = libc.syscall(
libc.SYS_io_uring_enter,
self.enter_fd,
submit,
wait_nr,
flags,
&arg,
sizeof(arg),
)
# print(f"SYS_io_uring_enter(submit={submit}, wait_nr={wait_nr}, "
# f"flags={flags:b}, timeout={timeout})")
with nogil:
ret = libc.syscall(
libc.SYS_io_uring_enter,
self.enter_fd,
submit,
wait_nr,
flags,
&arg,
sizeof(arg),
)
if ret < 0:
if errno.errno != errno.ETIME:
print(f"SYS_io_uring_enter(submit={submit}, wait_nr={wait_nr}, "
f"flags={flags:b}, timeout={timeout})")
PyErr_SetFromErrno(IOError)
return
@ -263,7 +267,7 @@ cdef class Work:
int op,
linux.io_uring_sqe * sqe,
int fd,
void * addr,
void* addr,
unsigned len,
linux.__u64 offset,
):
@ -386,6 +390,7 @@ cdef class RecvWork(Work):
cdef class RecvMsgWork(Work):
def __init__(self, int fd, buffers, callback):
cdef size_t size = libc.CMSG_SPACE(sizeof(unsigned char))
self.fd = fd
self.buffers = buffers
self.callback = callback
@ -398,9 +403,9 @@ cdef class RecvMsgWork(Work):
for i, buf in enumerate(buffers):
self.msg.msg_iov[i].iov_base = <char*>buf
self.msg.msg_iov[i].iov_len = len(buf)
self.control_msg = bytearray(256)
self.control_msg = bytearray(size)
self.msg.msg_control = <char*>self.control_msg
self.msg.msg_controllen = 256
self.msg.msg_controllen = size
def __dealloc__(self):
if self.msg.msg_iov != NULL:
@ -410,11 +415,24 @@ cdef class RecvMsgWork(Work):
self._submit(linux.IORING_OP_RECVMSG, sqe, self.fd, &self.msg, 1, 0)
def complete(self):
if self.res < 0:
errno.errno = abs(self.res)
PyErr_SetFromErrno(IOError)
return
# if self.msg.msg_controllen:
# print('control_msg:', self.control_msg[:self.msg.msg_controllen])
# print('flags:', self.msg.msg_flags)
self.callback(self.res)
cdef:
libc.cmsghdr* cmsg
unsigned char* cmsg_data
unsigned char record_type
# if self.res < 0:
# errno.errno = abs(self.res)
# PyErr_SetFromErrno(IOError)
# return
app_data = True
if self.msg.msg_controllen:
print('msg_controllen:', self.msg.msg_controllen)
cmsg = libc.CMSG_FIRSTHDR(&self.msg)
if cmsg.cmsg_level == libc.SOL_TLS and cmsg.cmsg_type == linux.TLS_GET_RECORD_TYPE:
cmsg_data = libc.CMSG_DATA(cmsg)
record_type = (<unsigned char*>cmsg_data)[0]
if record_type != ssl.SSL3_RT_APPLICATION_DATA:
app_data = False
print(f'cmsg.len={cmsg.cmsg_len}, cmsg.level={cmsg.cmsg_level}, cmsg.type={cmsg.cmsg_type}')
print(f'record type: {record_type}')
print('flags:', self.msg.msg_flags)
self.callback(self.res, app_data)

View file

@ -38,14 +38,21 @@ class TestLoop(unittest.TestCase):
def test_connect(self):
ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
ctx.minimum_version = ssl.TLSVersion.TLSv1_3
host = "www.google.com"
r, w = self.loop.run_until_complete(
asyncio.open_connection("www.google.com", 443, ssl=ctx)
# asyncio.open_connection("127.0.0.1", 8080, ssl=ctx)
asyncio.open_connection(host, 443, ssl=ctx)
)
w.write(b"GET / HTTP/1.1\r\n"
b"Host: www.google.com\r\n"
self.loop.run_until_complete(asyncio.sleep(1))
print('send request')
w.write(b"GET / HTTP/1.1\r\n" +
f"Host: {host}\r\n".encode("ISO-8859-1") +
b"Connection: close\r\n"
b"\r\n")
while line := self.loop.run_until_complete(r.readline()):
while line := self.loop.run_until_complete(r.read()):
print(line)
w.close()
self.loop.run_until_complete(w.wait_closed())