diff --git a/src/kloop/loop.pxd b/src/kloop/loop.pxd index 3f40754..6e0e14d 100644 --- a/src/kloop/loop.pxd +++ b/src/kloop/loop.pxd @@ -43,6 +43,7 @@ cdef class KLoopImpl: bint closed object thread_id Loop loop + Resolver resolver cpdef create_future(self) cdef inline check_closed(self) diff --git a/src/kloop/loop.pyx b/src/kloop/loop.pyx index 2706b6a..40cce19 100644 --- a/src/kloop/loop.pyx +++ b/src/kloop/loop.pyx @@ -514,24 +514,11 @@ cdef class KLoopImpl: interleave=None, ): cdef: - TCPTransport transport - Resolve resolve - object waiter - size_t i + int fd - resolve = await self.resolver.lookup_ip(host, port) - if not resolve.r.result_len: - raise RuntimeError(f"Cannot resolve host: {host!r}") - - transport = TCPTransport.new(protocol_factory, self) - exceptions = [] - for i in range(resolve.r.result_len): - try: - waiter = transport.connect(resolve.r.result + i) - return transport, await waiter - except OSError as e: - exceptions.append(e) - raise exceptions[0] + fd = await tcp_connect(self, host, port) + protocol = protocol_factory() + return TCPTransport.new(fd, protocol, self), protocol class KLoop(KLoopImpl, asyncio.AbstractEventLoop): diff --git a/src/kloop/resolver.pyx b/src/kloop/resolver.pyx index 7b43260..78a5c44 100644 --- a/src/kloop/resolver.pyx +++ b/src/kloop/resolver.pyx @@ -116,6 +116,9 @@ cdef class Resolver: cdef init_cb(self): cdef int res = self.resolver.res + if self.waiter.done(): + return + if res < 0: try: errno.errno = -res @@ -128,7 +131,8 @@ cdef class Resolver: cdef err_cb(self, exc): waiter, self.waiter = self.waiter, None if waiter is not None: - waiter.set_exception(exc) + if not waiter.done(): + waiter.set_exception(exc) async def lookup_ip(self, host, port): await self.ensure_initialized() diff --git a/src/kloop/tcp.pxd b/src/kloop/tcp.pxd index 201cf79..bc4c217 100644 --- a/src/kloop/tcp.pxd +++ b/src/kloop/tcp.pxd @@ -10,8 +10,6 @@ cdef struct TCPConnect: - int fd - libc.sockaddr* addr RingCallback ring_cb Loop* loop Callback* cb @@ -21,13 +19,7 @@ cdef class TCPTransport: cdef: KLoopImpl loop int fd - TCPConnect connector - object waiter - object protocol_factory - Handle handle + object protocol @staticmethod - cdef TCPTransport new(object protocol_factory, KLoopImpl loop) - - cdef connect(self, libc.sockaddr* addr) - cdef connect_cb(self) + cdef TCPTransport new(int fd, object protocol, KLoopImpl loop) diff --git a/src/kloop/tcp.pyx b/src/kloop/tcp.pyx index 6e6fd2e..d2bc05c 100644 --- a/src/kloop/tcp.pyx +++ b/src/kloop/tcp.pyx @@ -9,13 +9,54 @@ # See the Mulan PSL v2 for more details. -cdef int tcp_connect(TCPConnect* connector) nogil: - return ring_sq_submit_connect( - &connector.loop.ring.sq, - connector.fd, - connector.addr, - &connector.ring_cb, - ) +async def tcp_connect(KLoopImpl loop, host, port): + cdef: + Resolve resolve + TCPConnect connector + int fd, res + libc.sockaddr * addr + Handle handle + size_t i + + resolve = await loop.resolver.lookup_ip(host, port) + if not resolve.r.result_len: + raise RuntimeError(f"Cannot resolve host: {host!r}") + + connector.loop = &loop.loop + connector.ring_cb.callback = tcp_connect_cb + connector.ring_cb.data = &connector + + exceptions = [] + for i in range(resolve.r.result_len): + addr = resolve.r.result + i + fd = libc.socket(addr.sa_family, libc.SOCK_STREAM, 0) + if fd == -1: + raise IOError("Cannot create socket") + + try: + waiter = loop.create_future() + handle = Handle(waiter.set_result, (None,), loop, None) + connector.cb = &handle.cb + + if not ring_sq_submit_connect( + &loop.loop.ring.sq, + fd, + addr, + &connector.ring_cb, + ): + raise ValueError("Submission queue is full!") + + await waiter + + res = abs(connector.ring_cb.res) + if res != 0: + raise IOError(res, string.strerror(res)) + return fd + + except Exception as e: + os.close(fd) + exceptions.append(e) + raise exceptions[0] cdef int tcp_connect_cb(RingCallback* cb) nogil except 0: @@ -25,50 +66,13 @@ cdef int tcp_connect_cb(RingCallback* cb) nogil except 0: cdef class TCPTransport: @staticmethod - cdef TCPTransport new(object protocol_factory, KLoopImpl loop): + cdef TCPTransport new(int fd, object protocol, KLoopImpl loop): cdef TCPTransport rv = TCPTransport.__new__(TCPTransport) - rv.protocol_factory = protocol_factory + rv.fd = fd + rv.protocol = protocol rv.loop = loop - rv.connector.loop = &loop.loop - rv.connector.ring_cb.callback = tcp_connect_cb - rv.connector.ring_cb.data = &rv.connector + loop.call_soon(protocol.connection_made, rv) return rv - cdef connect(self, libc.sockaddr* addr): - cdef: - int fd - TCPConnect* c = &self.connector - - fd = libc.socket(addr.sa_family, libc.SOCK_STREAM, 0) - if fd == -1: - PyErr_SetFromErrno(IOError) - return - c.addr = addr - c.fd = self.fd = fd - self.handle = Handle(self.connect_cb, (self,), self.loop, None) - c.cb = &self.handle.cb - if not tcp_connect(c): - raise ValueError("Submission queue is full!") - self.waiter = self.loop.create_future() - return self.waiter - - cdef connect_cb(self): - if self.connector.ring_cb.res != 0: - if not ring_sq_submit_close( - &self.loop.loop.ring.sq, self.fd, NULL - ): - # TODO: fd not closed? - pass - try: - errno.errno = abs(self.connector.ring_cb.res) - PyErr_SetFromErrno(IOError) - except IOError as e: - self.waiter.set_exception(e) - return - - protocol = self.protocol_factory() - self.waiter.set_result(protocol) - self.loop.call_soon(protocol.connection_made, self) - def get_extra_info(self, x): return None