import asyncio
import logging
import itertools
import contextlib
from types import MappingProxyType
from functools import cached_property
from . import config, transport
logger = logging.getLogger(__name__)
[docs]class Network:
"""Maintains connections over the raft network as well as queues for messages"""
def __init__(self, peer_id):
self.peer_id = peer_id
self._active_connections = {}
self.inbox = asyncio.Queue()
self._active_ids = self._active_connections.keys()
self._all_ids = frozenset(config.SERVERS) - {peer_id}
# have a queue per peer in the network
self.outbox = MappingProxyType({i: asyncio.Queue() for i in self._all_ids})
[docs] def send(self, dest: int, msg):
self.outbox[dest].put_nowait(msg)
async def _send_messages(self):
suppresser = contextlib.suppress(KeyError)
for peer, queue in itertools.cycle(self.outbox.items()):
try:
msg = queue.get_nowait()
except asyncio.QueueEmpty: # no message here, come back later
await asyncio.sleep(.1)
else:
queue.task_done()
with suppresser:
reader, writer = self._active_connections[peer]
logger.debug(f"Sending to {peer=} {msg=}")
await transport.send(writer, msg)
@cached_property
def peers(self) -> frozenset:
return self._all_ids
[docs] async def recv(self) -> bytes:
return await self.inbox.get()
async def _on_connetion(self, reader, writer):
logger.info(f"Handling client: {writer.get_extra_info('peername')}")
connected_peer = await transport.recv(reader) # first message is always peer_id
logger.debug(f"Starting to receive from {connected_peer=}")
with contextlib.suppress(asyncio.IncompleteReadError):
while message:= await transport.recv(reader):
await self.inbox.put(message)
logger.warning(f"Broken peer {connected_peer}. Closing connection and clearing queue.")
writer.close()
await writer.wait_closed()
# remove all that was going to be sent
peer_q = self.outbox[connected_peer]
while peer_q.qsize():
peer_q.get_nowait()
peer_q.task_done()
logger.warning(f"Queue {peer_q} empty. Resurrecting connection for {connected_peer}")
del self._active_connections[connected_peer]
awaitable = self._connect(connected_peer)
asyncio.create_task(awaitable) # resurrect connection and hope for the best
async def _connect(self, target_peer):
"""Connect to a target_peer by its target_peer ID"""
host, port = config.SERVERS[target_peer]
logger.debug(f"Attempting to connect to: {target_peer=}, {host=}, {port=}")
timeout = 5
while True:
awaitable = asyncio.open_connection(host, port)
try:
reader, writer = await asyncio.wait_for(awaitable, timeout)
break
except (asyncio.TimeoutError, ConnectionRefusedError):
await asyncio.sleep(1)
# register ourselves on source connection
await transport.send(writer, self.peer_id)
self._active_connections[target_peer] = (reader, writer)
logger.debug(f"Connected to: {target_peer=}, {host=}, {port=}")
async def _connect_to_peers(self):
"""Attempt to connect to any missing peers"""
while missing_peers := self._all_ids - self._active_ids:
logger.warning(f"{missing_peers=}")
await asyncio.gather(*(self._connect(peer) for peer in missing_peers))
[docs] async def start(self):
host, port = config.SERVERS[self.peer_id]
server = await asyncio.start_server(self._on_connetion, host=host, port=port)
addr = server.sockets[0].getsockname()
logger.info(f'Serving on {addr}')
async with server:
await asyncio.gather(
self._connect_to_peers(),
self._send_messages(),
server.serve_forever(),
)
if __name__ == "__main__":
import argparse
logger.setLevel(logging.INFO)
transport.logger.setLevel(logging.INFO) # tmp: debug too verbose for this module
parser = argparse.ArgumentParser(description='Start server arguments.')
parser.add_argument('node', type=int, help='Server node to start')
parsedargs = parser.parse_args()
from datetime import datetime
async def main(peer_id):
# Have every server send a message to every other server
async def consume():
while msg:= await net.recv():
logger.critical(msg)
async def broadcast():
while await asyncio.sleep(peer_id, result=True):
for peer in net.peers:
net.send(peer, f"Hello from: {peer_id}, {datetime.now()}")
net = Network(peer_id)
await asyncio.gather(
net.start(),
broadcast(),
consume(),
)
asyncio.run(main(parsedargs.node))