mirror of
https://gitee.com/fantix/kloop.git
synced 2024-11-22 10:21:25 +00:00
Almost working PoC
This commit is contained in:
parent
97f02d40bc
commit
9fc44a51c8
7 changed files with 180 additions and 71 deletions
|
@ -31,6 +31,16 @@ cdef extern from "sys/socket.h" nogil:
|
||||||
size_t msg_controllen # ancillary data buffer len
|
size_t msg_controllen # ancillary data buffer len
|
||||||
int msg_flags # flags on received message
|
int msg_flags # flags on received message
|
||||||
|
|
||||||
|
struct cmsghdr:
|
||||||
|
socklen_t cmsg_len # data byte count, including header
|
||||||
|
int cmsg_level # originating protocol
|
||||||
|
int cmsg_type # protocol-specific type
|
||||||
|
|
||||||
|
size_t CMSG_LEN(size_t length)
|
||||||
|
cmsghdr* CMSG_FIRSTHDR(msghdr* msgh)
|
||||||
|
unsigned char* CMSG_DATA(cmsghdr* cmsg)
|
||||||
|
size_t CMSG_SPACE(size_t length)
|
||||||
|
|
||||||
|
|
||||||
cdef extern from "arpa/inet.h" nogil:
|
cdef extern from "arpa/inet.h" nogil:
|
||||||
int inet_pton(int af, char* src, void* dst)
|
int inet_pton(int af, char* src, void* dst)
|
||||||
|
|
|
@ -32,6 +32,8 @@ cdef extern from "linux/tcp.h" nogil:
|
||||||
|
|
||||||
|
|
||||||
cdef extern from "linux/tls.h" nogil:
|
cdef extern from "linux/tls.h" nogil:
|
||||||
|
int TLS_GET_RECORD_TYPE
|
||||||
|
|
||||||
__u16 TLS_CIPHER_AES_GCM_256
|
__u16 TLS_CIPHER_AES_GCM_256
|
||||||
int TLS_CIPHER_AES_GCM_256_IV_SIZE
|
int TLS_CIPHER_AES_GCM_256_IV_SIZE
|
||||||
int TLS_CIPHER_AES_GCM_256_SALT_SIZE
|
int TLS_CIPHER_AES_GCM_256_SALT_SIZE
|
||||||
|
|
|
@ -23,6 +23,16 @@ cdef extern from "openssl/ssl.h" nogil:
|
||||||
SSL_CTX_keylog_cb_func SSL_CTX_get_keylog_callback(SSL_CTX* ctx)
|
SSL_CTX_keylog_cb_func SSL_CTX_get_keylog_callback(SSL_CTX* ctx)
|
||||||
SSL_CTX* SSL_get_SSL_CTX(SSL* ssl)
|
SSL_CTX* SSL_get_SSL_CTX(SSL* ssl)
|
||||||
|
|
||||||
|
ctypedef enum OSSL_HANDSHAKE_STATE:
|
||||||
|
pass
|
||||||
|
|
||||||
|
OSSL_HANDSHAKE_STATE SSL_get_state(const SSL *ssl);
|
||||||
|
|
||||||
|
unsigned int SSL3_RT_CHANGE_CIPHER_SPEC
|
||||||
|
unsigned int SSL3_RT_ALERT
|
||||||
|
unsigned int SSL3_RT_HANDSHAKE
|
||||||
|
unsigned int SSL3_RT_APPLICATION_DATA
|
||||||
|
|
||||||
|
|
||||||
cdef extern from "includes/ssl.h" nogil:
|
cdef extern from "includes/ssl.h" nogil:
|
||||||
ctypedef struct PySSLSocket:
|
ctypedef struct PySSLSocket:
|
||||||
|
|
|
@ -11,6 +11,7 @@
|
||||||
import socket
|
import socket
|
||||||
import hmac
|
import hmac
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import struct
|
||||||
from ssl import SSLWantReadError
|
from ssl import SSLWantReadError
|
||||||
|
|
||||||
from cpython cimport PyErr_SetFromErrno
|
from cpython cimport PyErr_SetFromErrno
|
||||||
|
@ -61,20 +62,14 @@ def do_handshake_capturing_secrets(sslobj):
|
||||||
ssl.SSL_CTX_set_keylog_callback(ctx, orig_cb)
|
ssl.SSL_CTX_set_keylog_callback(ctx, orig_cb)
|
||||||
|
|
||||||
|
|
||||||
def hkdf_expand(pseudo_random_key, info=b"", length=32, hash=hashlib.sha384):
|
def hkdf_expand(pseudo_random_key, label, length, hash_method=hashlib.sha384):
|
||||||
'''
|
'''
|
||||||
Expand `pseudo_random_key` and `info` into a key of length `bytes` using
|
Expand `pseudo_random_key` and `info` into a key of length `bytes` using
|
||||||
HKDF's expand function based on HMAC with the provided hash (default
|
HKDF's expand function based on HMAC with the provided hash (default
|
||||||
SHA-512). See the HKDF draft RFC and paper for usage notes.
|
SHA-512). See the HKDF draft RFC and paper for usage notes.
|
||||||
'''
|
'''
|
||||||
# info_in = info
|
info = struct.pack("!HB", length, len(label)) + label + b'\0'
|
||||||
# info = b'\0' + struct.pack("H", len(info)) + info + b'\0'
|
hash_len = hash_method().digest_size
|
||||||
# print(f'hkdf_expand info_in= label={info.hex()}')
|
|
||||||
hash_len = hash().digest_size
|
|
||||||
length = int(length)
|
|
||||||
if length > 255 * hash_len:
|
|
||||||
raise Exception("Cannot expand to more than 255 * %d = %d bytes using the specified hash function" % \
|
|
||||||
(hash_len, 255 * hash_len))
|
|
||||||
blocks_needed = length // hash_len + (0 if length % hash_len == 0 else 1) # ceil
|
blocks_needed = length // hash_len + (0 if length % hash_len == 0 else 1) # ceil
|
||||||
okm = b""
|
okm = b""
|
||||||
output_block = b""
|
output_block = b""
|
||||||
|
@ -82,7 +77,7 @@ def hkdf_expand(pseudo_random_key, info=b"", length=32, hash=hashlib.sha384):
|
||||||
output_block = hmac.new(
|
output_block = hmac.new(
|
||||||
pseudo_random_key,
|
pseudo_random_key,
|
||||||
(output_block + info + bytearray((counter + 1,))),
|
(output_block + info + bytearray((counter + 1,))),
|
||||||
hash,
|
hash_method,
|
||||||
).digest()
|
).digest()
|
||||||
okm += output_block
|
okm += output_block
|
||||||
return okm[:length]
|
return okm[:length]
|
||||||
|
@ -95,6 +90,12 @@ def enable_ulp(sock):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def get_state(sslobj):
|
||||||
|
cdef:
|
||||||
|
ssl.SSL* s = (<ssl.PySSLSocket*>sslobj._sslobj).ssl
|
||||||
|
print(ssl.SSL_get_state(s))
|
||||||
|
|
||||||
|
|
||||||
def upgrade_aes_gcm_256(sslobj, sock, secret, sending):
|
def upgrade_aes_gcm_256(sslobj, sock, secret, sending):
|
||||||
cdef:
|
cdef:
|
||||||
ssl.SSL* s = (<ssl.PySSLSocket*>sslobj._sslobj).ssl
|
ssl.SSL* s = (<ssl.PySSLSocket*>sslobj._sslobj).ssl
|
||||||
|
@ -116,7 +117,7 @@ def upgrade_aes_gcm_256(sslobj, sock, secret, sending):
|
||||||
|
|
||||||
key = hkdf_expand(
|
key = hkdf_expand(
|
||||||
secret,
|
secret,
|
||||||
b'\x00 \ttls13 key\x00',
|
b'tls13 key',
|
||||||
linux.TLS_CIPHER_AES_GCM_256_KEY_SIZE,
|
linux.TLS_CIPHER_AES_GCM_256_KEY_SIZE,
|
||||||
)
|
)
|
||||||
string.memcpy(
|
string.memcpy(
|
||||||
|
@ -131,7 +132,7 @@ def upgrade_aes_gcm_256(sslobj, sock, secret, sending):
|
||||||
)
|
)
|
||||||
iv = hkdf_expand(
|
iv = hkdf_expand(
|
||||||
secret,
|
secret,
|
||||||
b'\x00\x0c\x08tls13 iv\x00',
|
b'tls13 iv',
|
||||||
linux.TLS_CIPHER_AES_GCM_256_IV_SIZE +
|
linux.TLS_CIPHER_AES_GCM_256_IV_SIZE +
|
||||||
linux.TLS_CIPHER_AES_GCM_256_SALT_SIZE,
|
linux.TLS_CIPHER_AES_GCM_256_SALT_SIZE,
|
||||||
)
|
)
|
||||||
|
|
|
@ -13,6 +13,7 @@ import asyncio.futures
|
||||||
import asyncio.trsock
|
import asyncio.trsock
|
||||||
import asyncio.transports
|
import asyncio.transports
|
||||||
import contextvars
|
import contextvars
|
||||||
|
import errno
|
||||||
import socket
|
import socket
|
||||||
import ssl
|
import ssl
|
||||||
|
|
||||||
|
@ -98,11 +99,14 @@ class KLoopSocketTransport(
|
||||||
self._recv_buffer_factory = protocol.get_buffer
|
self._recv_buffer_factory = protocol.get_buffer
|
||||||
else:
|
else:
|
||||||
self._read_ready_cb = self._read_ready__data_received
|
self._read_ready_cb = self._read_ready__data_received
|
||||||
self._recv_buffer = bytearray(256 * 1024)
|
self._recv_buffer = bytearray(64 * 1024 * 1024)
|
||||||
self._recv_buffer_factory = lambda _hint: self._recv_buffer
|
self._recv_buffer_factory = lambda _hint: self._recv_buffer
|
||||||
self._protocol = protocol
|
self._protocol = protocol
|
||||||
|
|
||||||
def _read(self):
|
def _read(self):
|
||||||
|
# print("RecvMsgWork")
|
||||||
|
if self._read_paused:
|
||||||
|
return
|
||||||
self._loop._selector.submit(
|
self._loop._selector.submit(
|
||||||
uring.RecvMsgWork(
|
uring.RecvMsgWork(
|
||||||
self._sock.fileno(),
|
self._sock.fileno(),
|
||||||
|
@ -111,7 +115,7 @@ class KLoopSocketTransport(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _read_ready__buffer_updated(self, res):
|
def _read_ready__buffer_updated(self, res, app_data):
|
||||||
if res < 0:
|
if res < 0:
|
||||||
raise IOError
|
raise IOError
|
||||||
elif res == 0:
|
elif res == 0:
|
||||||
|
@ -124,20 +128,29 @@ class KLoopSocketTransport(
|
||||||
if not self._closing:
|
if not self._closing:
|
||||||
self._read()
|
self._read()
|
||||||
|
|
||||||
def _read_ready__data_received(self, res):
|
def _read_ready__data_received(self, res, app_data):
|
||||||
|
print("_read_ready__data_received", res)
|
||||||
if res < 0:
|
if res < 0:
|
||||||
|
if abs(res) == errno.EAGAIN:
|
||||||
|
print('EAGAIN')
|
||||||
|
self._read()
|
||||||
|
else:
|
||||||
raise IOError(f"{res}")
|
raise IOError(f"{res}")
|
||||||
elif res == 0:
|
elif res == 0:
|
||||||
self._protocol.eof_received()
|
self._protocol.eof_received()
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
# print(f"data received: {res}")
|
print(f"data received: {res}")
|
||||||
self._protocol.data_received(self._recv_buffer[:res])
|
data = bytes(self._recv_buffer[:res])
|
||||||
|
# print(f"data received: {data}")
|
||||||
|
if app_data:
|
||||||
|
self._protocol.data_received(data)
|
||||||
finally:
|
finally:
|
||||||
if not self._closing:
|
if not self._closing:
|
||||||
self._read()
|
self._read()
|
||||||
|
|
||||||
def _write_done(self, res):
|
def _write_done(self, res):
|
||||||
|
# print("_write_done")
|
||||||
self._current_work = None
|
self._current_work = None
|
||||||
if res < 0:
|
if res < 0:
|
||||||
# TODO: force close transport
|
# TODO: force close transport
|
||||||
|
@ -153,6 +166,7 @@ class KLoopSocketTransport(
|
||||||
self._sock.fileno(), self._buffers, self._write_done
|
self._sock.fileno(), self._buffers, self._write_done
|
||||||
)
|
)
|
||||||
self._loop._selector.submit(self._current_work)
|
self._loop._selector.submit(self._current_work)
|
||||||
|
# print("more SendWork")
|
||||||
self._buffers = []
|
self._buffers = []
|
||||||
elif self._closing:
|
elif self._closing:
|
||||||
self._loop.call_soon(self._call_connection_lost, None)
|
self._loop.call_soon(self._call_connection_lost, None)
|
||||||
|
@ -168,6 +182,7 @@ class KLoopSocketTransport(
|
||||||
self._sock.fileno(), data, self._write_done
|
self._sock.fileno(), data, self._write_done
|
||||||
)
|
)
|
||||||
self._loop._selector.submit(self._current_work)
|
self._loop._selector.submit(self._current_work)
|
||||||
|
# print("SendWork")
|
||||||
else:
|
else:
|
||||||
self._buffers.append(data)
|
self._buffers.append(data)
|
||||||
self._maybe_pause_protocol()
|
self._maybe_pause_protocol()
|
||||||
|
@ -233,45 +248,87 @@ class KLoopSSLHandshakeProtocol(asyncio.Protocol):
|
||||||
self._handshake()
|
self._handshake()
|
||||||
|
|
||||||
def _handshake(self):
|
def _handshake(self):
|
||||||
|
ktls.get_state(self._sslobj)
|
||||||
success, secrets = ktls.do_handshake_capturing_secrets(self._sslobj)
|
success, secrets = ktls.do_handshake_capturing_secrets(self._sslobj)
|
||||||
self._secrets.update(secrets)
|
self._secrets.update(secrets)
|
||||||
if success:
|
if success:
|
||||||
|
# print("handshake done")
|
||||||
if self._handshaking:
|
if self._handshaking:
|
||||||
|
print(self._sslobj.cipher())
|
||||||
|
ktls.get_state(self._sslobj)
|
||||||
self._handshaking = False
|
self._handshaking = False
|
||||||
|
try:
|
||||||
|
data = self._sslobj.read(64 * 1024)
|
||||||
|
except ssl.SSLWantReadError:
|
||||||
|
data = None
|
||||||
|
if data:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data += self._sslobj.read(64 * 1024)
|
||||||
|
except ssl.SSLWantReadError:
|
||||||
|
break
|
||||||
|
print("try read", data)
|
||||||
|
self._transport._upgrade_ktls_read(
|
||||||
|
self._sslobj,
|
||||||
|
self._secrets["SERVER_TRAFFIC_SECRET_0"],
|
||||||
|
data,
|
||||||
|
)
|
||||||
if data := self._outgoing.read():
|
if data := self._outgoing.read():
|
||||||
|
print("last message")
|
||||||
self._transport.write(data)
|
self._transport.write(data)
|
||||||
self._transport._write_waiter = self._after_last_write
|
self._transport._write_waiter = self._after_last_write
|
||||||
|
# self._transport._write_waiter = lambda: self._transport._loop.call_later(1, self._after_last_write)
|
||||||
self._transport.pause_reading()
|
self._transport.pause_reading()
|
||||||
else:
|
else:
|
||||||
self._after_last_write()
|
self._after_last_write()
|
||||||
# else:
|
else:
|
||||||
|
assert False
|
||||||
# try:
|
# try:
|
||||||
# data = self._sslobj.read(16384)
|
# data = self._sslobj.read(64 * 1024)
|
||||||
# except ssl.SSLWantReadError:
|
# except ssl.SSLWantReadError:
|
||||||
# data = None
|
# data = None
|
||||||
|
# if data:
|
||||||
|
# while True:
|
||||||
|
# try:
|
||||||
|
# data += self._sslobj.read(64 * 1024)
|
||||||
|
# except ssl.SSLWantReadError:
|
||||||
|
# break
|
||||||
|
# print("try read", data)
|
||||||
|
# # ktls.get_state(self._sslobj)
|
||||||
|
# self._after_last_write()
|
||||||
# self._transport._upgrade_ktls_read(
|
# self._transport._upgrade_ktls_read(
|
||||||
# self._sslobj,
|
# self._sslobj,
|
||||||
# self._secrets["SERVER_TRAFFIC_SECRET_0"],
|
# self._secrets["SERVER_TRAFFIC_SECRET_0"],
|
||||||
# data,
|
# data,
|
||||||
# )
|
# )
|
||||||
else:
|
else:
|
||||||
|
# print("SSLWantReadError")
|
||||||
if data := self._outgoing.read():
|
if data := self._outgoing.read():
|
||||||
self._transport.write(data)
|
self._transport.write(data)
|
||||||
|
|
||||||
def _after_last_write(self):
|
def _after_last_write(self):
|
||||||
try:
|
print("_after_last_write")
|
||||||
data = self._sslobj.read(16384)
|
ktls.get_state(self._sslobj)
|
||||||
except ssl.SSLWantReadError:
|
# try:
|
||||||
data = None
|
# data = self._sslobj.read(64 * 1024)
|
||||||
|
# except ssl.SSLWantReadError:
|
||||||
|
# data = None
|
||||||
|
# if data:
|
||||||
|
# while True:
|
||||||
|
# try:
|
||||||
|
# data += self._sslobj.read(64 * 1024)
|
||||||
|
# except ssl.SSLWantReadError:
|
||||||
|
# break
|
||||||
|
# print("try read", data)
|
||||||
self._transport._upgrade_ktls_write(
|
self._transport._upgrade_ktls_write(
|
||||||
self._sslobj,
|
self._sslobj,
|
||||||
self._secrets["CLIENT_TRAFFIC_SECRET_0"],
|
self._secrets["CLIENT_TRAFFIC_SECRET_0"],
|
||||||
)
|
)
|
||||||
self._transport._upgrade_ktls_read(
|
# self._transport._upgrade_ktls_read(
|
||||||
self._sslobj,
|
# self._sslobj,
|
||||||
self._secrets["SERVER_TRAFFIC_SECRET_0"],
|
# self._secrets["SERVER_TRAFFIC_SECRET_0"],
|
||||||
data,
|
# data,
|
||||||
)
|
# )
|
||||||
self._transport.resume_reading()
|
self._transport.resume_reading()
|
||||||
|
|
||||||
|
|
||||||
|
@ -303,7 +360,9 @@ class KLoopSSLTransport(KLoopSocketTransport):
|
||||||
self._waiter = waiter
|
self._waiter = waiter
|
||||||
|
|
||||||
def _upgrade_ktls_write(self, sslobj, secret):
|
def _upgrade_ktls_write(self, sslobj, secret):
|
||||||
|
print("_upgrade_ktls_write")
|
||||||
ktls.upgrade_aes_gcm_256(sslobj, self._sock, secret, True)
|
ktls.upgrade_aes_gcm_256(sslobj, self._sock, secret, True)
|
||||||
|
self.set_protocol(self._app_protocol)
|
||||||
self._loop.call_soon(self._app_protocol.connection_made, self)
|
self._loop.call_soon(self._app_protocol.connection_made, self)
|
||||||
if self._waiter is not None:
|
if self._waiter is not None:
|
||||||
self._loop.call_soon(
|
self._loop.call_soon(
|
||||||
|
@ -313,13 +372,14 @@ class KLoopSSLTransport(KLoopSocketTransport):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _upgrade_ktls_read(self, sslobj, secret, data):
|
def _upgrade_ktls_read(self, sslobj, secret, data):
|
||||||
|
print("_upgrade_ktls_read")
|
||||||
ktls.upgrade_aes_gcm_256(sslobj, self._sock, secret, False)
|
ktls.upgrade_aes_gcm_256(sslobj, self._sock, secret, False)
|
||||||
self.set_protocol(self._app_protocol)
|
# self.set_protocol(self._app_protocol)
|
||||||
if data is not None:
|
# if data is not None:
|
||||||
if data:
|
# if data:
|
||||||
self._app_protocol.data_received(data)
|
# self._app_protocol.data_received(data)
|
||||||
else:
|
# else:
|
||||||
self._app_protocol.eof_received()
|
# self._app_protocol.eof_received()
|
||||||
|
|
||||||
|
|
||||||
class KLoop(asyncio.BaseEventLoop):
|
class KLoop(asyncio.BaseEventLoop):
|
||||||
|
@ -344,6 +404,7 @@ class KLoop(asyncio.BaseEventLoop):
|
||||||
def _make_socket_transport(
|
def _make_socket_transport(
|
||||||
self, sock, protocol, waiter=None, *, extra=None, server=None
|
self, sock, protocol, waiter=None, *, extra=None, server=None
|
||||||
):
|
):
|
||||||
|
sock.setblocking(True)
|
||||||
return KLoopSocketTransport(
|
return KLoopSocketTransport(
|
||||||
self, sock, protocol, waiter, extra, server
|
self, sock, protocol, waiter, extra, server
|
||||||
)
|
)
|
||||||
|
|
|
@ -16,7 +16,7 @@ from cpython cimport PyMem_RawMalloc, PyMem_RawFree
|
||||||
from libc cimport errno, string
|
from libc cimport errno, string
|
||||||
from posix cimport mman
|
from posix cimport mman
|
||||||
|
|
||||||
from .includes cimport barrier, libc, linux
|
from .includes cimport barrier, libc, linux, ssl
|
||||||
|
|
||||||
cdef linux.__u32 SIG_SIZE = libc._NSIG // 8
|
cdef linux.__u32 SIG_SIZE = libc._NSIG // 8
|
||||||
|
|
||||||
|
@ -187,6 +187,7 @@ cdef class Ring:
|
||||||
|
|
||||||
def submit(self, Work work):
|
def submit(self, Work work):
|
||||||
cdef linux.io_uring_sqe* sqe = self.sq.next_sqe()
|
cdef linux.io_uring_sqe* sqe = self.sq.next_sqe()
|
||||||
|
# print(f"submit: {work}")
|
||||||
work.submit(sqe)
|
work.submit(sqe)
|
||||||
|
|
||||||
def select(self, timeout):
|
def select(self, timeout):
|
||||||
|
@ -225,8 +226,9 @@ cdef class Ring:
|
||||||
if need_enter:
|
if need_enter:
|
||||||
arg.sigmask = 0
|
arg.sigmask = 0
|
||||||
arg.sigmask_sz = SIG_SIZE
|
arg.sigmask_sz = SIG_SIZE
|
||||||
print(f"SYS_io_uring_enter(submit={submit}, wait_nr={wait_nr}, "
|
# print(f"SYS_io_uring_enter(submit={submit}, wait_nr={wait_nr}, "
|
||||||
f"flags={flags:b}, timeout={timeout})")
|
# f"flags={flags:b}, timeout={timeout})")
|
||||||
|
with nogil:
|
||||||
ret = libc.syscall(
|
ret = libc.syscall(
|
||||||
libc.SYS_io_uring_enter,
|
libc.SYS_io_uring_enter,
|
||||||
self.enter_fd,
|
self.enter_fd,
|
||||||
|
@ -238,6 +240,8 @@ cdef class Ring:
|
||||||
)
|
)
|
||||||
if ret < 0:
|
if ret < 0:
|
||||||
if errno.errno != errno.ETIME:
|
if errno.errno != errno.ETIME:
|
||||||
|
print(f"SYS_io_uring_enter(submit={submit}, wait_nr={wait_nr}, "
|
||||||
|
f"flags={flags:b}, timeout={timeout})")
|
||||||
PyErr_SetFromErrno(IOError)
|
PyErr_SetFromErrno(IOError)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -263,7 +267,7 @@ cdef class Work:
|
||||||
int op,
|
int op,
|
||||||
linux.io_uring_sqe * sqe,
|
linux.io_uring_sqe * sqe,
|
||||||
int fd,
|
int fd,
|
||||||
void * addr,
|
void* addr,
|
||||||
unsigned len,
|
unsigned len,
|
||||||
linux.__u64 offset,
|
linux.__u64 offset,
|
||||||
):
|
):
|
||||||
|
@ -386,6 +390,7 @@ cdef class RecvWork(Work):
|
||||||
|
|
||||||
cdef class RecvMsgWork(Work):
|
cdef class RecvMsgWork(Work):
|
||||||
def __init__(self, int fd, buffers, callback):
|
def __init__(self, int fd, buffers, callback):
|
||||||
|
cdef size_t size = libc.CMSG_SPACE(sizeof(unsigned char))
|
||||||
self.fd = fd
|
self.fd = fd
|
||||||
self.buffers = buffers
|
self.buffers = buffers
|
||||||
self.callback = callback
|
self.callback = callback
|
||||||
|
@ -398,9 +403,9 @@ cdef class RecvMsgWork(Work):
|
||||||
for i, buf in enumerate(buffers):
|
for i, buf in enumerate(buffers):
|
||||||
self.msg.msg_iov[i].iov_base = <char*>buf
|
self.msg.msg_iov[i].iov_base = <char*>buf
|
||||||
self.msg.msg_iov[i].iov_len = len(buf)
|
self.msg.msg_iov[i].iov_len = len(buf)
|
||||||
self.control_msg = bytearray(256)
|
self.control_msg = bytearray(size)
|
||||||
self.msg.msg_control = <char*>self.control_msg
|
self.msg.msg_control = <char*>self.control_msg
|
||||||
self.msg.msg_controllen = 256
|
self.msg.msg_controllen = size
|
||||||
|
|
||||||
def __dealloc__(self):
|
def __dealloc__(self):
|
||||||
if self.msg.msg_iov != NULL:
|
if self.msg.msg_iov != NULL:
|
||||||
|
@ -410,11 +415,24 @@ cdef class RecvMsgWork(Work):
|
||||||
self._submit(linux.IORING_OP_RECVMSG, sqe, self.fd, &self.msg, 1, 0)
|
self._submit(linux.IORING_OP_RECVMSG, sqe, self.fd, &self.msg, 1, 0)
|
||||||
|
|
||||||
def complete(self):
|
def complete(self):
|
||||||
if self.res < 0:
|
cdef:
|
||||||
errno.errno = abs(self.res)
|
libc.cmsghdr* cmsg
|
||||||
PyErr_SetFromErrno(IOError)
|
unsigned char* cmsg_data
|
||||||
return
|
unsigned char record_type
|
||||||
# if self.msg.msg_controllen:
|
# if self.res < 0:
|
||||||
# print('control_msg:', self.control_msg[:self.msg.msg_controllen])
|
# errno.errno = abs(self.res)
|
||||||
# print('flags:', self.msg.msg_flags)
|
# PyErr_SetFromErrno(IOError)
|
||||||
self.callback(self.res)
|
# return
|
||||||
|
app_data = True
|
||||||
|
if self.msg.msg_controllen:
|
||||||
|
print('msg_controllen:', self.msg.msg_controllen)
|
||||||
|
cmsg = libc.CMSG_FIRSTHDR(&self.msg)
|
||||||
|
if cmsg.cmsg_level == libc.SOL_TLS and cmsg.cmsg_type == linux.TLS_GET_RECORD_TYPE:
|
||||||
|
cmsg_data = libc.CMSG_DATA(cmsg)
|
||||||
|
record_type = (<unsigned char*>cmsg_data)[0]
|
||||||
|
if record_type != ssl.SSL3_RT_APPLICATION_DATA:
|
||||||
|
app_data = False
|
||||||
|
print(f'cmsg.len={cmsg.cmsg_len}, cmsg.level={cmsg.cmsg_level}, cmsg.type={cmsg.cmsg_type}')
|
||||||
|
print(f'record type: {record_type}')
|
||||||
|
print('flags:', self.msg.msg_flags)
|
||||||
|
self.callback(self.res, app_data)
|
||||||
|
|
|
@ -38,14 +38,21 @@ class TestLoop(unittest.TestCase):
|
||||||
|
|
||||||
def test_connect(self):
|
def test_connect(self):
|
||||||
ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
|
ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
|
||||||
|
ctx.check_hostname = False
|
||||||
|
ctx.verify_mode = ssl.CERT_NONE
|
||||||
|
ctx.minimum_version = ssl.TLSVersion.TLSv1_3
|
||||||
|
host = "www.google.com"
|
||||||
r, w = self.loop.run_until_complete(
|
r, w = self.loop.run_until_complete(
|
||||||
asyncio.open_connection("www.google.com", 443, ssl=ctx)
|
# asyncio.open_connection("127.0.0.1", 8080, ssl=ctx)
|
||||||
|
asyncio.open_connection(host, 443, ssl=ctx)
|
||||||
)
|
)
|
||||||
w.write(b"GET / HTTP/1.1\r\n"
|
self.loop.run_until_complete(asyncio.sleep(1))
|
||||||
b"Host: www.google.com\r\n"
|
print('send request')
|
||||||
|
w.write(b"GET / HTTP/1.1\r\n" +
|
||||||
|
f"Host: {host}\r\n".encode("ISO-8859-1") +
|
||||||
b"Connection: close\r\n"
|
b"Connection: close\r\n"
|
||||||
b"\r\n")
|
b"\r\n")
|
||||||
while line := self.loop.run_until_complete(r.readline()):
|
while line := self.loop.run_until_complete(r.read()):
|
||||||
print(line)
|
print(line)
|
||||||
w.close()
|
w.close()
|
||||||
self.loop.run_until_complete(w.wait_closed())
|
self.loop.run_until_complete(w.wait_closed())
|
||||||
|
|
Loading…
Reference in a new issue