1
0
Fork 0
mirror of https://gitee.com/fantix/kloop.git synced 2024-11-21 18:01:00 +00:00

Barely working TLS client!

Refs #I5ANZH
This commit is contained in:
Fantix King 2022-06-25 19:46:19 -04:00
parent a1094281ec
commit 74f0062154
No known key found for this signature in database
GPG key ID: 95304B04071CCDB4
6 changed files with 558 additions and 30 deletions

View file

@ -84,5 +84,11 @@ cdef extern from "openssl/bio.h" nogil:
void set_init "BIO_set_init" (BIO* a, int init)
void set_shutdown "BIO_set_shutdown" (BIO* a, int shut)
void set_retry_read "BIO_set_retry_read" (BIO *b)
void set_retry_write "BIO_set_retry_write" (BIO *b)
void set_retry_read "BIO_set_retry_read" (BIO* b)
void set_retry_write "BIO_set_retry_write" (BIO* b)
void clear_retry_flags "BIO_clear_retry_flags" (BIO *b)
cdef int FLAGS_IN_EOF "BIO_FLAGS_IN_EOF"
int test_flags "BIO_test_flags" (BIO* b, int flags)
void set_flags "BIO_set_flags" (BIO *b, int flags)

View file

@ -9,9 +9,30 @@
# See the Mulan PSL v2 for more details.
from .. cimport linux
cdef extern from "openssl/ssl.h" nogil:
ctypedef struct SSL:
pass
int SSL3_RT_APPLICATION_DATA
int OP_ENABLE_KTLS "SSL_OP_ENABLE_KTLS"
int set_options "SSL_set_options" (SSL* ssl, int options)
cdef extern from *:
"""
typedef struct {
union {
struct tls12_crypto_info_aes_gcm_128 gcm128;
struct tls12_crypto_info_aes_gcm_256 gcm256;
struct tls12_crypto_info_aes_ccm_128 ccm128;
struct tls12_crypto_info_chacha20_poly1305 chacha20poly1305;
};
size_t tls_crypto_info_len;
} ktls_crypto_info_t;
"""
ctypedef struct ktls_crypto_info_t:
size_t tls_crypto_info_len

View file

@ -526,7 +526,9 @@ cdef class KLoopImpl:
fd = await tcp_connect(self, host, port)
protocol = protocol_factory()
if ssl is not None:
transport = tls.TLSTransport.new(fd, protocol, self, ssl)
waiter = self.create_future()
transport = tls.TLSTransport.new(fd, protocol, self, ssl, waiter=waiter)
await waiter
else:
transport = TCPTransport.new(fd, protocol, self)
return transport, protocol

View file

@ -9,8 +9,32 @@
# See the Mulan PSL v2 for more details.
from cpython cimport PyObject
from .includes cimport libc
from .includes.openssl cimport bio
from .loop cimport KLoopImpl
from .loop cimport KLoopImpl, Loop, RingCallback
cdef struct Proxy:
PyObject* transport
libc.iovec send_vec
libc.msghdr send_msg
RingCallback send_callback
libc.iovec recv_vec
libc.msghdr recv_msg
RingCallback recv_callback
unsigned char flags
char* read_buffer
Loop* loop
int fd
cdef enum State:
UNWRAPPED
HANDSHAKING
WRAPPED
WRAPPED_KTLS
cdef class TLSTransport:
@ -21,3 +45,14 @@ cdef class TLSTransport:
object protocol
object sslctx
object sslobj
object waiter
Proxy proxy
State state
object write_buffer
bint sending
cdef do_handshake(self)
cdef do_read(self)
cdef do_read_ktls(self)
cdef write_cb(self, int res)
cdef read_cb(self, int res)

View file

