diff --git a/src/kloop/loop.pyx b/src/kloop/loop.pyx index 40cce19..3dc7c92 100644 --- a/src/kloop/loop.pyx +++ b/src/kloop/loop.pyx @@ -516,9 +516,20 @@ cdef class KLoopImpl: cdef: int fd + if ssl is False: + ssl = None + elif ssl is not None: + from . import tls + if ssl is True: + import ssl as ssl_mod + ssl = ssl_mod.create_default_context() fd = await tcp_connect(self, host, port) protocol = protocol_factory() - return TCPTransport.new(fd, protocol, self), protocol + if ssl is not None: + transport = tls.TLSTransport.new(fd, protocol, self, ssl) + else: + transport = TCPTransport.new(fd, protocol, self) + return transport, protocol class KLoop(KLoopImpl, asyncio.AbstractEventLoop): diff --git a/src/kloop/tls.pxd b/src/kloop/tls.pxd index 4f3cefa..daa796f 100644 --- a/src/kloop/tls.pxd +++ b/src/kloop/tls.pxd @@ -9,5 +9,15 @@ # See the Mulan PSL v2 for more details. -cdef struct BIO: - int data +from .includes.openssl cimport bio +from .loop cimport KLoopImpl + + +cdef class TLSTransport: + cdef: + KLoopImpl loop + int fd + bio.BIO* bio + object protocol + object sslctx + object sslobj diff --git a/src/kloop/tls.pyx b/src/kloop/tls.pyx index 7c2d8c6..0db3184 100644 --- a/src/kloop/tls.pyx +++ b/src/kloop/tls.pyx @@ -63,52 +63,61 @@ cdef long bio_ctrl(bio.BIO* b, int cmd, long num, void* ptr) nogil: cdef int bio_create(bio.BIO* b) nogil: - cdef BIO* obj = PyMem_RawMalloc(sizeof(BIO)) - if obj == NULL: - return 0 - string.memset(obj, 0, sizeof(BIO)) - bio.set_data(b, obj) bio.set_init(b, 1) return 1 cdef int bio_destroy(bio.BIO* b) nogil: - cdef void* obj = bio.get_data(b) - if obj != NULL: - PyMem_RawFree(obj) bio.set_shutdown(b, 1) return 1 -cdef object wrap_bio( - bio.BIO* b, - object ssl_context, - bint server_side=False, - object server_hostname=None, - object session=None, -): - cdef pyssl.PySSLMemoryBIO* c_bio - py_bio = ssl.MemoryBIO() - c_bio = py_bio - c_bio.bio, b = b, c_bio.bio - rv = ssl_context.wrap_bio( - py_bio, py_bio, server_side, server_hostname, session - ) - c_bio.bio, b = b, c_bio.bio - ssl_h.set_options( - (rv._sslobj).ssl, ssl_h.OP_ENABLE_KTLS - ) - return rv +cdef class TLSTransport: + @staticmethod + def new( + int fd, + protocol, + KLoopImpl loop, + sslctx, + server_side=False, + server_hostname=None, + session=None, + ): + cdef: + TLSTransport rv = TLSTransport.__new__(TLSTransport) + pyssl.PySSLMemoryBIO* c_bio + py_bio = ssl.MemoryBIO() + c_bio = py_bio + c_bio.bio, rv.bio = rv.bio, c_bio.bio + try: + rv.sslobj = sslctx.wrap_bio( + py_bio, py_bio, server_side, server_hostname, session + ) + finally: + c_bio.bio, rv.bio = rv.bio, c_bio.bio + del py_bio -def test(): - cdef BIO* b - with nogil: - b = bio.new(KTLS_BIO_METHOD) - if b == NULL: - raise fromOpenSSLError(RuntimeError) - ctx = ssl.create_default_context() - return wrap_bio(b, ctx) + ssl_h.set_options( + (rv.sslobj).ssl, ssl_h.OP_ENABLE_KTLS + ) + rv.fd = fd + rv.protocol = protocol + rv.loop = loop + rv.sslctx = sslctx + + try: + rv.sslobj.do_handshake() + except (ssl.SSLWantReadError, ssl.SSLWantWriteError): + pass + return rv + + def __cinit__(self): + self.bio = bio.new(KTLS_BIO_METHOD) + bio.set_data(self.bio, self) + + def __dealloc__(self): + bio.free(self.bio) cdef bio.Method* KTLS_BIO_METHOD = bio.meth_new( diff --git a/tests/test_loop.py b/tests/test_loop.py index 96a4c45..e3c07e6 100644 --- a/tests/test_loop.py +++ b/tests/test_loop.py @@ -38,9 +38,9 @@ class TestLoop(unittest.TestCase): def test_connect(self): ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) - ctx.check_hostname = False - ctx.verify_mode = ssl.CERT_NONE - ctx.minimum_version = ssl.TLSVersion.TLSv1_3 + # ctx.check_hostname = False + # ctx.verify_mode = ssl.CERT_NONE + # ctx.minimum_version = ssl.TLSVersion.TLSv1_3 host = "www.google.com" r, w = self.loop.run_until_complete( # asyncio.open_connection("127.0.0.1", 8080, ssl=ctx)