mirror of
https://gitee.com/fantix/kloop.git
synced 2025-01-05 07:28:43 +00:00
TLS: fix buffer usage and reset message
Also improved debug
This commit is contained in:
parent
74f0062154
commit
761f741d5d
3 changed files with 84 additions and 65 deletions
|
@ -20,6 +20,11 @@ cdef extern from "openssl/ssl.h" nogil:
|
|||
int OP_ENABLE_KTLS "SSL_OP_ENABLE_KTLS"
|
||||
int set_options "SSL_set_options" (SSL* ssl, int options)
|
||||
|
||||
long SSL_MODE_RELEASE_BUFFERS
|
||||
long SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER
|
||||
long clear_mode "SSL_clear_mode" (SSL* ssl, long mode)
|
||||
int free_buffers "SSL_free_buffers" (SSL *ssl)
|
||||
|
||||
|
||||
cdef extern from *:
|
||||
"""
|
||||
|
|
|
@ -25,6 +25,7 @@ cdef struct Proxy:
|
|||
RingCallback recv_callback
|
||||
unsigned char flags
|
||||
char* read_buffer
|
||||
void* cmsg
|
||||
|
||||
Loop* loop
|
||||
int fd
|
||||
|
@ -51,6 +52,7 @@ cdef class TLSTransport:
|
|||
object write_buffer
|
||||
bint sending
|
||||
|
||||
|
||||
cdef do_handshake(self)
|
||||
cdef do_read(self)
|
||||
cdef do_read_ktls(self)
|
||||
|
|
|
@ -15,7 +15,7 @@ import ssl
|
|||
from cpython cimport PyMem_RawMalloc, PyMem_RawFree
|
||||
from libc cimport errno, string
|
||||
|
||||
from .includes.openssl cimport bio, err, ssl as ssl_h
|
||||
from .includes.openssl cimport err, ssl as ssl_h
|
||||
from .includes cimport pyssl, linux
|
||||
from .loop cimport ring_sq_submit_sendmsg, ring_sq_submit_recvmsg
|
||||
|
||||
|
@ -48,11 +48,16 @@ cdef size_t CMSG_SIZE = libc.CMSG_SPACE(sizeof(unsigned char))
|
|||
DEF DEBUG = 0
|
||||
|
||||
|
||||
cdef inline void reset_msg(libc.msghdr* msg, size_t controllen) nogil:
|
||||
cdef inline void reset_msg(libc.msghdr* msg, void* cmsg) nogil:
|
||||
msg.msg_name = NULL
|
||||
msg.msg_namelen = 0
|
||||
msg.msg_flags = 0
|
||||
msg.msg_controllen = controllen
|
||||
if cmsg == NULL:
|
||||
msg.msg_control = NULL
|
||||
msg.msg_controllen = 0
|
||||
else:
|
||||
msg.msg_control = cmsg
|
||||
msg.msg_controllen = CMSG_SIZE
|
||||
|
||||
|
||||
cdef object fromOpenSSLError(object err_type):
|
||||
|
@ -72,18 +77,18 @@ cdef int bio_write_ex(
|
|||
Proxy* proxy = <Proxy*>bio.get_data(b)
|
||||
int res
|
||||
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
with gil:
|
||||
print("bio_write_ex(data=%x, datal=%d)" % (<long long>data, datal))
|
||||
|
||||
if proxy.flags & FLAGS_PROXY_SEND_SUBMITTED:
|
||||
if proxy.send_vec.iov_base != data:
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
with gil:
|
||||
print("bio_write_ex() error: concurrent call")
|
||||
return 0
|
||||
if proxy.send_vec.iov_len > datal:
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
with gil:
|
||||
print("bio_write_ex() error: short rewrite")
|
||||
return 0
|
||||
|
@ -93,40 +98,42 @@ cdef int bio_write_ex(
|
|||
proxy.flags &= ~FLAGS_PROXY_SEND_ALL
|
||||
res = proxy.send_callback.res
|
||||
if res < 0:
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
with gil:
|
||||
print("bio_write_ex() error:", -res)
|
||||
errno.errno = -res
|
||||
return 0
|
||||
|
||||
written[0] = res
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
with gil:
|
||||
print('bio_write_ex() written:', res)
|
||||
print("bio_write_ex() written:", res)
|
||||
print(">>> ", end="")
|
||||
for i in range(res):
|
||||
print(
|
||||
"%02x " % <unsigned char>data[i],
|
||||
end="" if (i + 1) % 16 and i < res - 1 else "\n",
|
||||
end="" if (i + 1) % 16 or i == res - 1 else "\n>>> ",
|
||||
)
|
||||
print()
|
||||
|
||||
else:
|
||||
written[0] = 0
|
||||
bio.set_retry_write(b)
|
||||
|
||||
if not proxy.flags & FLAGS_PROXY_SEND_SUBMITTED:
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
with gil:
|
||||
print("bio_write_ex() submit")
|
||||
proxy.send_vec.iov_base = data
|
||||
proxy.send_vec.iov_len = datal
|
||||
reset_msg(&proxy.send_msg, 0)
|
||||
reset_msg(&proxy.send_msg, NULL)
|
||||
if not ring_sq_submit_sendmsg(
|
||||
&proxy.loop.ring.sq,
|
||||
proxy.fd,
|
||||
&proxy.send_msg,
|
||||
&proxy.send_callback,
|
||||
):
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
with gil:
|
||||
print("bio_write_ex() error: SQ full")
|
||||
return 0
|
||||
|
@ -144,18 +151,18 @@ cdef int bio_read_ex(
|
|||
int res
|
||||
int is_ktls = bio.test_flags(b, FLAGS_KTLS_RX)
|
||||
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
with gil:
|
||||
print('bio_read_ex(data=%x, datal=%d)' % (<long long>data, datal))
|
||||
|
||||
if proxy.flags & FLAGS_PROXY_RECV_SUBMITTED:
|
||||
if proxy.recv_vec.iov_base != (data + 5 if is_ktls else data):
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
with gil:
|
||||
print("bio_read_ex() error: concurrent call")
|
||||
return 0
|
||||
if proxy.recv_vec.iov_len > (datal - 21 if is_ktls else datal):
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
with gil:
|
||||
print("bio_read_ex() error: short reread")
|
||||
return 0
|
||||
|
@ -164,7 +171,7 @@ cdef int bio_read_ex(
|
|||
if proxy.flags & FLAGS_PROXY_RECV_KTLS:
|
||||
res = proxy.recv_callback.res
|
||||
if datal < res + 5:
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
with gil:
|
||||
print("bio_read_ex() error: datal too short")
|
||||
errno.errno = errno.EINVAL
|
||||
|
@ -184,24 +191,22 @@ cdef int bio_read_ex(
|
|||
string.memcpy(data, proxy.read_buffer, res)
|
||||
readbytes[0] = res
|
||||
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
with gil:
|
||||
print(
|
||||
"bio_read_ex() read:",
|
||||
res,
|
||||
"(forwarded TLS record)"
|
||||
)
|
||||
print("bio_read_ex() read:", res, "(forwarded TLS record)")
|
||||
print("<<< ", end="")
|
||||
for i in range(res):
|
||||
print(
|
||||
"%02x " % <unsigned char>data[i],
|
||||
end="" if (i + 1) % 16 and i < res - 1 else "\n",
|
||||
end="" if (i + 1) % 16 or i == res - 1 else "\n<<< ",
|
||||
)
|
||||
print()
|
||||
|
||||
elif proxy.flags & FLAGS_PROXY_RECV_COMPLETED:
|
||||
proxy.flags &= ~FLAGS_PROXY_RECV_ALL
|
||||
res = proxy.recv_callback.res
|
||||
if res < 0:
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
with gil:
|
||||
print("bio_read_ex() error:", -res)
|
||||
errno.errno = -res
|
||||
|
@ -223,25 +228,25 @@ cdef int bio_read_ex(
|
|||
bio.set_flags(b, bio.FLAGS_IN_EOF)
|
||||
readbytes[0] = res
|
||||
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
with gil:
|
||||
print(
|
||||
"bio_read_ex() read:",
|
||||
res,
|
||||
"(TLS record)" if cmsg else ""
|
||||
"bio_read_ex() read:", res, "(TLS record)" if cmsg else ""
|
||||
)
|
||||
print("<<< ", end="")
|
||||
for i in range(res):
|
||||
print(
|
||||
"%02x " % <unsigned char>data[i],
|
||||
end="" if (i + 1) % 16 and i < res - 1 else "\n",
|
||||
end="" if (i + 1) % 16 or i == res - 1 else "\n<<< ",
|
||||
)
|
||||
print()
|
||||
else:
|
||||
bio.set_retry_read(b)
|
||||
readbytes[0] = 0
|
||||
if not proxy.flags & FLAGS_PROXY_RECV_SUBMITTED:
|
||||
if is_ktls:
|
||||
if datal < 21:
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
with gil:
|
||||
print("bio_read_ex() error: datal too short")
|
||||
errno.errno = errno.EINVAL
|
||||
|
@ -249,7 +254,8 @@ cdef int bio_read_ex(
|
|||
|
||||
proxy.recv_vec.iov_base = data + 5
|
||||
proxy.recv_vec.iov_len = datal - 21
|
||||
if DEBUG:
|
||||
reset_msg(&proxy.recv_msg, proxy.cmsg)
|
||||
IF DEBUG:
|
||||
with gil:
|
||||
print("bio_read_ex() submit(%x, %d)" % (
|
||||
<long long>proxy.recv_vec.iov_base,
|
||||
|
@ -258,10 +264,10 @@ cdef int bio_read_ex(
|
|||
else:
|
||||
proxy.recv_vec.iov_base = data
|
||||
proxy.recv_vec.iov_len = datal
|
||||
if DEBUG:
|
||||
reset_msg(&proxy.recv_msg, NULL)
|
||||
IF DEBUG:
|
||||
with gil:
|
||||
print("bio_read_ex() submit")
|
||||
reset_msg(&proxy.recv_msg, CMSG_SIZE)
|
||||
|
||||
if not ring_sq_submit_recvmsg(
|
||||
&proxy.loop.ring.sq,
|
||||
|
@ -269,7 +275,7 @@ cdef int bio_read_ex(
|
|||
&proxy.recv_msg,
|
||||
&proxy.recv_callback,
|
||||
):
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
with gil:
|
||||
print("bio_read_ex() error: SQ full")
|
||||
return 0
|
||||
|
@ -283,24 +289,24 @@ cdef long bio_ctrl(bio.BIO* b, int cmd, long num, void* ptr) nogil:
|
|||
ssl_h.ktls_crypto_info_t* crypto_info
|
||||
long ret = 0
|
||||
if cmd == bio.BIO_CTRL_EOF:
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
with gil:
|
||||
print("BIO_CTRL_EOF", ret)
|
||||
elif cmd == bio.BIO_CTRL_PUSH:
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
with gil:
|
||||
print("BIO_CTRL_PUSH", ret)
|
||||
elif cmd == bio.BIO_CTRL_POP:
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
with gil:
|
||||
print("BIO_CTRL_POP", ret)
|
||||
elif cmd == bio.BIO_CTRL_FLUSH:
|
||||
ret = 1
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
with gil:
|
||||
print('BIO_CTRL_FLUSH', ret)
|
||||
elif cmd == BIO_CTRL_SET_KTLS:
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
with gil:
|
||||
print("BIO_CTRL_SET_KTLS", "TX end" if num else "RX end")
|
||||
crypto_info = <ssl_h.ktls_crypto_info_t*>ptr
|
||||
|
@ -313,7 +319,7 @@ cdef long bio_ctrl(bio.BIO* b, int cmd, long num, void* ptr) nogil:
|
|||
) == 0:
|
||||
bio.set_flags(b, FLAGS_KTLS_TX if num else FLAGS_KTLS_RX)
|
||||
else:
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
with gil:
|
||||
print(
|
||||
"BIO_CTRL_SET_KTLS",
|
||||
|
@ -325,7 +331,7 @@ cdef long bio_ctrl(bio.BIO* b, int cmd, long num, void* ptr) nogil:
|
|||
elif cmd == BIO_CTRL_GET_KTLS_RECV:
|
||||
return bio.test_flags(b, FLAGS_KTLS_RX) != 0
|
||||
else:
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
with gil:
|
||||
print('bio_ctrl', cmd, num)
|
||||
return ret
|
||||
|
@ -372,6 +378,7 @@ cdef class TLSTransport:
|
|||
cdef:
|
||||
TLSTransport rv = TLSTransport.__new__(TLSTransport)
|
||||
pyssl.PySSLMemoryBIO* c_bio
|
||||
pyssl.SSL* s
|
||||
|
||||
libc.setsockopt(fd, socket.SOL_TCP, linux.TCP_ULP, b"tls", 3)
|
||||
|
||||
|
@ -386,8 +393,12 @@ cdef class TLSTransport:
|
|||
c_bio.bio, rv.bio = rv.bio, c_bio.bio
|
||||
del py_bio
|
||||
|
||||
ssl_h.set_options(
|
||||
(<pyssl.PySSLSocket*>rv.sslobj._sslobj).ssl, ssl_h.OP_ENABLE_KTLS
|
||||
s = (<pyssl.PySSLSocket*>rv.sslobj._sslobj).ssl
|
||||
ssl_h.set_options(s, ssl_h.OP_ENABLE_KTLS)
|
||||
ssl_h.clear_mode(
|
||||
s,
|
||||
ssl_h.SSL_MODE_RELEASE_BUFFERS |
|
||||
ssl_h.SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER
|
||||
)
|
||||
rv.fd = fd
|
||||
rv.protocol = protocol
|
||||
|
@ -409,10 +420,9 @@ cdef class TLSTransport:
|
|||
self.proxy.send_msg.msg_iovlen = 1
|
||||
self.proxy.send_callback.data = <void*>&self.proxy
|
||||
self.proxy.send_callback.callback = tls_send_cb
|
||||
self.proxy.recv_msg.msg_control = PyMem_RawMalloc(CMSG_SIZE)
|
||||
if self.proxy.recv_msg.msg_control == NULL:
|
||||
self.proxy.cmsg = PyMem_RawMalloc(CMSG_SIZE)
|
||||
if self.proxy.cmsg == NULL:
|
||||
raise MemoryError
|
||||
self.proxy.recv_msg.msg_controllen = CMSG_SIZE
|
||||
self.proxy.recv_msg.msg_iov = &self.proxy.recv_vec
|
||||
self.proxy.recv_msg.msg_iovlen = 1
|
||||
self.proxy.recv_callback.data = <void*>&self.proxy
|
||||
|
@ -423,7 +433,7 @@ cdef class TLSTransport:
|
|||
self.sslobj = None
|
||||
bio.free(self.bio)
|
||||
PyMem_RawFree(self.proxy.read_buffer)
|
||||
PyMem_RawFree(self.proxy.recv_msg.msg_control)
|
||||
PyMem_RawFree(self.proxy.cmsg)
|
||||
|
||||
cdef do_handshake(self):
|
||||
if self.state == UNWRAPPED:
|
||||
|
@ -431,20 +441,23 @@ cdef class TLSTransport:
|
|||
elif self.state != HANDSHAKING:
|
||||
raise RuntimeError("Cannot do handshake now")
|
||||
|
||||
|
||||
try:
|
||||
IF DEBUG:
|
||||
print("do_handshake()")
|
||||
self.sslobj.do_handshake()
|
||||
except ssl.SSLWantReadError:
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
print("do_handshake() SSLWantReadError")
|
||||
except ssl.SSLWantWriteError:
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
print("do_handshake() SSLWantWriteError")
|
||||
except Exception as ex:
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
print('do_handshake() error:', ex)
|
||||
raise
|
||||
else:
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
print('do_handshake() done')
|
||||
|
||||
self.state = WRAPPED
|
||||
|
@ -473,14 +486,14 @@ cdef class TLSTransport:
|
|||
self.proxy.flags &= ~FLAGS_PROXY_RECV_ALL
|
||||
res = self.proxy.recv_callback.res
|
||||
if res < 0:
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
print("do_read_ktls() error:", -res)
|
||||
self.loop.call_soon(
|
||||
self.protocol.connection_lost,
|
||||
IOError(-res, string.strerror(-res))
|
||||
)
|
||||
elif res == 0:
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
print("do_read_ktls() EOF")
|
||||
self.loop.call_soon(self.protocol.eof_received)
|
||||
self.loop.call_soon(self.protocol.connection_lost, None)
|
||||
|
@ -490,10 +503,10 @@ cdef class TLSTransport:
|
|||
if cmsg.cmsg_type == linux.TLS_GET_RECORD_TYPE:
|
||||
record_type = (<unsigned char*>libc.CMSG_DATA(cmsg))[0]
|
||||
if record_type != ssl_h.SSL3_RT_APPLICATION_DATA:
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
print("do_read_ktls() forward CMSG")
|
||||
return self.do_read()
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
print("do_read_ktls() received", res, "bytes")
|
||||
self.loop.call_soon(
|
||||
self.protocol.data_received,
|
||||
|
@ -502,10 +515,9 @@ cdef class TLSTransport:
|
|||
self.loop.call_soon(self.do_read_ktls, self)
|
||||
|
||||
elif not self.proxy.flags & FLAGS_PROXY_RECV_SUBMITTED:
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
print("do_read_ktls() submit")
|
||||
self.proxy.recv_msg.msg_controllen = CMSG_SIZE
|
||||
reset_msg(&self.proxy.recv_msg, CMSG_SIZE)
|
||||
reset_msg(&self.proxy.recv_msg, self.proxy.cmsg)
|
||||
if not ring_sq_submit_recvmsg(
|
||||
&self.proxy.loop.ring.sq,
|
||||
self.fd,
|
||||
|
@ -519,18 +531,18 @@ cdef class TLSTransport:
|
|||
try:
|
||||
data = self.sslobj.read(65536)
|
||||
except ssl.SSLWantReadError:
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
print("do_read() SSLWantReadError")
|
||||
except ssl.SSLWantWriteError:
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
print("do_read() SSLWantWriteError")
|
||||
except Exception as ex:
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
print("do_read() error:", ex)
|
||||
self.loop.call_soon(self.protocol.connection_lost, ex)
|
||||
else:
|
||||
if data:
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
print("do_read() received", len(data), bytes)
|
||||
print(data)
|
||||
self.loop.call_soon(self.protocol.data_received, data)
|
||||
|
@ -539,7 +551,7 @@ cdef class TLSTransport:
|
|||
else:
|
||||
self.loop.call_soon(self.do_read, self)
|
||||
else:
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
print("do_read() EOF")
|
||||
self.loop.call_soon(self.protocol.eof_received)
|
||||
self.loop.call_soon(self.protocol.connection_lost, None)
|
||||
|
@ -564,7 +576,7 @@ cdef class TLSTransport:
|
|||
try:
|
||||
self.sslobj.write(data)
|
||||
except ssl.SSLWantWriteError:
|
||||
if DEBUG:
|
||||
IF DEBUG:
|
||||
print("write() SSLWantWriteError")
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue