"""Remote (procedure) calls.
Much like :mod:`asyncio`'s transports and protocols, this module is divided into
low-level and high-level APIs:
* The low-level API, :class:`Node` and its implementations, deal with transporting
discrete binary messages and managing the underlying transport.
* The high-level API, :class:`Endpoint` and its implementations, implement
request/response semantics. Most consumers should use the high-level API.
This remote call message format is based on `MessagePack-RPC`_, except this module uses
:mod:`cbor2` for serialization.
Every application (process) typically creates a single :class:`Handler` bound to one or
more :class:`Service` instances. The handler encapsulates the application's business
logic and state, while each service exposes the handler's methods to a different
transport.
.. _MessagePack-RPC:
https://github.com/msgpack-rpc/msgpack-rpc/blob/master/spec.md
"""
import abc
import asyncio
import contextlib
import enum
import functools
import inspect
import random
import socket
import types
import typing
from collections.abc import (
AsyncIterator,
Awaitable,
Callable,
Collection,
Iterator,
MutableMapping,
)
from dataclasses import dataclass, field
from typing import Any, Generic, NoReturn, Optional, Protocol, TypeVar, Union
from urllib.parse import urlsplit
import cbor2
import structlog
import zmq
import zmq.asyncio
import zmq.error
from .exception import RuntimeBaseException
__all__ = [
'Client',
'DatagramNode',
'Endpoint',
'Handler',
'MessageType',
'Node',
'RemoteCallError',
'RequestTracker',
'Router',
'Service',
'SocketNode',
'route',
]
[docs]class RemoteCallError(RuntimeBaseException):
"""Error produced by executing a remote call.
Parameters:
message: A human-readable description of the exception.
context: Machine-readable data.
"""
[docs]class MessageType(enum.IntEnum):
"""The message type ID.
Attributes:
REQUEST: Denotes a request message sent by clients. Requires a response.
RESPONSE: Denotes a response message sent by services.
NOTIFICATION: Denotes a notification message sent by clients. Does not require a
response. Unlike the synchronous request-response pattern, notifications may
be pipelined (*i.e.*, multiple notifications in-flight simultaneously) for
increased throughput.
"""
REQUEST = 0
RESPONSE = 1
NOTIFICATION = 2
Segments = tuple[list[bytes], Any]
NodeType = TypeVar('NodeType', bound='Node')
EndpointType = TypeVar('EndpointType', bound='Endpoint')
SocketOptions = dict[int, Union[int, bytes]]
[docs]def get_logger(*factory_args: Any, **context: Any) -> structlog.stdlib.AsyncBoundLogger:
"""Get an unbound async-compatible logger."""
logger = structlog.get_logger(
*factory_args,
**context,
wrapper_class=structlog.stdlib.AsyncBoundLogger,
)
return typing.cast(structlog.stdlib.AsyncBoundLogger, logger)
[docs]@dataclass # type: ignore[misc]
class Node(abc.ABC): # https://github.com/python/mypy/issues/5374
"""A transceiver of discrete binary messages.
A node wraps an underlying transport, such as a UDP endpoint, that it can repeatedly
open, close, and reopen. :class:`Node` supports the async context manager protocol
(reusable) for automatically managing the transport.
When the transport is open, a node can send to and receive messages from one or more
peers concurrently.
State Diagram::
start [-> closed]? -> open
[[-> close -> open]? [-> send]? [-> recv]? [-> closed?]]*
-> close -> end
The data segments and address a :class:`Node` sends and receives are opaque to the
node. Their format and semantics depend on the transport and :class:`Endpoint` the
node works with.
Attributes:
send_count: The number of messages sent since the transport was opened.
recv_count: The number of messages received since the transport was opened.
"""
recv_queue: asyncio.Queue[Segments] = field(
default_factory=lambda: asyncio.Queue(128),
init=False,
repr=False,
)
send_count: int = field(default=0, init=False, repr=False)
recv_count: int = field(default=0, init=False, repr=False)
async def __aenter__(self: NodeType, /) -> NodeType:
if self.closed:
await self.open()
self.send_count = self.recv_count = 0
return self
async def __aexit__(
self,
_exc_type: Optional[type[BaseException]],
_exc: Optional[BaseException],
_traceback: Optional[types.TracebackType],
/,
) -> None:
if not self.closed:
self.close()
[docs] @abc.abstractmethod
async def send(
self,
parts: list[bytes],
/,
*,
address: Optional[Any] = None,
) -> None:
"""Send a message.
Parameters:
parts: Zero or more data segments.
address: The destination's address.
Raises:
RemoteCallError: If the transport cannot send the message. May reopen the
internal transport.
"""
[docs] async def recv(self, /) -> Segments:
"""Receive a message.
Returns:
Zero or more data segments and an address, which are transport-dependent.
Raises:
RemoteCallError: If the transport cannot receive a message.
Note:
Asking the transport directly for messages may be problematic if it is
reopened or does not support concurrent waiters.
"""
if not self.can_recv:
raise RemoteCallError('transport does not support recv')
item = await self.recv_queue.get()
self.recv_count += 1
return item
[docs] @abc.abstractmethod
async def open(self, /) -> None:
"""Open the internal transport."""
[docs] @abc.abstractmethod
def close(self, /) -> None:
"""Close the internal transport."""
@property
@abc.abstractmethod
def closed(self, /) -> bool:
"""Whether the internal transport is closed."""
@property
@abc.abstractmethod
def can_recv(self, /) -> bool:
"""Whether the transport can receive messages."""
@contextlib.asynccontextmanager
async def _maybe_reopen(
self,
/,
*exc_types: type[Exception],
) -> AsyncIterator[None]:
"""An async context manager for reopening the transport when an error occurs.
Parameters:
exc_types: Exception types to catch. If none are given, defaults to
:class:`Exception`.
Raises:
RemoteCallError: If the transport is reopened.
"""
if self.closed:
raise RemoteCallError('transport is closed')
exc_types = exc_types or (Exception,)
try:
yield
except exc_types as exc:
self.close()
await self.open()
raise RemoteCallError('node transport reopened') from exc
SocketOptionType = tuple[int, int, Union[int, bytes]]
[docs]@dataclass
class DatagramNode(Node, asyncio.DatagramProtocol):
"""A wrapper around :mod:`asyncio`'s datagram support.
Parameters:
host: Hostname.
port: Port number.
bind: Whether to bind the socket to a local address or connect to a remote one.
options: Socket options in the form `(level, option, value)` passed to
:meth:`socket.socket.setsockopt`.
"""
host: str = ''
port: int = 8000
bind: bool = True
options: Collection[SocketOptionType] = frozenset()
transport: Optional[asyncio.DatagramTransport] = field(
default=None,
init=False,
repr=False,
)
def datagram_received(self, data: bytes, addr: Any, /) -> None:
with contextlib.suppress(asyncio.QueueFull):
self.recv_queue.put_nowait(([data], addr))
def connection_lost(self, exc: Optional[Exception], /) -> None:
self.close()
async def send(
self,
parts: list[bytes],
/,
*,
address: Optional[tuple[str, int]] = None,
) -> None:
if not self.transport:
raise RemoteCallError('transport is not yet open')
async with self._maybe_reopen():
for part in parts:
self.transport.sendto(part, addr=address)
self.send_count += 1
async def open(self, /) -> None:
loop = asyncio.get_running_loop()
kwargs: dict[str, Any] = {
('local_addr' if self.bind else 'remote_addr'): (self.host, self.port),
'reuse_port': True,
'family': socket.AF_INET,
}
transport, _ = await loop.create_datagram_endpoint(lambda: self, **kwargs)
self.transport = typing.cast(asyncio.DatagramTransport, transport)
sock = self.transport.get_extra_info('socket')
for level, option, value in self.options:
sock.setsockopt(level, option, value)
def close(self, /) -> None:
if self.transport:
self.transport.close()
@property
def closed(self, /) -> bool:
return self.transport.is_closing() if self.transport else True
@property
def can_recv(self, /) -> bool:
return True
[docs] @classmethod
def from_address(
cls,
/,
address: str,
*,
bind: bool = True,
options: Collection[SocketOptionType] = frozenset(),
) -> 'DatagramNode':
"""Build a datagram node from an address.
Parameters:
address: The address to parse, in the form ``udp://[hostname[:port]]``.
bind: True if this is a local address (socket is bound). False for a remote
address (socket connects).
options: Socket options passed to :class:`DatagramNode`.
Raises:
ValueError: If the address is not a valid UDP address.
"""
components = urlsplit(address)
if components.scheme != 'udp' or not components.hostname or not components.port:
raise ValueError('must provide a UDP address')
return DatagramNode(
host=components.hostname,
port=components.port,
bind=bind,
options=options,
)
[docs]@dataclass
class SocketNode(Node):
"""A wrapper around :class:`zmq.asyncio.Socket` for handling timeouts.
When the underlying socket of a :class:`SocketNode` times out, the socket is closed
and rebuilt to reset socket's internal state. For example, a ``REQ`` socket may
become stuck in the "listening" state indefinitely if the message it sent gets lost.
Parameters:
socket_type: The socket type (a constant defined under :mod:`zmq`).
subscriptions: A set of topics to subscribe to (for ``SUB`` sockets only).
options: A mapping of `ZMQ socket option symbols
<http://api.zeromq.org/4-3:zmq-setsockopt>`_ to their values.
connections: A set of addresses to connect to.
bindings: A set of addresses to bind to.
"""
socket_type: int = zmq.DEALER
options: SocketOptions = field(default_factory=dict)
bindings: frozenset[str] = frozenset()
connections: frozenset[str] = frozenset()
subscriptions: set[str] = field(default_factory=set)
socket: zmq.asyncio.Socket = field(init=False, repr=False)
recv_task: asyncio.Future[NoReturn] = field(
default_factory=lambda: asyncio.get_running_loop().create_future(),
init=False,
repr=False,
)
def __post_init__(self, /) -> None:
# TODO: remove this type coercion
for attr in ('bindings', 'connections', 'subscriptions'):
value = getattr(self, attr)
if isinstance(value, str):
setattr(self, attr, {value})
self.bindings = frozenset(self.bindings)
self.connections = frozenset(self.connections)
if self.socket_type == zmq.SUB and not self.subscriptions:
self.subscriptions.add('')
if self.socket_type == zmq.DEALER:
self.options.setdefault(zmq.PROBE_ROUTER, 1)
if self.socket_type == zmq.ROUTER:
self.options.setdefault(zmq.ROUTER_HANDOVER, 1)
@property
def identity(self, /) -> bytes:
"""The ZMQ identity of this socket."""
ident = self.options.get(zmq.IDENTITY)
return ident if isinstance(ident, bytes) else b'(anonymous)'
async def send(
self,
parts: list[bytes],
/,
*,
address: Optional[bytes] = None,
) -> None:
if not address:
raise RemoteCallError('must provide an address')
async with self._maybe_reopen(zmq.error.Again):
await self.socket.send_multipart([address, *parts])
self.send_count += 1
async def _recv_forever(self, /) -> NoReturn:
"""Receive messages indefinitely and enqueue them."""
while True:
with contextlib.suppress(RemoteCallError):
async with self._maybe_reopen(zmq.error.Again):
sender_id, *frames = await self.socket.recv_multipart()
if self.socket_type == zmq.SUB:
sender_id = b''
await self.recv_queue.put((list(frames), sender_id))
async def open(self, /) -> None:
ctx = zmq.asyncio.Context.instance()
self.socket = ctx.socket(self.socket_type)
for name, value in self.options.items():
self.socket.set(name, value)
for address in self.bindings:
self.socket.bind(address)
for address in self.connections:
self.socket.connect(address)
if self.socket_type == zmq.SUB:
for topic in self.subscriptions:
self.socket.subscribe(topic)
if self.can_recv:
self.recv_task = asyncio.create_task(self._recv_forever(), name='recv')
def close(self, /) -> None:
self.recv_task.cancel()
self.socket.close()
@property
def closed(self, /) -> bool:
return bool(self.socket.closed) if getattr(self, 'socket', None) else True
@property
def can_recv(self, /) -> bool:
return self.socket_type != zmq.PUB
[docs] def subscribe(self, /, topic: str = '') -> None:
"""Subscribe to a topic (for ``zmq.SUB`` sockets only).
Parameters:
topic: The topic to subscribe to.
"""
if self.socket_type == zmq.SUB:
self.socket.subscribe(topic)
self.subscriptions.add(topic)
[docs] def unsubscribe(self, /, topic: str = '') -> None:
"""Unsubscribe from a topic (for ``zmq.SUB`` sockets only).
Parameters:
topic: The topic to unsubscribe from.
"""
if self.socket_type == zmq.SUB:
self.socket.unsubscribe(topic)
self.subscriptions.discard(topic)
[docs] def set_option(self, option: int, value: Union[int, bytes], /) -> None:
"""Set a socket option.
Parameters:
option: A socket option symbol defined by :mod:`zmq`.
value: The option value (the type/format depends on the option).
"""
self.socket.set(option, value)
self.options[option] = value
def _check_type(node: Node, /, *allowed_types: int) -> None:
"""Raise a :class:`RemoteCallError` if this socket type is not allowed."""
if isinstance(node, SocketNode) and node.socket_type not in allowed_types:
raise RemoteCallError(
'socket type not allowed',
socket_type=node.socket_type,
allowed_types=allowed_types,
)
async def _encode(obj: Any, /) -> bytes:
"""Encode an object as a CBOR-encoded buffer in the default executor.
Raises:
cbor2.CBOREncodeError: If the encoding fails.
"""
return await asyncio.to_thread(cbor2.dumps, obj)
async def _decode(buf: bytes, /) -> Any:
"""Decode a CBOR-encoded buffer in the default executor.
Raises:
cbor2.CBORDecodeError: If the decoding fails.
"""
return await asyncio.to_thread(cbor2.loads, buf)
Method = Callable[..., Any]
class RemoteMethod(Protocol):
"""A remotely callable method (any signature, any return value)."""
__remote__: str
def __call__(self, /, *args: Any, **kwargs: Any) -> Any:
...
@typing.overload
def route(method_or_name: str, /) -> Callable[[Method], RemoteMethod]:
...
@typing.overload
def route(method_or_name: Method, /) -> RemoteMethod:
...
[docs]def route(
method_or_name: Union[str, Method],
/,
) -> Union[RemoteMethod, Callable[[Method], RemoteMethod]]:
"""Decorator for marking a bound method as an RPC target.
Parameters:
method_or_name: Either the method to be registered or the name it should be
registered under. If the former, the method name is exposed to the
:class:`Handler`. The latter is useful for exposing a name that is not a
valid Python identifier.
Returns:
Either an identity decorator (if a name was provided) or the method provided.
"""
if isinstance(method_or_name, str):
def decorator(method: Callable[..., Any]) -> RemoteMethod:
remote_method = typing.cast(RemoteMethod, method)
remote_method.__remote__ = typing.cast(str, method_or_name)
return remote_method
return decorator
remote_method = typing.cast(RemoteMethod, method_or_name)
remote_method.__remote__ = method_or_name.__name__
return remote_method
[docs]@dataclass # type: ignore[misc]
class Endpoint(abc.ABC): # https://github.com/python/mypy/issues/5374
"""A source or destination of messages.
An :class:`Endpoint` has a number of workers (instances of :class:`asyncio.Task`)
that listen for and process incoming messages. This allows for request pipelining.
Once all workers are busy processing messages, the node wrapped by the endpoint
buffers any additional messages.
Parameters:
node: The message transceiver. Not all node/endpoint pairs are compatible.
concurrency: The number of workers.
logger: A logger instance.
"""
node: Node
concurrency: int = 1
logger: structlog.stdlib.AsyncBoundLogger = field(default_factory=get_logger)
stack: contextlib.AsyncExitStack = field(
default_factory=contextlib.AsyncExitStack,
init=False,
repr=False,
)
def __post_init__(self, /) -> None:
if self.concurrency < 0:
raise ValueError('concurrency must be a positive integer')
async def __aenter__(self: EndpointType, /) -> EndpointType:
await self.stack.__aenter__()
self.node = await self.stack.enter_async_context(self.node)
for _ in range(self.concurrency):
worker = asyncio.create_task(self._process_forever(), name='process-msg')
self.stack.callback(worker.cancel)
return self
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc: Optional[BaseException],
traceback: Optional[types.TracebackType],
/,
) -> Optional[bool]:
return await self.stack.__aexit__(exc_type, exc, traceback)
async def _process_forever(self, /, *, cooldown: float = 0.01) -> NoReturn:
"""Receive messages indefinitely and process them."""
logger = self.logger.bind()
while True:
try:
frames, address = await self.node.recv()
payload, *_ = frames
message_type, *message = await _decode(payload)
message_type = MessageType(message_type)
await logger.debug(
'Endpoint received message', message_type=message_type.name
)
await self.handle_message(address, message_type, *message)
except (ValueError, cbor2.CBORDecodeError, RemoteCallError) as exc:
await logger.error('Endpoint failed to process message', exc_info=exc)
await asyncio.sleep(cooldown)
[docs] @abc.abstractmethod
async def handle_message(
self, address: Any, message_type: MessageType, *message: Any
) -> None:
"""Process a message.
Parameters:
address: The address of the message's sender, if available. The semantics
depend on the node. Pass this argument directly to :meth:`Node.send`.
message_type: The message type.
message: Other message parts. The message type determines the format.
Raises:
ValueError: If the endpoint could not unpack part of the message.
RemoteCallError: If the endpoint could not otherwise process the message.
"""
ResponseType = TypeVar('ResponseType')
[docs]@dataclass
class RequestTracker(Generic[ResponseType]):
"""Track outstanding requests and their results.
Every request is associated with a unique request ID (an integer, or an object
serializable as an integer).
Parameters:
futures: A mapping from request IDs to futures representing responses.
lower: Minimum valid request ID.
upper: Maximum valid request ID.
"""
futures: MutableMapping[int, asyncio.Future[ResponseType]] = field(
default_factory=dict,
)
lower: int = 0
upper: int = (1 << 32) - 1
def _try_generate_id(self, /) -> int:
"""Attempt to generate a request ID.
Unlike :meth:`generate_uid`, the candidate ID does not need to be unique.
"""
return random.randint(self.lower, self.upper)
[docs] def generate_uid(self, /, *, attempts: int = 10) -> int:
"""Generate a unique request ID.
Parameters:
attempts: The maximum number of times to try to generate an ID.
Raises:
ValueError: If the tracker could not generate a unique ID. If the ID space
is sufficiently large, this error is exceedingly rare. Increasing the
number of attempts or decreasing the number of in-flight requests should
increase the probability of a unique ID.
"""
for _ in range(attempts):
request_id = self._try_generate_id()
if request_id not in self.futures:
return request_id
raise ValueError('unable to generate a request ID')
[docs] @contextlib.contextmanager
def new_request(
self,
/,
request_id: Optional[int] = None,
) -> Iterator[tuple[int, asyncio.Future[ResponseType]]]:
"""Register a new request.
Parameters:
request_id: A unique request identifier. If not provided, a request ID is
randomly generated.
Returns:
The request ID and a future representing the outcome of the request.
"""
if request_id is None:
request_id = self.generate_uid()
elif request_id in self.futures:
raise ValueError('request ID already exists')
self.futures[request_id] = asyncio.get_running_loop().create_future()
try:
yield request_id, self.futures[request_id]
finally:
self.futures.pop(request_id, None)
[docs] def register_response(
self,
/,
request_id: int,
result: Union[BaseException, ResponseType],
) -> None:
"""Register a request's response.
Parameters:
request_id: The request identifier returned from :meth:`new_request`.
result: The response or exception.
"""
future = self.futures[request_id]
if isinstance(result, BaseException):
future.set_exception(result)
else:
future.set_result(result)
Call = Callable[..., Awaitable[Any]]
@dataclass
class CallFactory:
"""
A wrapper class around the call factory.
This wrapper uses currying to partially complete the argument list to
:meth:`Client.issue_call`.
"""
issue_call: Call
cached_partial: Callable[[str], Call] = field(init=False, repr=False)
def __post_init__(self) -> None:
make_cached = functools.lru_cache(maxsize=128)
self.cached_partial: Callable[[str], Call] = make_cached(self._partial)
def _partial(self, method: str) -> Call:
return functools.partial(self.issue_call, method)
def __getitem__(self, method: str) -> Call:
return self.cached_partial(method)
def __getattr__(self, method: str) -> Call:
return self.cached_partial(method)
[docs]@dataclass
class Client(Endpoint):
"""An endpoint for issuing remote calls.
A request is matched to its response with a message ID, a 32-bit integer unique
among in-flight requests at any given time.
Parameters:
requests: Stores in-flight message IDs to futures. Each future represents the
outcome of a call (a result or an exception).
node: A node for transporting messages.
concurrency: The number of workers for processing responses.
"""
requests: RequestTracker[Any] = field(default_factory=RequestTracker)
def __post_init__(self, /) -> None:
_check_type(self.node, zmq.PUB, zmq.DEALER)
if not self.node.can_recv:
self.concurrency = 0
super().__post_init__()
async def handle_message(
self,
address: Any,
message_type: MessageType,
/,
*message: Any,
) -> None:
if message_type is not MessageType.RESPONSE:
raise RemoteCallError(
'client only receives RESPONSE messages',
message_type=message_type,
message_parts=message,
)
message_id, error, result = message
if isinstance(error, list):
error_message, context = error
result = RemoteCallError(error_message, **context)
try:
self.requests.register_response(message_id, result)
except KeyError as exc:
raise RemoteCallError(
'client received unexpected response',
message_id=message_id,
) from exc
[docs] async def issue_call(
self,
method: str,
/,
*args: Any,
address: Optional[Any] = None,
notification: bool = False,
timeout: float = 5,
) -> Any:
"""Issue a remote procedure call and possibly wait for the result.
Parameters:
method: Method name.
args: Method arguments.
address: A transport-dependent address.
notification: False iff this call requires a response. Has no effect for
nodes that cannot receive data, which can *only* send notifications.
timeout: Maximum duration (in seconds) to wait for a response.
Raises:
asyncio.TimeoutError: The request was successfully sent, but the response
never arrived in time.
ValueError: If the request tracker could not generate a unique message ID.
RemoteCallError: If the service returned an error.
cbor2.CBOREncodeError: If the arguments were not serializable.
Note:
Notification calls will not raise an exception client-side if the server
fails, even if the node supports duplex communication.
"""
if not self.node.can_recv:
notification = True
if isinstance(self.node, SocketNode) and self.node.socket_type == zmq.PUB:
address = address or method.encode()
await self.logger.debug(
'Issuing remote procedure call',
method=method,
notification=notification,
)
if notification:
request = [MessageType.NOTIFICATION.value, method, args]
await self.node.send([await _encode(request)], address=address)
else:
with self.requests.new_request() as (message_id, result):
request = [MessageType.REQUEST.value, message_id, method, args]
await self.node.send([await _encode(request)], address=address)
return await asyncio.wait_for(result, timeout)
@functools.cached_property
def call(self, /) -> CallFactory:
"""Syntactic sugar for issuing remote procedure calls.
Instead of::
await client.issue_call('add', 1, 2)
Replace with either of::
await client.call.add(1, 2)
await client.call['add'](1, 2)
"""
return CallFactory(self.issue_call)
[docs]class Handler:
"""An object whose bound methods are exposed to remote callers.
Define a handler by subclassing :class:`Handler` and applying the :func:`route`
decorator:
>>> class CustomHandler(Handler):
... @route
... async def method1(self, arg: int) -> int:
... ...
... @route('non-python-identifier')
... def method2(self):
... ...
"""
@functools.cached_property
def _method_table(self) -> dict[str, types.MethodType]:
"""A mapping of method names to (possibly coroutine) bound methods."""
# Need to use the class to avoid calling `getattr(...)` on this property.
# Accessing bound methods directly can lead to infinite recursion.
funcs = inspect.getmembers(self.__class__, inspect.isfunction)
funcs = [(attr, func) for attr, func in funcs if hasattr(func, '__remote__')]
return {func.__remote__: getattr(self, attr) for attr, func in funcs}
[docs] async def dispatch(self, method: str, *args: Any, timeout: float = 30) -> Any:
"""Dispatch a remote procedure call.
If the method is synchronous (possibly blocking), the default executor performs
the call.
Parameters:
method: The procedure name.
args: Positional arguments for the procedure.
Returns:
The procedure's result, which must be CBOR-serializable.
Raises:
RemoteCallError: The procedure call does not exist, timed out, or raised an
exception.
"""
func = self._method_table.get(method)
if not func:
raise RemoteCallError('no such method exists', method=method)
try:
if inspect.iscoroutinefunction(func):
call = func(*args)
else:
call = asyncio.to_thread(func, *args)
return await asyncio.wait_for(call, timeout)
except asyncio.TimeoutError as exc:
raise RemoteCallError('method timed out', timeout=timeout) from exc
except Exception as exc:
raise RemoteCallError('method produced an error') from exc
[docs]@dataclass
class Service(Endpoint):
"""Responds to RPC requests.
Parameters:
handler: The object whose bound methods this service will call.
timeout: Maximum duration (in seconds) to execute methods for.
"""
handler: Handler = field(default_factory=Handler)
timeout: float = 30
def __post_init__(self) -> None:
_check_type(self.node, zmq.SUB, zmq.DEALER)
async def handle_message(
self, address: Any, message_type: MessageType, *message: Any
) -> None:
if message_type is MessageType.REQUEST:
message_id, method, args = message
elif message_type is MessageType.NOTIFICATION:
message_id, (method, args) = None, message
else:
await self.logger.warn(
'Service does not support message',
message_type=message_type.name,
)
return
try:
result = await self.handler.dispatch(method, *args, timeout=self.timeout)
error = None
except RemoteCallError as exc:
result, error = None, [str(exc), exc.context]
await self.logger.error(
'Service was unable to execute call',
message_type=message_type.name,
message_id=message_id,
exc_info=exc,
)
if message_type is MessageType.REQUEST:
response = [MessageType.RESPONSE.value, message_id, error, result]
await self.node.send([await _encode(response)], address=address)
def _render_id(identity: bytes) -> str:
with contextlib.suppress(UnicodeDecodeError):
decoded = identity.decode()
if decoded.isprintable():
return decoded
return identity.hex()
async def _route(recv_socket: SocketNode, send_socket: SocketNode) -> NoReturn:
"""Route messages in one direction.
A :class:`Router` is duplex, but the frame format and implementation for each
direction are identical.
Parameters:
recv_socket: The receiving socket, which indefinitely reads five-frame
messages consisting of the sender's ZMQ identity, the recipient's
identity, and the payload, with empty delimeter frames sandwiched
between them.
send_socket: The sending socket, which simply transposes the
sender/recipient ID frames.
"""
logger = get_logger().bind(
recv_socket=_render_id(recv_socket.identity),
send_socket=_render_id(send_socket.identity),
)
await logger.info('Router started')
while True:
try:
frames, sender_id = await recv_socket.recv()
if frames == [b'']:
await logger.info(
'Router connected to endpoint',
sender_id=_render_id(sender_id),
)
continue
recipient_id, payload = frames
await logger.debug(
'Router received message',
sender_id=_render_id(sender_id),
recipient_id=_render_id(recipient_id),
)
if sender_id == recipient_id:
await logger.warn(
'Loopback not allowed',
sender_id=_render_id(sender_id),
)
continue
await send_socket.send([sender_id, payload], address=recipient_id)
except (ValueError, RemoteCallError) as exc:
await logger.error('Router failed to route message', exc_info=exc)
[docs]@dataclass
class Router:
"""Routes messages between :class:`Client`s and :class:`Service`s that use sockets.
Routers are stateless, duplex, and symmetric (*i.e.*, require the same format and
exhibit the same behavior on both ends).
Routers have no error handling and may silently drop messages if the destination is
unreachable. Clients must rely on timeouts to determine when to consider a request
failed.
The payloads themselves are opaque to the router and are not deserialized.
Parameters:
frontend: A ``ROUTER`` socket clients connect to.
backend: A ``ROUTER`` socket services connect to.
route_task: The background task performing the routing. :class:`Router`
implements the async context manager protocol, which automatically schedules
and cancels this task.
"""
frontend: SocketNode
backend: SocketNode
route_task: asyncio.Future[tuple[NoReturn, NoReturn]] = field(
default_factory=lambda: asyncio.get_running_loop().create_future(),
init=False,
repr=False,
)
def __post_init__(self, /) -> None:
_check_type(self.frontend, zmq.ROUTER)
_check_type(self.backend, zmq.ROUTER)
async def __aenter__(self, /) -> 'Router':
await self.frontend.__aenter__()
await self.backend.__aenter__()
self.route_task = asyncio.gather(
asyncio.create_task(_route(self.frontend, self.backend), name='route-req'),
asyncio.create_task(_route(self.backend, self.frontend), name='route-res'),
)
return self
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc: Optional[BaseException],
traceback: Optional[types.TracebackType],
) -> None:
self.route_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self.route_task
await self.frontend.__aexit__(exc_type, exc, traceback)
await self.backend.__aexit__(exc_type, exc, traceback)
[docs] @classmethod
def bind(
cls,
frontend: Collection[str],
backend: Collection[str],
frontend_options: Optional[SocketOptions] = None,
backend_options: Optional[SocketOptions] = None,
) -> 'Router':
"""Construct a :class:`Router` bound to the provided addresses."""
# pylint: disable=unexpected-keyword-arg; dataclass not recognized
frontend_options = frontend_options or {}
backend_options = backend_options or {}
frontend_options.setdefault(zmq.IDENTITY, b'router-frontend')
backend_options.setdefault(zmq.IDENTITY, b'router-backend')
return Router(
SocketNode(
socket_type=zmq.ROUTER,
bindings=frozenset(frontend),
options=frontend_options,
),
SocketNode(
socket_type=zmq.ROUTER,
bindings=frozenset(backend),
options=backend_options,
),
)