import socket from abc import abstractmethod from io import IOBase from ipaddress import IPv4Address, IPv6Address from socket import AddressFamily from types import TracebackType from typing import ( Any, AsyncContextManager, Callable, Collection, Dict, List, Mapping, Optional, Tuple, Type, TypeVar, Union) from .._core._typedattr import TypedAttributeProvider, TypedAttributeSet, typed_attribute from ._streams import ByteStream, Listener, T_Stream, UnreliableObjectStream from ._tasks import TaskGroup IPAddressType = Union[str, IPv4Address, IPv6Address] IPSockAddrType = Tuple[str, int] SockAddrType = Union[IPSockAddrType, str] UDPPacketType = Tuple[bytes, IPSockAddrType] T_Retval = TypeVar('T_Retval') class _NullAsyncContextManager: async def __aenter__(self) -> None: pass async def __aexit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]) -> Optional[bool]: return None class SocketAttribute(TypedAttributeSet): #: the address family of the underlying socket family: AddressFamily = typed_attribute() #: the local socket address of the underlying socket local_address: SockAddrType = typed_attribute() #: for IP addresses, the local port the underlying socket is bound to local_port: int = typed_attribute() #: the underlying stdlib socket object raw_socket: socket.socket = typed_attribute() #: the remote address the underlying socket is connected to remote_address: SockAddrType = typed_attribute() #: for IP addresses, the remote port the underlying socket is connected to remote_port: int = typed_attribute() class _SocketProvider(TypedAttributeProvider): @property def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: from .._core._sockets import convert_ipv6_sockaddr as convert attributes: Dict[Any, Callable[[], Any]] = { SocketAttribute.family: lambda: self._raw_socket.family, SocketAttribute.local_address: lambda: convert(self._raw_socket.getsockname()), SocketAttribute.raw_socket: lambda: self._raw_socket } try: peername: Optional[Tuple[str, int]] = convert(self._raw_socket.getpeername()) except OSError: peername = None # Provide the remote address for connected sockets if peername is not None: attributes[SocketAttribute.remote_address] = lambda: peername # Provide local and remote ports for IP based sockets if self._raw_socket.family in (AddressFamily.AF_INET, AddressFamily.AF_INET6): attributes[SocketAttribute.local_port] = lambda: self._raw_socket.getsockname()[1] if peername is not None: remote_port = peername[1] attributes[SocketAttribute.remote_port] = lambda: remote_port return attributes @property @abstractmethod def _raw_socket(self) -> socket.socket: pass class SocketStream(ByteStream, _SocketProvider): """ Transports bytes over a socket. Supports all relevant extra attributes from :class:`~SocketAttribute`. """ class UNIXSocketStream(SocketStream): @abstractmethod async def send_fds(self, message: bytes, fds: Collection[Union[int, IOBase]]) -> None: """ Send file descriptors along with a message to the peer. :param message: a non-empty bytestring :param fds: a collection of files (either numeric file descriptors or open file or socket objects) """ @abstractmethod async def receive_fds(self, msglen: int, maxfds: int) -> Tuple[bytes, List[int]]: """ Receive file descriptors along with a message from the peer. :param msglen: length of the message to expect from the peer :param maxfds: maximum number of file descriptors to expect from the peer :return: a tuple of (message, file descriptors) """ class SocketListener(Listener[SocketStream], _SocketProvider): """ Listens to incoming socket connections. Supports all relevant extra attributes from :class:`~SocketAttribute`. """ @abstractmethod async def accept(self) -> SocketStream: """Accept an incoming connection.""" async def serve(self, handler: Callable[[T_Stream], Any], task_group: Optional[TaskGroup] = None) -> None: from .. import create_task_group context_manager: AsyncContextManager if task_group is None: task_group = context_manager = create_task_group() else: # Can be replaced with AsyncExitStack once on py3.7+ context_manager = _NullAsyncContextManager() async with context_manager: while True: stream = await self.accept() task_group.start_soon(handler, stream) class UDPSocket(UnreliableObjectStream[UDPPacketType], _SocketProvider): """ Represents an unconnected UDP socket. Supports all relevant extra attributes from :class:`~SocketAttribute`. """ async def sendto(self, data: bytes, host: str, port: int) -> None: """Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, (host, port))).""" return await self.send((data, (host, port))) class ConnectedUDPSocket(UnreliableObjectStream[bytes], _SocketProvider): """ Represents an connected UDP socket. Supports all relevant extra attributes from :class:`~SocketAttribute`. """