@ -9,12 +9,50 @@
# See the Mulan PSL v2 for more details.
import collections
import socket
import ssl
from cpython cimport PyMem_RawMalloc, PyMem_RawFree
from libc cimport string
from libc cimport errno, string
from .includes.openssl cimport bio, err, ssl as ssl_h
from .includes cimport pyssl
from .includes cimport pyssl, linux
from .loop cimport ring_sq_submit_sendmsg, ring_sq_submit_recvmsg
cdef int BIO_CTRL_SET_KTLS = 72
cdef int BIO_CTRL_GET_KTLS_SEND = 73
cdef int BIO_CTRL_GET_KTLS_RECV = 76
cdef int FLAGS_KTLS_TX_CTRL_MSG = 0x1000
cdef int FLAGS_KTLS_RX = 0x2000
cdef int FLAGS_KTLS_TX = 0x4000
cdef unsigned char FLAGS_PROXY_SEND_SUBMITTED = 1 << 0
cdef unsigned char FLAGS_PROXY_SEND_COMPLETED = 1 << 1
cdef unsigned char FLAGS_PROXY_SEND_IN_PROXY = 1 << 2
cdef unsigned char FLAGS_PROXY_SEND_ALL = (
FLAGS_PROXY_SEND_SUBMITTED |
FLAGS_PROXY_SEND_COMPLETED |
FLAGS_PROXY_SEND_IN_PROXY
)
cdef unsigned char FLAGS_PROXY_RECV_SUBMITTED = 1 << 4
cdef unsigned char FLAGS_PROXY_RECV_COMPLETED = 1 << 5
cdef unsigned char FLAGS_PROXY_RECV_KTLS = 1 << 6
cdef unsigned char FLAGS_PROXY_RECV_ALL = (
FLAGS_PROXY_RECV_SUBMITTED |
FLAGS_PROXY_RECV_COMPLETED
)
cdef size_t CMSG_SIZE = libc.CMSG_SPACE(sizeof(unsigned char))
DEF DEBUG = 0
cdef inline void reset_msg(libc.msghdr* msg, size_t controllen) nogil:
msg.msg_name = NULL
msg.msg_namelen = 0
msg.msg_flags = 0
msg.msg_controllen = controllen
cdef object fromOpenSSLError(object err_type):
@ -30,35 +68,266 @@ cdef object fromOpenSSLError(object err_type):
cdef int bio_write_ex(
bio.BIO* b, const char* data, size_t datal, size_t* written
) nogil:
with gil:
print('bio_write', data[:datal], int(<int>data))
bio.set_retry_write(b)
written[0] = 0
cdef:
Proxy* proxy = <Proxy*>bio.get_data(b)
int res
if DEBUG:
with gil:
print("bio_write_ex(data=%x, datal=%d)" % (<long long>data, datal))
if proxy.flags & FLAGS_PROXY_SEND_SUBMITTED:
if proxy.send_vec.iov_base != data:
if DEBUG:
with gil:
print("bio_write_ex() error: concurrent call")
return 0
if proxy.send_vec.iov_len > datal:
if DEBUG:
with gil:
print("bio_write_ex() error: short rewrite")
return 0
bio.clear_retry_flags(b)
if proxy.flags & FLAGS_PROXY_SEND_COMPLETED:
proxy.flags &= ~FLAGS_PROXY_SEND_ALL
res = proxy.send_callback.res
if res < 0:
if DEBUG:
with gil:
print("bio_write_ex() error:", -res)
errno.errno = -res
return 0
written[0] = res
if DEBUG:
with gil:
print('bio_write_ex() written:', res)
for i in range(res):
print(
"%02x " % <unsigned char>data[i],
end="" if (i + 1) % 16 and i < res - 1 else "\n",
)
else:
written[0] = 0
bio.set_retry_write(b)
if not proxy.flags & FLAGS_PROXY_SEND_SUBMITTED:
if DEBUG:
with gil:
print("bio_write_ex() submit")
proxy.send_vec.iov_base = data
proxy.send_vec.iov_len = datal
reset_msg(&proxy.send_msg, 0)
if not ring_sq_submit_sendmsg(
&proxy.loop.ring.sq,
proxy.fd,
&proxy.send_msg,
&proxy.send_callback,
):
if DEBUG:
with gil:
print("bio_write_ex() error: SQ full")
return 0
proxy.flags |= FLAGS_PROXY_SEND_SUBMITTED
return 1
cdef int bio_read_ex(
bio.BIO* b, char* data, size_t datal, size_t* readbytes
) nogil:
with gil:
print('bio_read', datal, int(<int>data))
bio.set_retry_read(b)
readbytes[0] = 0
cdef:
Proxy* proxy = <Proxy*>bio.get_data(b)
libc.cmsghdr* cmsg = NULL
int res
int is_ktls = bio.test_flags(b, FLAGS_KTLS_RX)
if DEBUG:
with gil:
print('bio_read_ex(data=%x, datal=%d)' % (<long long>data, datal))
if proxy.flags & FLAGS_PROXY_RECV_SUBMITTED:
if proxy.recv_vec.iov_base != (data + 5 if is_ktls else data):
if DEBUG:
with gil:
print("bio_read_ex() error: concurrent call")
return 0
if proxy.recv_vec.iov_len > (datal - 21 if is_ktls else datal):
if DEBUG:
with gil:
print("bio_read_ex() error: short reread")
return 0
bio.clear_retry_flags(b)
if proxy.flags & FLAGS_PROXY_RECV_KTLS:
res = proxy.recv_callback.res
if datal < res + 5:
if DEBUG:
with gil:
print("bio_read_ex() error: datal too short")
errno.errno = errno.EINVAL
return 0
cmsg = libc.CMSG_FIRSTHDR(&proxy.recv_msg)
if cmsg.cmsg_type == linux.TLS_GET_RECORD_TYPE:
data[0] = (<unsigned char *> libc.CMSG_DATA(cmsg))[0]
data[1] = 0x03 # TLS1_2_VERSION_MAJOR
data[2] = 0x03 # TLS1_2_VERSION_MINOR
# returned length is limited to msg_iov.iov_len above
data[3] = (res >> 8) & 0xff
data[4] = res & 0xff
string.memcpy(data + 5, proxy.read_buffer, res)
res += 5
else:
string.memcpy(data, proxy.read_buffer, res)
readbytes[0] = res
if DEBUG:
with gil:
print(
"bio_read_ex() read:",
res,
"(forwarded TLS record)"
)
for i in range(res):
print(
"%02x " % <unsigned char>data[i],
end="" if (i + 1) % 16 and i < res - 1 else "\n",
)
elif proxy.flags & FLAGS_PROXY_RECV_COMPLETED:
proxy.flags &= ~FLAGS_PROXY_RECV_ALL
res = proxy.recv_callback.res
if res < 0:
if DEBUG:
with gil:
print("bio_read_ex() error:", -res)
errno.errno = -res
return 0
if is_ktls:
if proxy.recv_msg.msg_controllen:
cmsg = libc.CMSG_FIRSTHDR(&proxy.recv_msg)
if cmsg.cmsg_type == linux.TLS_GET_RECORD_TYPE:
data[0] = (<unsigned char*>libc.CMSG_DATA(cmsg))[0]
data[1] = 0x03 # TLS1_2_VERSION_MAJOR
data[2] = 0x03 # TLS1_2_VERSION_MINOR
# returned length is limited to msg_iov.iov_len above
data[3] = (res >> 8) & 0xff
data[4] = res & 0xff
res += 5
if res == 0:
bio.set_flags(b, bio.FLAGS_IN_EOF)
readbytes[0] = res
if DEBUG:
with gil:
print(
"bio_read_ex() read:",
res,
"(TLS record)" if cmsg else ""
)
for i in range(res):
print(
"%02x " % <unsigned char>data[i],
end="" if (i + 1) % 16 and i < res - 1 else "\n",
)
else:
bio.set_retry_read(b)
readbytes[0] = 0
if not proxy.flags & FLAGS_PROXY_RECV_SUBMITTED:
if is_ktls:
if datal < 21:
if DEBUG:
with gil:
print("bio_read_ex() error: datal too short")
errno.errno = errno.EINVAL
return 0
proxy.recv_vec.iov_base = data + 5
proxy.recv_vec.iov_len = datal - 21
if DEBUG:
with gil:
print("bio_read_ex() submit(%x, %d)" % (
<long long>proxy.recv_vec.iov_base,
proxy.recv_vec.iov_len,
))
else:
proxy.recv_vec.iov_base = data
proxy.recv_vec.iov_len = datal
if DEBUG:
with gil:
print("bio_read_ex() submit")
reset_msg(&proxy.recv_msg, CMSG_SIZE)
if not ring_sq_submit_recvmsg(
&proxy.loop.ring.sq,
proxy.fd,
&proxy.recv_msg,
&proxy.recv_callback,
):
if DEBUG:
with gil:
print("bio_read_ex() error: SQ full")
return 0
proxy.flags |= FLAGS_PROXY_RECV_SUBMITTED
return 1
cdef long bio_ctrl(bio.BIO* b, int cmd, long num, void* ptr) nogil:
cdef long ret = 0
with gil:
if cmd == bio.BIO_CTRL_EOF:
print("BIO_CTRL_EOF", ret)
elif cmd == bio.BIO_CTRL_PUSH:
print("BIO_CTRL_PUSH", ret)
elif cmd == bio.BIO_CTRL_FLUSH:
ret = 1
print('BIO_CTRL_FLUSH', ret)
cdef:
ssl_h.ktls_crypto_info_t* crypto_info
long ret = 0
if cmd == bio.BIO_CTRL_EOF:
if DEBUG:
with gil:
print("BIO_CTRL_EOF", ret)
elif cmd == bio.BIO_CTRL_PUSH:
if DEBUG:
with gil:
print("BIO_CTRL_PUSH", ret)
elif cmd == bio.BIO_CTRL_POP:
if DEBUG:
with gil:
print("BIO_CTRL_POP", ret)
elif cmd == bio.BIO_CTRL_FLUSH:
ret = 1
if DEBUG:
with gil:
print('BIO_CTRL_FLUSH', ret)
elif cmd == BIO_CTRL_SET_KTLS:
if DEBUG:
with gil:
print("BIO_CTRL_SET_KTLS", "TX end" if num else "RX end")
crypto_info = <ssl_h.ktls_crypto_info_t*>ptr
if libc.setsockopt(
(<Proxy*>bio.get_data(b)).fd,
libc.SOL_TLS,
linux.TLS_TX if num else linux.TLS_RX,
crypto_info,
crypto_info.tls_crypto_info_len,
) == 0:
bio.set_flags(b, FLAGS_KTLS_TX if num else FLAGS_KTLS_RX)
else:
print('bio_ctrl', cmd, num)
if DEBUG:
with gil:
print(
"BIO_CTRL_SET_KTLS",
"TX end" if num else "RX end",
"failed",
)
elif cmd == BIO_CTRL_GET_KTLS_SEND:
return bio.test_flags(b, FLAGS_KTLS_TX) != 0
elif cmd == BIO_CTRL_GET_KTLS_RECV:
return bio.test_flags(b, FLAGS_KTLS_RX) != 0
else:
if DEBUG:
with gil:
print('bio_ctrl', cmd, num)
return ret
@ -72,6 +341,22 @@ cdef int bio_destroy(bio.BIO* b) nogil:
return 1
cdef int tls_send_cb(RingCallback* cb) nogil except 0:
cdef Proxy* proxy = <Proxy*>cb.data
proxy.flags |= FLAGS_PROXY_SEND_COMPLETED
with gil:
(<TLSTransport>proxy.transport).write_cb(cb.res)
return 1
cdef int tls_recv_cb(RingCallback* cb) nogil except 0:
cdef Proxy* proxy = <Proxy*>cb.data
proxy.flags |= FLAGS_PROXY_RECV_COMPLETED
with gil:
(<TLSTransport>proxy.transport).read_cb(cb.res)
return 1
cdef class TLSTransport:
@staticmethod
def new(
@ -82,11 +367,14 @@ cdef class TLSTransport:
server_side=False,
server_hostname=None,
session=None,
waiter=None,
):
cdef:
TLSTransport rv = TLSTransport.__new__(TLSTransport)
pyssl.PySSLMemoryBIO* c_bio
libc.setsockopt(fd, socket.SOL_TCP, linux.TCP_ULP, b"tls", 3)
py_bio = ssl.MemoryBIO()
c_bio = <pyssl.PySSLMemoryBIO*>py_bio
c_bio.bio, rv.bio = rv.bio, c_bio.bio
@ -99,25 +387,185 @@ cdef class TLSTransport:
del py_bio
ssl_h.set_options(
(<pyssl.PySSLSocket*>rv.sslobj).ssl, ssl_h.OP_ENABLE_KTLS
(<pyssl.PySSLSocket*>rv.sslobj._sslobj).ssl, ssl_h.OP_ENABLE_KTLS
)
rv.fd = fd
rv.protocol = protocol
rv.loop = loop
rv.sslctx = sslctx
rv.proxy.loop = &loop.loop
rv.proxy.fd = fd
rv.waiter = waiter
rv.write_buffer = collections.deque()
try:
rv.sslobj.do_handshake()
except (ssl.SSLWantReadError, ssl.SSLWantWriteError):
pass
rv.do_handshake()
return rv
def __cinit__(self):
self.state = UNWRAPPED
self.bio = bio.new(KTLS_BIO_METHOD)
bio.set_data(self.bio, <void*>self)
self.proxy.transport = <PyObject*>self
self.proxy.send_msg.msg_iov = &self.proxy.send_vec
self.proxy.send_msg.msg_iovlen = 1
self.proxy.send_callback.data = <void*>&self.proxy
self.proxy.send_callback.callback = tls_send_cb
self.proxy.recv_msg.msg_control = PyMem_RawMalloc(CMSG_SIZE)
if self.proxy.recv_msg.msg_control == NULL:
raise MemoryError
self.proxy.recv_msg.msg_controllen = CMSG_SIZE
self.proxy.recv_msg.msg_iov = &self.proxy.recv_vec
self.proxy.recv_msg.msg_iovlen = 1
self.proxy.recv_callback.data = <void*>&self.proxy
self.proxy.recv_callback.callback = tls_recv_cb
bio.set_data(self.bio, <void*>&self.proxy)
def __dealloc__(self):
self.sslobj = None
bio.free(self.bio)
PyMem_RawFree(self.proxy.read_buffer)
PyMem_RawFree(self.proxy.recv_msg.msg_control)
cdef do_handshake(self):
if self.state == UNWRAPPED:
self.state = HANDSHAKING
elif self.state != HANDSHAKING:
raise RuntimeError("Cannot do handshake now")
try:
self.sslobj.do_handshake()
except ssl.SSLWantReadError:
if DEBUG:
print("do_handshake() SSLWantReadError")
except ssl.SSLWantWriteError:
if DEBUG:
print("do_handshake() SSLWantWriteError")
except Exception as ex:
if DEBUG:
print('do_handshake() error:', ex)
raise
else:
if DEBUG:
print('do_handshake() done')
self.state = WRAPPED
if self.waiter:
self.waiter.set_result(self)
self.waiter = None
if bio.test_flags(self.bio, FLAGS_KTLS_RX):
self.proxy.read_buffer = <char*>PyMem_RawMalloc(65536)
if self.proxy.read_buffer == NULL:
raise MemoryError
self.proxy.flags |= FLAGS_PROXY_RECV_KTLS
self.proxy.recv_vec.iov_base = self.proxy.read_buffer
self.proxy.recv_vec.iov_len = 65536
self.do_read_ktls()
else:
self.do_read()
cdef do_read_ktls(self):
cdef:
int res
libc.cmsghdr* cmsg
unsigned char record_type
if self.proxy.flags & FLAGS_PROXY_RECV_COMPLETED:
self.proxy.flags &= ~FLAGS_PROXY_RECV_ALL
res = self.proxy.recv_callback.res
if res < 0:
if DEBUG:
print("do_read_ktls() error:", -res)
self.loop.call_soon(
self.protocol.connection_lost,
IOError(-res, string.strerror(-res))
)
elif res == 0:
if DEBUG:
print("do_read_ktls() EOF")
self.loop.call_soon(self.protocol.eof_received)
self.loop.call_soon(self.protocol.connection_lost, None)
else:
if self.proxy.recv_msg.msg_controllen:
cmsg = libc.CMSG_FIRSTHDR(&self.proxy.recv_msg)
if cmsg.cmsg_type == linux.TLS_GET_RECORD_TYPE:
record_type = (<unsigned char*>libc.CMSG_DATA(cmsg))[0]
if record_type != ssl_h.SSL3_RT_APPLICATION_DATA:
if DEBUG:
print("do_read_ktls() forward CMSG")
return self.do_read()
if DEBUG:
print("do_read_ktls() received", res, "bytes")
self.loop.call_soon(
self.protocol.data_received,
bytes(self.proxy.read_buffer[:res]),
)
self.loop.call_soon(self.do_read_ktls, self)
elif not self.proxy.flags & FLAGS_PROXY_RECV_SUBMITTED:
if DEBUG:
print("do_read_ktls() submit")
self.proxy.recv_msg.msg_controllen = CMSG_SIZE
reset_msg(&self.proxy.recv_msg, CMSG_SIZE)
if not ring_sq_submit_recvmsg(
&self.proxy.loop.ring.sq,
self.fd,
&self.proxy.recv_msg,
&self.proxy.recv_callback,
):
raise RuntimeError("SQ full")
self.proxy.flags |= FLAGS_PROXY_RECV_SUBMITTED
cdef do_read(self):
try:
data = self.sslobj.read(65536)
except ssl.SSLWantReadError:
if DEBUG:
print("do_read() SSLWantReadError")
except ssl.SSLWantWriteError:
if DEBUG:
print("do_read() SSLWantWriteError")
except Exception as ex:
if DEBUG:
print("do_read() error:", ex)
self.loop.call_soon(self.protocol.connection_lost, ex)
else:
if data:
if DEBUG:
print("do_read() received", len(data), bytes)
print(data)
self.loop.call_soon(self.protocol.data_received, data)
if self.proxy.flags & FLAGS_PROXY_RECV_KTLS:
self.loop.call_soon(self.do_read_ktls, self)
else:
self.loop.call_soon(self.do_read, self)
else:
if DEBUG:
print("do_read() EOF")
self.loop.call_soon(self.protocol.eof_received)
self.loop.call_soon(self.protocol.connection_lost, None)
cdef write_cb(self, int res):
if self.state == HANDSHAKING:
self.do_handshake()
cdef read_cb(self, int res):
if self.state == HANDSHAKING:
self.do_handshake()
elif self.state == WRAPPED:
if self.proxy.flags & FLAGS_PROXY_RECV_KTLS:
self.do_read_ktls()
else:
self.do_read()
def write(self, data):
if self.sending:
self.write_buffer.append(data)
else:
try:
self.sslobj.write(data)
except ssl.SSLWantWriteError:
if DEBUG:
print("write() SSLWantWriteError")
cdef bio.Method* KTLS_BIO_METHOD = bio.meth_new(

View file

@ -56,3 +56,19 @@ cdef struct RingCallback:
void* data
int res
int (*callback)(RingCallback* cb) nogil except 0
cdef int ring_sq_submit_sendmsg(
SubmissionQueue* sq,
int fd,
const libc.msghdr *msg,
RingCallback* callback,
) nogil
cdef int ring_sq_submit_recvmsg(
SubmissionQueue* sq,
int fd,
const libc.msghdr *msg,
RingCallback* callback,
) nogil