mirror of
https://gitee.com/fantix/kloop.git
synced 2024-11-22 02:11:01 +00:00
Almost working PoC
This commit is contained in:
parent
97f02d40bc
commit
9fc44a51c8
7 changed files with 180 additions and 71 deletions
|
@ -31,6 +31,16 @@ cdef extern from "sys/socket.h" nogil:
|
|||
size_t msg_controllen # ancillary data buffer len
|
||||
int msg_flags # flags on received message
|
||||
|
||||
struct cmsghdr:
|
||||
socklen_t cmsg_len # data byte count, including header
|
||||
int cmsg_level # originating protocol
|
||||
int cmsg_type # protocol-specific type
|
||||
|
||||
size_t CMSG_LEN(size_t length)
|
||||
cmsghdr* CMSG_FIRSTHDR(msghdr* msgh)
|
||||
unsigned char* CMSG_DATA(cmsghdr* cmsg)
|
||||
size_t CMSG_SPACE(size_t length)
|
||||
|
||||
|
||||
cdef extern from "arpa/inet.h" nogil:
|
||||
int inet_pton(int af, char* src, void* dst)
|
||||
|
|
|
@ -32,6 +32,8 @@ cdef extern from "linux/tcp.h" nogil:
|
|||
|
||||
|
||||
cdef extern from "linux/tls.h" nogil:
|
||||
int TLS_GET_RECORD_TYPE
|
||||
|
||||
__u16 TLS_CIPHER_AES_GCM_256
|
||||
int TLS_CIPHER_AES_GCM_256_IV_SIZE
|
||||
int TLS_CIPHER_AES_GCM_256_SALT_SIZE
|
||||
|
|
|
@ -23,6 +23,16 @@ cdef extern from "openssl/ssl.h" nogil:
|
|||
SSL_CTX_keylog_cb_func SSL_CTX_get_keylog_callback(SSL_CTX* ctx)
|
||||
SSL_CTX* SSL_get_SSL_CTX(SSL* ssl)
|
||||
|
||||
ctypedef enum OSSL_HANDSHAKE_STATE:
|
||||
pass
|
||||
|
||||
OSSL_HANDSHAKE_STATE SSL_get_state(const SSL *ssl);
|
||||
|
||||
unsigned int SSL3_RT_CHANGE_CIPHER_SPEC
|
||||
unsigned int SSL3_RT_ALERT
|
||||
unsigned int SSL3_RT_HANDSHAKE
|
||||
unsigned int SSL3_RT_APPLICATION_DATA
|
||||
|
||||
|
||||
cdef extern from "includes/ssl.h" nogil:
|
||||
ctypedef struct PySSLSocket:
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
import socket
|
||||
import hmac
|
||||
import hashlib
|
||||
import struct
|
||||
from ssl import SSLWantReadError
|
||||
|
||||
from cpython cimport PyErr_SetFromErrno
|
||||
|
@ -61,20 +62,14 @@ def do_handshake_capturing_secrets(sslobj):
|
|||
ssl.SSL_CTX_set_keylog_callback(ctx, orig_cb)
|
||||
|
||||
|
||||
def hkdf_expand(pseudo_random_key, info=b"", length=32, hash=hashlib.sha384):
|
||||
def hkdf_expand(pseudo_random_key, label, length, hash_method=hashlib.sha384):
|
||||
'''
|
||||
Expand `pseudo_random_key` and `info` into a key of length `bytes` using
|
||||
HKDF's expand function based on HMAC with the provided hash (default
|
||||
SHA-512). See the HKDF draft RFC and paper for usage notes.
|
||||
'''
|
||||
# info_in = info
|
||||
# info = b'\0' + struct.pack("H", len(info)) + info + b'\0'
|
||||
# print(f'hkdf_expand info_in= label={info.hex()}')
|
||||
hash_len = hash().digest_size
|
||||
length = int(length)
|
||||
if length > 255 * hash_len:
|
||||
raise Exception("Cannot expand to more than 255 * %d = %d bytes using the specified hash function" % \
|
||||
(hash_len, 255 * hash_len))
|
||||
info = struct.pack("!HB", length, len(label)) + label + b'\0'
|
||||
hash_len = hash_method().digest_size
|
||||
blocks_needed = length // hash_len + (0 if length % hash_len == 0 else 1) # ceil
|
||||
okm = b""
|
||||
output_block = b""
|
||||
|
@ -82,7 +77,7 @@ def hkdf_expand(pseudo_random_key, info=b"", length=32, hash=hashlib.sha384):
|
|||
output_block = hmac.new(
|
||||
pseudo_random_key,
|
||||
(output_block + info + bytearray((counter + 1,))),
|
||||
hash,
|
||||
hash_method,
|
||||
).digest()
|
||||
okm += output_block
|
||||
return okm[:length]
|
||||
|
@ -95,6 +90,12 @@ def enable_ulp(sock):
|
|||
return
|
||||
|
||||
|
||||
def get_state(sslobj):
|
||||
cdef:
|
||||
ssl.SSL* s = (<ssl.PySSLSocket*>sslobj._sslobj).ssl
|
||||
print(ssl.SSL_get_state(s))
|
||||
|
||||
|
||||
def upgrade_aes_gcm_256(sslobj, sock, secret, sending):
|
||||
cdef:
|
||||
ssl.SSL* s = (<ssl.PySSLSocket*>sslobj._sslobj).ssl
|
||||
|
@ -116,7 +117,7 @@ def upgrade_aes_gcm_256(sslobj, sock, secret, sending):
|
|||
|
||||
key = hkdf_expand(
|
||||
secret,
|
||||
b'\x00 \ttls13 key\x00',
|
||||
b'tls13 key',
|
||||
linux.TLS_CIPHER_AES_GCM_256_KEY_SIZE,
|
||||
)
|
||||
string.memcpy(
|
||||
|
@ -131,7 +132,7 @@ def upgrade_aes_gcm_256(sslobj, sock, secret, sending):
|
|||
)
|
||||
iv = hkdf_expand(
|
||||
secret,
|
||||
b'\x00\x0c\x08tls13 iv\x00',
|
||||
b'tls13 iv',
|
||||
linux.TLS_CIPHER_AES_GCM_256_IV_SIZE +
|
||||
linux.TLS_CIPHER_AES_GCM_256_SALT_SIZE,
|
||||
)
|
||||
|
|
|
@ -13,6 +13,7 @@ import asyncio.futures
|
|||
import asyncio.trsock
|
||||
import asyncio.transports
|
||||
import contextvars
|
||||
import errno
|
||||
import socket
|
||||
import ssl
|
||||
|
||||
|
@ -98,11 +99,14 @@ class KLoopSocketTransport(
|
|||
self._recv_buffer_factory = protocol.get_buffer
|
||||
else:
|
||||
self._read_ready_cb = self._read_ready__data_received
|
||||
self._recv_buffer = bytearray(256 * 1024)
|
||||
self._recv_buffer = bytearray(64 * 1024 * 1024)
|
||||
self._recv_buffer_factory = lambda _hint: self._recv_buffer
|
||||
self._protocol = protocol
|
||||
|
||||
def _read(self):
|
||||
# print("RecvMsgWork")
|
||||
if self._read_paused:
|
||||
return
|
||||
self._loop._selector.submit(
|
||||
uring.RecvMsgWork(
|
||||
self._sock.fileno(),
|
||||
|
@ -111,7 +115,7 @@ class KLoopSocketTransport(
|
|||
)
|
||||
)
|
||||
|
||||
def _read_ready__buffer_updated(self, res):
|
||||
def _read_ready__buffer_updated(self, res, app_data):
|
||||
if res < 0:
|
||||
raise IOError
|
||||
elif res == 0:
|
||||
|
@ -124,20 +128,29 @@ class KLoopSocketTransport(
|
|||
if not self._closing:
|
||||
self._read()
|
||||
|
||||
def _read_ready__data_received(self, res):
|
||||
def _read_ready__data_received(self, res, app_data):
|
||||
print("_read_ready__data_received", res)
|
||||
if res < 0:
|
||||
if abs(res) == errno.EAGAIN:
|
||||
print('EAGAIN')
|
||||
self._read()
|
||||
else:
|
||||
raise IOError(f"{res}")
|
||||
elif res == 0:
|
||||
self._protocol.eof_received()
|
||||
else:
|
||||
try:
|
||||
# print(f"data received: {res}")
|
||||
self._protocol.data_received(self._recv_buffer[:res])
|
||||
print(f"data received: {res}")
|
||||
data = bytes(self._recv_buffer[:res])
|
||||
# print(f"data received: {data}")
|
||||
if app_data:
|
||||
self._protocol.data_received(data)
|
||||
finally:
|
||||
if not self._closing:
|
||||
self._read()
|
||||
|
||||
def _write_done(self, res):
|
||||
# print("_write_done")
|
||||
self._current_work = None
|
||||
if res < 0:
|
||||
# TODO: force close transport
|
||||
|
@ -153,6 +166,7 @@ class KLoopSocketTransport(
|
|||
self._sock.fileno(), self._buffers, self._write_done
|
||||
)
|
||||
self._loop._selector.submit(self._current_work)
|
||||
# print("more SendWork")
|
||||
self._buffers = []
|
||||
elif self._closing:
|
||||
self._loop.call_soon(self._call_connection_lost, None)
|
||||
|
@ -168,6 +182,7 @@ class KLoopSocketTransport(
|
|||
self._sock.fileno(), data, self._write_done
|
||||
)
|
||||
self._loop._selector.submit(self._current_work)
|
||||
# print("SendWork")
|
||||
else:
|
||||
self._buffers.append(data)
|
||||
self._maybe_pause_protocol()
|
||||
|
@ -233,45 +248,87 @@ class KLoopSSLHandshakeProtocol(asyncio.Protocol):
|
|||
self._handshake()
|
||||
|
||||
def _handshake(self):
|
||||
ktls.get_state(self._sslobj)
|
||||
success, secrets = ktls.do_handshake_capturing_secrets(self._sslobj)
|
||||
self._secrets.update(secrets)
|
||||
if success:
|
||||
# print("handshake done")
|
||||
if self._handshaking:
|
||||
print(self._sslobj.cipher())
|
||||
ktls.get_state(self._sslobj)
|
||||
self._handshaking = False
|
||||
try:
|
||||
data = self._sslobj.read(64 * 1024)
|
||||
except ssl.SSLWantReadError:
|
||||
data = None
|
||||
if data:
|
||||
while True:
|
||||
try:
|
||||
data += self._sslobj.read(64 * 1024)
|
||||
except ssl.SSLWantReadError:
|
||||
break
|
||||
print("try read", data)
|
||||
self._transport._upgrade_ktls_read(
|
||||
self._sslobj,
|
||||
self._secrets["SERVER_TRAFFIC_SECRET_0"],
|
||||
data,
|
||||
)
|
||||
if data := self._outgoing.read():
|
||||
print("last message")
|
||||
self._transport.write(data)
|
||||
self._transport._write_waiter = self._after_last_write
|
||||
# self._transport._write_waiter = lambda: self._transport._loop.call_later(1, self._after_last_write)
|
||||
self._transport.pause_reading()
|
||||
else:
|
||||
self._after_last_write()
|
||||
# else:
|
||||
else:
|
||||
assert False
|
||||
# try:
|
||||
# data = self._sslobj.read(16384)
|
||||
# data = self._sslobj.read(64 * 1024)
|
||||
# except ssl.SSLWantReadError:
|
||||
# data = None
|
||||
# if data:
|
||||
# while True:
|
||||
# try:
|
||||
# data += self._sslobj.read(64 * 1024)
|
||||
# except ssl.SSLWantReadError:
|
||||
# break
|
||||
# print("try read", data)
|
||||
# # ktls.get_state(self._sslobj)
|
||||
# self._after_last_write()
|
||||
# self._transport._upgrade_ktls_read(
|
||||
# self._sslobj,
|
||||
# self._secrets["SERVER_TRAFFIC_SECRET_0"],
|
||||
# data,
|
||||
# )
|
||||
else:
|
||||
# print("SSLWantReadError")
|
||||
if data := self._outgoing.read():
|
||||
self._transport.write(data)
|
||||
|
||||
def _after_last_write(self):
|
||||
try:
|
||||
data = self._sslobj.read(16384)
|
||||
except ssl.SSLWantReadError:
|
||||
data = None
|
||||
print("_after_last_write")
|
||||
ktls.get_state(self._sslobj)
|
||||
# try:
|
||||
# data = self._sslobj.read(64 * 1024)
|
||||
# except ssl.SSLWantReadError:
|
||||
# data = None
|
||||
# if data:
|
||||
# while True:
|
||||
# try:
|
||||
# data += self._sslobj.read(64 * 1024)
|
||||
# except ssl.SSLWantReadError:
|
||||
# break
|
||||
# print("try read", data)
|
||||
self._transport._upgrade_ktls_write(
|
||||
self._sslobj,
|
||||
self._secrets["CLIENT_TRAFFIC_SECRET_0"],
|
||||
)
|
||||
self._transport._upgrade_ktls_read(
|
||||
self._sslobj,
|
||||
self._secrets["SERVER_TRAFFIC_SECRET_0"],
|
||||
data,
|
||||
)
|
||||
# self._transport._upgrade_ktls_read(
|
||||
# self._sslobj,
|
||||
# self._secrets["SERVER_TRAFFIC_SECRET_0"],
|
||||
# data,
|
||||
# )
|
||||
self._transport.resume_reading()
|
||||
|
||||
|
||||
|
@ -303,7 +360,9 @@ class KLoopSSLTransport(KLoopSocketTransport):
|
|||
self._waiter = waiter
|
||||
|
||||
def _upgrade_ktls_write(self, sslobj, secret):
|
||||
print("_upgrade_ktls_write")
|
||||
ktls.upgrade_aes_gcm_256(sslobj, self._sock, secret, True)
|
||||
self.set_protocol(self._app_protocol)
|
||||
self._loop.call_soon(self._app_protocol.connection_made, self)
|
||||
if self._waiter is not None:
|
||||
self._loop.call_soon(
|
||||
|
@ -313,13 +372,14 @@ class KLoopSSLTransport(KLoopSocketTransport):
|
|||
)
|
||||
|
||||
def _upgrade_ktls_read(self, sslobj, secret, data):
|
||||
print("_upgrade_ktls_read")
|
||||
ktls.upgrade_aes_gcm_256(sslobj, self._sock, secret, False)
|
||||
self.set_protocol(self._app_protocol)
|
||||
if data is not None:
|
||||
if data:
|
||||
self._app_protocol.data_received(data)
|
||||
else:
|
||||
self._app_protocol.eof_received()
|
||||
# self.set_protocol(self._app_protocol)
|
||||
# if data is not None:
|
||||
# if data:
|
||||
# self._app_protocol.data_received(data)
|
||||
# else:
|
||||
# self._app_protocol.eof_received()
|
||||
|
||||
|
||||
class KLoop(asyncio.BaseEventLoop):
|
||||
|
@ -344,6 +404,7 @@ class KLoop(asyncio.BaseEventLoop):
|
|||
def _make_socket_transport(
|
||||
self, sock, protocol, waiter=None, *, extra=None, server=None
|
||||
):
|
||||
sock.setblocking(True)
|
||||
return KLoopSocketTransport(
|
||||
self, sock, protocol, waiter, extra, server
|
||||
)
|
||||
|
|
|
@ -16,7 +16,7 @@ from cpython cimport PyMem_RawMalloc, PyMem_RawFree
|
|||
from libc cimport errno, string
|
||||
from posix cimport mman
|
||||
|
||||
from .includes cimport barrier, libc, linux
|
||||
from .includes cimport barrier, libc, linux, ssl
|
||||
|
||||
cdef linux.__u32 SIG_SIZE = libc._NSIG // 8
|
||||
|
||||
|
@ -187,6 +187,7 @@ cdef class Ring:
|
|||
|
||||
def submit(self, Work work):
|
||||
cdef linux.io_uring_sqe* sqe = self.sq.next_sqe()
|
||||
# print(f"submit: {work}")
|
||||
work.submit(sqe)
|
||||
|
||||
def select(self, timeout):
|
||||
|
@ -225,8 +226,9 @@ cdef class Ring:
|
|||
if need_enter:
|
||||
arg.sigmask = 0
|
||||
arg.sigmask_sz = SIG_SIZE
|
||||
print(f"SYS_io_uring_enter(submit={submit}, wait_nr={wait_nr}, "
|
||||
f"flags={flags:b}, timeout={timeout})")
|
||||
# print(f"SYS_io_uring_enter(submit={submit}, wait_nr={wait_nr}, "
|
||||
# f"flags={flags:b}, timeout={timeout})")
|
||||
with nogil:
|
||||
ret = libc.syscall(
|
||||
libc.SYS_io_uring_enter,
|
||||
self.enter_fd,
|
||||
|
@ -238,6 +240,8 @@ cdef class Ring:
|
|||
)
|
||||
if ret < 0:
|
||||
if errno.errno != errno.ETIME:
|
||||
print(f"SYS_io_uring_enter(submit={submit}, wait_nr={wait_nr}, "
|
||||
f"flags={flags:b}, timeout={timeout})")
|
||||
PyErr_SetFromErrno(IOError)
|
||||
return
|
||||
|
||||
|
@ -263,7 +267,7 @@ cdef class Work:
|
|||
int op,
|
||||
linux.io_uring_sqe * sqe,
|
||||
int fd,
|
||||
void * addr,
|
||||
void* addr,
|
||||
unsigned len,
|
||||
linux.__u64 offset,
|
||||
):
|
||||
|
@ -386,6 +390,7 @@ cdef class RecvWork(Work):
|
|||
|
||||
cdef class RecvMsgWork(Work):
|
||||
def __init__(self, int fd, buffers, callback):
|
||||
cdef size_t size = libc.CMSG_SPACE(sizeof(unsigned char))
|
||||
self.fd = fd
|
||||
self.buffers = buffers
|
||||
self.callback = callback
|
||||
|
@ -398,9 +403,9 @@ cdef class RecvMsgWork(Work):
|
|||
for i, buf in enumerate(buffers):
|
||||
self.msg.msg_iov[i].iov_base = <char*>buf
|
||||
self.msg.msg_iov[i].iov_len = len(buf)
|
||||
self.control_msg = bytearray(256)
|
||||
self.control_msg = bytearray(size)
|
||||
self.msg.msg_control = <char*>self.control_msg
|
||||
self.msg.msg_controllen = 256
|
||||
self.msg.msg_controllen = size
|
||||
|
||||
def __dealloc__(self):
|
||||
if self.msg.msg_iov != NULL:
|
||||
|
@ -410,11 +415,24 @@ cdef class RecvMsgWork(Work):
|
|||
self._submit(linux.IORING_OP_RECVMSG, sqe, self.fd, &self.msg, 1, 0)
|
||||
|
||||
def complete(self):
|
||||
if self.res < 0:
|
||||
errno.errno = abs(self.res)
|
||||
PyErr_SetFromErrno(IOError)
|
||||
return
|
||||
# if self.msg.msg_controllen:
|
||||
# print('control_msg:', self.control_msg[:self.msg.msg_controllen])
|
||||
# print('flags:', self.msg.msg_flags)
|
||||
self.callback(self.res)
|
||||
cdef:
|
||||
libc.cmsghdr* cmsg
|
||||
unsigned char* cmsg_data
|
||||
unsigned char record_type
|
||||
# if self.res < 0:
|
||||
# errno.errno = abs(self.res)
|
||||
# PyErr_SetFromErrno(IOError)
|
||||
# return
|
||||
app_data = True
|
||||
if self.msg.msg_controllen:
|
||||
print('msg_controllen:', self.msg.msg_controllen)
|
||||
cmsg = libc.CMSG_FIRSTHDR(&self.msg)
|
||||
if cmsg.cmsg_level == libc.SOL_TLS and cmsg.cmsg_type == linux.TLS_GET_RECORD_TYPE:
|
||||
cmsg_data = libc.CMSG_DATA(cmsg)
|
||||
record_type = (<unsigned char*>cmsg_data)[0]
|
||||
if record_type != ssl.SSL3_RT_APPLICATION_DATA:
|
||||
app_data = False
|
||||
print(f'cmsg.len={cmsg.cmsg_len}, cmsg.level={cmsg.cmsg_level}, cmsg.type={cmsg.cmsg_type}')
|
||||
print(f'record type: {record_type}')
|
||||
print('flags:', self.msg.msg_flags)
|
||||
self.callback(self.res, app_data)
|
||||
|
|
|
@ -38,14 +38,21 @@ class TestLoop(unittest.TestCase):
|
|||
|
||||
def test_connect(self):
|
||||
ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
|
||||
ctx.check_hostname = False
|
||||
ctx.verify_mode = ssl.CERT_NONE
|
||||
ctx.minimum_version = ssl.TLSVersion.TLSv1_3
|
||||
host = "www.google.com"
|
||||
r, w = self.loop.run_until_complete(
|
||||
asyncio.open_connection("www.google.com", 443, ssl=ctx)
|
||||
# asyncio.open_connection("127.0.0.1", 8080, ssl=ctx)
|
||||
asyncio.open_connection(host, 443, ssl=ctx)
|
||||
)
|
||||
w.write(b"GET / HTTP/1.1\r\n"
|
||||
b"Host: www.google.com\r\n"
|
||||
self.loop.run_until_complete(asyncio.sleep(1))
|
||||
print('send request')
|
||||
w.write(b"GET / HTTP/1.1\r\n" +
|
||||
f"Host: {host}\r\n".encode("ISO-8859-1") +
|
||||
b"Connection: close\r\n"
|
||||
b"\r\n")
|
||||
while line := self.loop.run_until_complete(r.readline()):
|
||||
while line := self.loop.run_until_complete(r.read()):
|
||||
print(line)
|
||||
w.close()
|
||||
self.loop.run_until_complete(w.wait_closed())
|
||||
|
|
Loading…
Reference in a new issue