Source code for distributed.comm.core

from __future__ import print_function, division, absolute_import

from abc import ABCMeta, abstractmethod, abstractproperty
from datetime import timedelta
import logging

from six import string_types, with_metaclass

from tornado import gen
from tornado.ioloop import IOLoop

from ..config import config
from ..metrics import time


logger = logging.getLogger(__name__)

# Connector instances

connectors = {
    #'tcp': ...,
    #'zmq': ...,
    }


# Listener classes

listeners = {
    #'tcp': ...,
    # 'zmq': ...,
    }


DEFAULT_SCHEME = config.get('default-scheme', 'tcp')


class CommClosedError(IOError):
    pass


[docs]class Comm(with_metaclass(ABCMeta)): """ A message-oriented communication object, representing an established communication channel. There should be only one reader and one writer at a time: to manage current communications, even with a single peer, you must create distinct ``Comm`` objects. Messages are arbitrary Python objects. Concrete implementations of this class can implement different serialization mechanisms depending on the underlying transport's characteristics. """ # XXX add set_close_callback()? @abstractmethod
[docs] def read(self, deserialize=None): """ Read and return a message (a Python object). If *deserialize* is not None, it overrides this communication's default setting. This method is a coroutine. """
@abstractmethod
[docs] def write(self, msg): """ Write a message (a Python object). This method is a coroutine. """
@abstractmethod
[docs] def close(self): """ Close the communication cleanly. This will attempt to flush outgoing buffers before actually closing the underlying transport. This method is a coroutine. """
@abstractmethod
[docs] def abort(self): """ Close the communication immediately and abruptly. Useful in destructors or generators' ``finally`` blocks. """
@abstractmethod
[docs] def closed(self): """ Return whether the stream is closed. """
@abstractproperty def peer_address(self): """ The peer's address. For logging and debugging purposes only. """
[docs]class Listener(with_metaclass(ABCMeta)): @abstractmethod
[docs] def start(self): """ Start listening for incoming connections. """
@abstractmethod
[docs] def stop(self): """ Stop listening. This does not shutdown already established communications, but prevents accepting new ones. """ tcp_server, self.tcp_server = self.tcp_server, None if tcp_server is not None: tcp_server.stop()
@abstractproperty def listen_address(self): """ The listening address as a URI string. """ @abstractproperty def contact_address(self): """ An address this listener can be contacted on. This can be different from `listen_address` if the latter is some wildcard address such as 'tcp://0.0.0.0:123'. """
def parse_address(addr): """ Split address into its scheme and scheme-dependent location string. """ if not isinstance(addr, string_types): raise TypeError("expected str, got %r" % addr.__class__.__name__) scheme, sep, loc = addr.rpartition('://') if not sep: scheme = DEFAULT_SCHEME return scheme, loc def unparse_address(scheme, loc): """ Undo parse_address(). """ return '%s://%s' % (scheme, loc) def parse_host_port(address, default_port=None): """ Parse an endpoint address given in the form "host:port". """ if isinstance(address, tuple): return address if address.startswith('tcp:'): address = address[4:] def _fail(): raise ValueError("invalid address %r" % (address,)) def _default(): if default_port is None: raise ValueError("missing port number in address %r" % (address,)) return default_port if address.startswith('['): host, sep, tail = address[1:].partition(']') if not sep: _fail() if not tail: port = _default() else: if not tail.startswith(':'): _fail() port = tail[1:] else: host, sep, port = address.partition(':') if not sep: port = _default() elif ':' in host: _fail() return host, int(port) def unparse_host_port(host, port=None): """ Undo parse_host_port(). """ if ':' in host and not host.startswith('['): host = '[%s]' % host if port: return '%s:%s' % (host, port) else: return host def get_address_host_port(addr): """ Get a (host, port) tuple out of the given address. """ scheme, loc = parse_address(addr) if scheme not in ('tcp', 'zmq'): raise ValueError("don't know how to extract host and port " "for address %r" % (addr,)) return parse_host_port(loc) def normalize_address(addr): """ Canonicalize address, adding a default scheme if necessary. """ return unparse_address(*parse_address(addr)) def resolve_address(addr): """ Apply scheme-specific address resolution to *addr*, ensuring all symbolic references are replaced with concrete location specifiers. In practice, this means hostnames are resolved to IP addresses. """ # XXX circular import; reorganize APIs into a distributed.comms.addressing module? from ..utils import ensure_ip scheme, loc = parse_address(addr) if scheme not in ('tcp', 'zmq'): raise ValueError("don't know how to extract host and port " "for address %r" % (addr,)) host, port = parse_host_port(loc) loc = unparse_host_port(ensure_ip(host), port) addr = unparse_address(scheme, loc) return addr @gen.coroutine
[docs]def connect(addr, timeout=3, deserialize=True): """ Connect to the given address (a URI such as ``tcp://127.0.0.1:1234``) and yield a ``Comm`` object. If the connection attempt fails, it is retried until the *timeout* is expired. """ scheme, loc = parse_address(addr) connector = connectors.get(scheme) if connector is None: raise ValueError("unknown scheme %r in address %r" % (scheme, addr)) start = time() deadline = start + timeout error = None def _raise(error): error = error or "connect() didn't finish in time" msg = ("Timed out trying to connect to %r after %s s: %s" % (addr, timeout, error)) raise IOError(msg) while True: try: future = connector.connect(loc, deserialize=deserialize) comm = yield gen.with_timeout(timedelta(seconds=deadline - time()), future, quiet_exceptions=EnvironmentError) except EnvironmentError as e: error = str(e) if time() < deadline: yield gen.sleep(0.01) logger.debug("sleeping on connect") else: _raise(error) except gen.TimeoutError: _raise(error) else: break raise gen.Return(comm)
[docs]def listen(addr, handle_comm, deserialize=True): """ Create a listener object with the given parameters. When its ``start()`` method is called, the listener will listen on the given address (a URI such as ``tcp://0.0.0.0``) and call *handle_comm* with a ``Comm`` object for each incoming connection. *handle_comm* can be a regular function or a coroutine. """ scheme, loc = parse_address(addr) listener_class = listeners.get(scheme) if listener_class is None: raise ValueError("unknown scheme %r in address %r" % (scheme, addr)) return listener_class(loc, handle_comm, deserialize)