import array import gc import io import os import platform import socket import sys import threading import time from contextlib import suppress from pathlib import Path from socket import AddressFamily from ssl import SSLContext, SSLError from threading import Thread from typing import Any, Iterable, Iterator, List, NoReturn, Tuple, Type, TypeVar, Union, cast import pytest from _pytest.fixtures import SubRequest from _pytest.logging import LogCaptureFixture from _pytest.monkeypatch import MonkeyPatch from _pytest.tmpdir import TempPathFactory from anyio import ( BrokenResourceError, BusyResourceError, ClosedResourceError, Event, ExceptionGroup, TypedAttributeLookupError, connect_tcp, connect_unix, create_connected_udp_socket, create_task_group, create_tcp_listener, create_udp_socket, create_unix_listener, fail_after, getaddrinfo, getnameinfo, move_on_after, sleep, wait_all_tasks_blocked) from anyio.abc import IPSockAddrType, Listener, SocketAttribute, SocketListener, SocketStream from anyio.streams.stapled import MultiListener if sys.version_info >= (3, 8): from typing import Literal else: from typing_extensions import Literal AnyIPAddressFamily = Literal[AddressFamily.AF_UNSPEC, AddressFamily.AF_INET, AddressFamily.AF_INET6] pytestmark = pytest.mark.anyio # If a socket can bind to ::1, the current environment has IPv6 properly configured has_ipv6 = False if socket.has_ipv6: try: s = socket.socket(AddressFamily.AF_INET6) try: s.bind(('::1', 0)) finally: s.close() del s except OSError: pass else: has_ipv6 = True @pytest.fixture def fake_localhost_dns(monkeypatch: MonkeyPatch) -> None: def fake_getaddrinfo(*args: Any, **kwargs: Any) -> object: # Make it return IPv4 addresses first so we can test the IPv6 preference results = real_getaddrinfo(*args, **kwargs) return sorted(results, key=lambda item: item[0]) real_getaddrinfo = socket.getaddrinfo monkeypatch.setattr('socket.getaddrinfo', fake_getaddrinfo) @pytest.fixture(params=[ pytest.param(AddressFamily.AF_INET, id='ipv4'), pytest.param(AddressFamily.AF_INET6, id='ipv6', marks=[pytest.mark.skipif(not has_ipv6, reason='no IPv6 support')]) ]) def family(request: SubRequest) -> AnyIPAddressFamily: return request.param @pytest.fixture def check_asyncio_bug(anyio_backend_name: str, family: AnyIPAddressFamily) -> None: if ( anyio_backend_name == 'asyncio' and sys.platform == 'win32' and family == AddressFamily.AF_INET6 ): import asyncio policy = asyncio.get_event_loop_policy() if policy.__class__.__name__ == 'WindowsProactorEventLoopPolicy': pytest.skip('Does not work due to a known bug (39148)') _T = TypeVar("_T") def _identity(v: _T) -> _T: return v # _ProactorBasePipeTransport.abort() after _ProactorBasePipeTransport.close() # does not cancel writes: https://bugs.python.org/issue44428 _ignore_win32_resource_warnings = pytest.mark.filterwarnings( "ignore:unclosed Iterator[socket.socket]: sock = socket.socket(family, socket.SOCK_STREAM) sock.settimeout(1) sock.bind(('localhost', 0)) sock.listen() yield sock sock.close() @pytest.fixture def server_addr(self, server_sock: socket.socket) -> Tuple[str, int]: return server_sock.getsockname()[:2] async def test_extra_attributes(self, server_sock: socket.socket, server_addr: Tuple[str, int], family: AnyIPAddressFamily) -> None: async with await connect_tcp(*server_addr) as stream: raw_socket = stream.extra(SocketAttribute.raw_socket) assert stream.extra(SocketAttribute.family) == family assert stream.extra(SocketAttribute.local_address) == raw_socket.getsockname()[:2] assert stream.extra(SocketAttribute.local_port) == raw_socket.getsockname()[1] assert stream.extra(SocketAttribute.remote_address) == server_addr assert stream.extra(SocketAttribute.remote_port) == server_addr[1] async def test_send_receive(self, server_sock: socket.socket, server_addr: Tuple[str, int]) -> None: async with await connect_tcp(*server_addr) as stream: client, _ = server_sock.accept() await stream.send(b'blah') request = client.recv(100) client.sendall(request[::-1]) response = await stream.receive() client.close() assert response == b'halb' async def test_send_large_buffer(self, server_sock: socket.socket, server_addr: Tuple[str, int]) -> None: def serve() -> None: client, _ = server_sock.accept() client.sendall(buffer) client.close() buffer = b'\xff' * 1024 * 1024 # should exceed the maximum kernel send buffer size async with await connect_tcp(*server_addr) as stream: thread = Thread(target=serve, daemon=True) thread.start() response = b'' while len(response) < len(buffer): response += await stream.receive() thread.join() assert response == buffer async def test_send_eof(self, server_sock: socket.socket, server_addr: Tuple[str, int]) -> None: def serve() -> None: client, _ = server_sock.accept() request = b'' while True: data = client.recv(100) request += data if not data: break client.sendall(request[::-1]) client.close() async with await connect_tcp(*server_addr) as stream: thread = Thread(target=serve, daemon=True) thread.start() await stream.send(b'hello, ') await stream.send(b'world\n') await stream.send_eof() response = await stream.receive() thread.join() assert response == b'\ndlrow ,olleh' async def test_iterate(self, server_sock: socket.socket, server_addr: Tuple[str, int]) -> None: def serve() -> None: client, _ = server_sock.accept() client.sendall(b'bl') event.wait(1) client.sendall(b'ah') client.close() event = threading.Event() thread = Thread(target=serve, daemon=True) thread.start() chunks = [] async with await connect_tcp(*server_addr) as stream: async for chunk in stream: chunks.append(chunk) event.set() thread.join() assert chunks == [b'bl', b'ah'] async def test_socket_options(self, family: AnyIPAddressFamily, server_addr: Tuple[str, int]) -> None: async with await connect_tcp(*server_addr) as stream: raw_socket = stream.extra(SocketAttribute.raw_socket) assert raw_socket.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) != 0 @pytest.mark.skipif(not has_ipv6, reason='IPv6 is not available') @pytest.mark.parametrize('local_addr, expected_client_addr', [ pytest.param('', '::1', id='dualstack'), pytest.param('127.0.0.1', '127.0.0.1', id='ipv4'), pytest.param('::1', '::1', id='ipv6') ]) async def test_happy_eyeballs(self, local_addr: str, expected_client_addr: str, fake_localhost_dns: None) -> None: client_addr = None, None def serve() -> None: nonlocal client_addr client, client_addr = server_sock.accept() client.close() family = AddressFamily.AF_INET if local_addr == '127.0.0.1' else AddressFamily.AF_INET6 server_sock = socket.socket(family) server_sock.bind((local_addr, 0)) server_sock.listen() port = server_sock.getsockname()[1] thread = Thread(target=serve, daemon=True) thread.start() async with await connect_tcp('localhost', port): pass thread.join() server_sock.close() assert client_addr[0] == expected_client_addr @pytest.mark.parametrize('target, exception_class', [ pytest.param( 'localhost', ExceptionGroup, id='multi', marks=[pytest.mark.skipif(not has_ipv6, reason='IPv6 is not available')] ), pytest.param('127.0.0.1', ConnectionRefusedError, id='single') ]) async def test_connection_refused( self, target: str, exception_class: Union[Type[ExceptionGroup], Type[ConnectionRefusedError]], fake_localhost_dns: None, ) -> None: dummy_socket = socket.socket(AddressFamily.AF_INET6) dummy_socket.bind(('::', 0)) free_port = dummy_socket.getsockname()[1] dummy_socket.close() with pytest.raises(OSError) as exc: await connect_tcp(target, free_port) assert exc.match('All connection attempts failed') assert isinstance(exc.value.__cause__, exception_class) if isinstance(exc.value.__cause__, ExceptionGroup): for exception in exc.value.__cause__.exceptions: assert isinstance(exception, ConnectionRefusedError) async def test_receive_timeout(self, server_sock: socket.socket, server_addr: Tuple[str, int]) -> None: def serve() -> None: conn, _ = server_sock.accept() time.sleep(1) conn.close() thread = Thread(target=serve, daemon=True) thread.start() async with await connect_tcp(*server_addr) as stream: start_time = time.monotonic() with move_on_after(0.1): while time.monotonic() - start_time < 0.3: await stream.receive(1) pytest.fail('The timeout was not respected') async def test_concurrent_send(self, server_addr: Tuple[str, int]) -> None: async def send_data() -> NoReturn: while True: await stream.send(b'\x00' * 4096) async with await connect_tcp(*server_addr) as stream: async with create_task_group() as tg: tg.start_soon(send_data) await wait_all_tasks_blocked() with pytest.raises(BusyResourceError) as exc: await stream.send(b'foo') exc.match('already writing to') tg.cancel_scope.cancel() async def test_concurrent_receive(self, server_addr: Tuple[str, int]) -> None: async with await connect_tcp(*server_addr) as client: async with create_task_group() as tg: tg.start_soon(client.receive) await wait_all_tasks_blocked() try: with pytest.raises(BusyResourceError) as exc: await client.receive() exc.match('already reading from') finally: tg.cancel_scope.cancel() async def test_close_during_receive(self, server_addr: Tuple[str, int]) -> None: async def interrupt() -> None: await wait_all_tasks_blocked() await stream.aclose() async with await connect_tcp(*server_addr) as stream: async with create_task_group() as tg: tg.start_soon(interrupt) with pytest.raises(ClosedResourceError): await stream.receive() async def test_receive_after_close(self, server_addr: Tuple[str, int]) -> None: stream = await connect_tcp(*server_addr) await stream.aclose() with pytest.raises(ClosedResourceError): await stream.receive() async def test_send_after_close(self, server_addr: Tuple[str, int]) -> None: stream = await connect_tcp(*server_addr) await stream.aclose() with pytest.raises(ClosedResourceError): await stream.send(b'foo') async def test_send_after_peer_closed(self, family: AnyIPAddressFamily) -> None: def serve_once() -> None: client_sock, _ = server_sock.accept() client_sock.close() server_sock.close() server_sock = socket.socket(family, socket.SOCK_STREAM) server_sock.settimeout(1) server_sock.bind(('localhost', 0)) server_addr = server_sock.getsockname()[:2] server_sock.listen() thread = Thread(target=serve_once, daemon=True) thread.start() with pytest.raises(BrokenResourceError): async with await connect_tcp(*server_addr) as stream: for _ in range(1000): await stream.send(b'foo') thread.join() async def test_connect_tcp_with_tls(self, server_context: SSLContext, client_context: SSLContext, server_sock: socket.socket, server_addr: Tuple[str, int]) -> None: def serve() -> None: with suppress(socket.timeout): client, addr = server_sock.accept() client.settimeout(1) client = server_context.wrap_socket(client, server_side=True) data = client.recv(100) client.sendall(data[::-1]) client.unwrap() client.close() # The TLSStream tests are more comprehensive than this one! thread = Thread(target=serve, daemon=True) thread.start() async with await connect_tcp(*server_addr, tls_hostname='localhost', ssl_context=client_context) as stream: await stream.send(b'hello') response = await stream.receive() assert response == b'olleh' thread.join() async def test_connect_tcp_with_tls_cert_check_fail(self, server_context: SSLContext, server_sock: socket.socket, server_addr: Tuple[str, int]) -> None: thread_exception = None def serve() -> None: nonlocal thread_exception client, addr = server_sock.accept() with client: client.settimeout(1) try: server_context.wrap_socket(client, server_side=True) except OSError: pass except BaseException as exc: thread_exception = exc thread = Thread(target=serve, daemon=True) thread.start() with pytest.raises(SSLError): await connect_tcp(*server_addr, tls_hostname='localhost') thread.join() assert thread_exception is None @pytest.mark.parametrize('anyio_backend', ['asyncio']) async def test_unretrieved_future_exception_server_crash( self, family: AnyIPAddressFamily, caplog: LogCaptureFixture) -> None: """ Tests that there won't be any leftover Futures that don't get their exceptions retrieved. See https://github.com/encode/httpcore/issues/382 for details. """ def serve() -> None: sock, addr = server_sock.accept() event.wait(3) del sock gc.collect() server_sock = socket.socket(family, socket.SOCK_STREAM) server_sock.settimeout(1) server_sock.bind(('localhost', 0)) server_sock.listen() server_addr = server_sock.getsockname()[:2] event = threading.Event() thread = Thread(target=serve) thread.start() async with await connect_tcp(*server_addr) as stream: await stream.send(b'GET') event.set() with pytest.raises(BrokenResourceError): await stream.receive() thread.join() gc.collect() assert not caplog.text class TestTCPListener: async def test_extra_attributes(self, family: AnyIPAddressFamily) -> None: async with await create_tcp_listener(local_host='localhost', family=family) as multi: assert multi.extra(SocketAttribute.family) == family for listener in multi.listeners: raw_socket = listener.extra(SocketAttribute.raw_socket) assert listener.extra(SocketAttribute.family) == family assert listener.extra(SocketAttribute.local_address) == \ raw_socket.getsockname()[:2] assert listener.extra(SocketAttribute.local_port) == raw_socket.getsockname()[1] pytest.raises(TypedAttributeLookupError, listener.extra, SocketAttribute.remote_address) pytest.raises(TypedAttributeLookupError, listener.extra, SocketAttribute.remote_port) @pytest.mark.parametrize('family', [ pytest.param(AddressFamily.AF_INET, id='ipv4'), pytest.param(AddressFamily.AF_INET6, id='ipv6', marks=[pytest.mark.skipif(not has_ipv6, reason='no IPv6 support')]), pytest.param(socket.AF_UNSPEC, id='both', marks=[pytest.mark.skipif(not has_ipv6, reason='no IPv6 support')]) ]) async def test_accept(self, family: AnyIPAddressFamily) -> None: async with await create_tcp_listener(local_host='localhost', family=family) as multi: for listener in multi.listeners: client = socket.socket(listener.extra(SocketAttribute.family)) client.settimeout(1) client.connect(listener.extra(SocketAttribute.local_address)) assert isinstance(listener, SocketListener) stream = await listener.accept() client.sendall(b'blah') request = await stream.receive() await stream.send(request[::-1]) assert client.recv(100) == b'halb' client.close() await stream.aclose() async def test_accept_after_close(self, family: AnyIPAddressFamily) -> None: async with await create_tcp_listener(local_host='localhost', family=family) as multi: for listener in multi.listeners: await listener.aclose() assert isinstance(listener, SocketListener) with pytest.raises(ClosedResourceError): await listener.accept() async def test_socket_options(self, family: AnyIPAddressFamily) -> None: async with await create_tcp_listener(local_host='localhost', family=family) as multi: for listener in multi.listeners: raw_socket = listener.extra(SocketAttribute.raw_socket) if sys.platform == 'win32': assert raw_socket.getsockopt( socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE) != 0 else: assert raw_socket.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) != 0 raw_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 80000) assert raw_socket.getsockopt( socket.SOL_SOCKET, socket.SO_RCVBUF) in (80000, 160000) client = socket.socket(raw_socket.family) client.settimeout(1) client.connect(raw_socket.getsockname()) assert isinstance(listener, SocketListener) async with await listener.accept() as stream: raw_socket = stream.extra(SocketAttribute.raw_socket) assert raw_socket.gettimeout() == 0 assert raw_socket.family == listener.extra(SocketAttribute.family) assert raw_socket.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) != 0 client.close() @pytest.mark.skipif(not hasattr(socket, "SO_REUSEPORT"), reason='SO_REUSEPORT option not supported') async def test_reuse_port(self, family: AnyIPAddressFamily) -> None: multi1 = await create_tcp_listener(local_host='localhost', family=family, reuse_port=True) assert len(multi1.listeners) == 1 multi2 = await create_tcp_listener( local_host='localhost', local_port=multi1.listeners[0].extra(SocketAttribute.local_port), family=family, reuse_port=True) assert len(multi2.listeners) == 1 assert multi1.listeners[0].extra(SocketAttribute.local_address) == \ multi2.listeners[0].extra(SocketAttribute.local_address) await multi1.aclose() await multi2.aclose() async def test_close_from_other_task(self, family: AnyIPAddressFamily) -> None: listener = await create_tcp_listener(local_host='localhost', family=family) with pytest.raises(ClosedResourceError): async with create_task_group() as tg: tg.start_soon(listener.serve, lambda stream: None) await wait_all_tasks_blocked() await listener.aclose() tg.cancel_scope.cancel() async def test_send_after_eof(self, family: AnyIPAddressFamily) -> None: async def handle(stream: SocketStream) -> None: async with stream: await stream.send(b'Hello\n') multi = await create_tcp_listener(family=family, local_host='localhost') async with multi, create_task_group() as tg: tg.start_soon(multi.serve, handle) await wait_all_tasks_blocked() with socket.socket(family) as client: client.connect(multi.extra(SocketAttribute.local_address)) client.shutdown(socket.SHUT_WR) client.setblocking(False) with fail_after(1): while True: try: message = client.recv(10) except BlockingIOError: await sleep(0) else: assert message == b'Hello\n' break tg.cancel_scope.cancel() @pytest.mark.skipif(sys.platform == 'win32', reason='UNIX sockets are not available on Windows') class TestUNIXStream: @pytest.fixture def socket_path(self, tmp_path_factory: TempPathFactory) -> Path: return tmp_path_factory.mktemp('unix').joinpath('socket') @pytest.fixture(params=[False, True], ids=["str", "path"]) def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Union[Path, str]: return socket_path if request.param else str(socket_path) @pytest.fixture def server_sock(self, socket_path: Path) -> Iterable[socket.socket]: sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sock.settimeout(1) sock.bind(str(socket_path)) sock.listen() yield sock sock.close() async def test_extra_attributes(self, server_sock: socket.socket, socket_path: Path) -> None: async with await connect_unix(socket_path) as stream: raw_socket = stream.extra(SocketAttribute.raw_socket) assert stream.extra(SocketAttribute.family) == socket.AF_UNIX assert stream.extra(SocketAttribute.local_address) == raw_socket.getsockname() assert stream.extra(SocketAttribute.remote_address) == str(socket_path) pytest.raises(TypedAttributeLookupError, stream.extra, SocketAttribute.local_port) pytest.raises(TypedAttributeLookupError, stream.extra, SocketAttribute.remote_port) async def test_send_receive(self, server_sock: socket.socket, socket_path_or_str: Union[Path, str]) -> None: async with await connect_unix(socket_path_or_str) as stream: client, _ = server_sock.accept() await stream.send(b'blah') request = client.recv(100) client.sendall(request[::-1]) response = await stream.receive() client.close() assert response == b'halb' async def test_send_large_buffer(self, server_sock: socket.socket, socket_path: Path) -> None: def serve() -> None: client, _ = server_sock.accept() client.sendall(buffer) client.close() buffer = b'\xff' * 1024 * 1024 # should exceed the maximum kernel send buffer size async with await connect_unix(socket_path) as stream: thread = Thread(target=serve, daemon=True) thread.start() response = b'' while len(response) < len(buffer): response += await stream.receive() thread.join() assert response == buffer async def test_receive_fds(self, server_sock: socket.socket, socket_path: Path, tmp_path: Path) -> None: def serve() -> None: path1 = tmp_path / 'file1' path2 = tmp_path / 'file2' path1.write_text('Hello, ') path2.write_text('World!') with path1.open() as file1, path2.open() as file2: fdarray = array.array('i', [file1.fileno(), file2.fileno()]) client, _ = server_sock.accept() cmsg = (socket.SOL_SOCKET, socket.SCM_RIGHTS, fdarray) with client: client.sendmsg([b'test'], [cmsg]) async with await connect_unix(socket_path) as stream: thread = Thread(target=serve, daemon=True) thread.start() message, fds = await stream.receive_fds(10, 2) thread.join() text = '' for fd in fds: with os.fdopen(fd) as file: text += file.read() assert message == b'test' assert text == 'Hello, World!' async def test_receive_fds_bad_args(self, server_sock: socket.socket, socket_path: Path) -> None: async with await connect_unix(socket_path) as stream: for msglen in (-1, 'foo'): with pytest.raises(ValueError, match='msglen must be a non-negative integer'): await stream.receive_fds(msglen, 0) # type: ignore[arg-type] for maxfds in (0, 'foo'): with pytest.raises(ValueError, match='maxfds must be a positive integer'): await stream.receive_fds(0, maxfds) # type: ignore[arg-type] async def test_send_fds(self, server_sock: socket.socket, socket_path: Path, tmp_path: Path) -> None: def serve() -> None: fds = array.array('i') client, _ = server_sock.accept() msg, ancdata, *_ = client.recvmsg(10, socket.CMSG_LEN(2 * fds.itemsize)) client.close() assert msg == b'test' for cmsg_level, cmsg_type, cmsg_data in ancdata: assert cmsg_level == socket.SOL_SOCKET assert cmsg_type == socket.SCM_RIGHTS fds.frombytes(cmsg_data[:len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) text = '' for fd in fds: with os.fdopen(fd) as file: text += file.read() assert text == 'Hello, World!' path1 = tmp_path / 'file1' path2 = tmp_path / 'file2' path1.write_text('Hello, ') path2.write_text('World!') with path1.open() as file1, path2.open() as file2, fail_after(2): assert isinstance(file1, io.TextIOWrapper) assert isinstance(file2, io.TextIOWrapper) async with await connect_unix(socket_path) as stream: thread = Thread(target=serve, daemon=True) thread.start() await stream.send_fds(b'test', [file1, file2]) thread.join() async def test_send_eof(self, server_sock: socket.socket, socket_path: Path) -> None: def serve() -> None: client, _ = server_sock.accept() request = b'' while True: data = client.recv(100) request += data if not data: break client.sendall(request[::-1]) client.close() async with await connect_unix(socket_path) as stream: thread = Thread(target=serve, daemon=True) thread.start() await stream.send(b'hello, ') await stream.send(b'world\n') await stream.send_eof() response = await stream.receive() thread.join() assert response == b'\ndlrow ,olleh' async def test_iterate(self, server_sock: socket.socket, socket_path: Path) -> None: def serve() -> None: client, _ = server_sock.accept() client.sendall(b'bl') time.sleep(0.05) client.sendall(b'ah') client.close() thread = Thread(target=serve, daemon=True) thread.start() chunks = [] async with await connect_unix(socket_path) as stream: async for chunk in stream: chunks.append(chunk) thread.join() assert chunks == [b'bl', b'ah'] async def test_send_fds_bad_args(self, server_sock: socket.socket, socket_path: Path) -> None: async with await connect_unix(socket_path) as stream: with pytest.raises(ValueError, match='message must not be empty'): await stream.send_fds(b'', [0]) with pytest.raises(ValueError, match='fds must not be empty'): await stream.send_fds(b'test', []) async def test_concurrent_send(self, server_sock: socket.socket, socket_path: Path) -> None: async def send_data() -> NoReturn: while True: await client.send(b'\x00' * 4096) async with await connect_unix(socket_path) as client: async with create_task_group() as tg: tg.start_soon(send_data) await wait_all_tasks_blocked() with pytest.raises(BusyResourceError) as exc: await client.send(b'foo') exc.match('already writing to') tg.cancel_scope.cancel() async def test_concurrent_receive(self, server_sock: socket.socket, socket_path: Path) -> None: async with await connect_unix(socket_path) as client: async with create_task_group() as tg: tg.start_soon(client.receive) await wait_all_tasks_blocked() try: with pytest.raises(BusyResourceError) as exc: await client.receive() exc.match('already reading from') finally: tg.cancel_scope.cancel() async def test_close_during_receive(self, server_sock: socket.socket, socket_path: Path) -> None: async def interrupt() -> None: await wait_all_tasks_blocked() await stream.aclose() async with await connect_unix(socket_path) as stream: async with create_task_group() as tg: tg.start_soon(interrupt) with pytest.raises(ClosedResourceError): await stream.receive() async def test_receive_after_close(self, server_sock: socket.socket, socket_path: Path) -> None: stream = await connect_unix(socket_path) await stream.aclose() with pytest.raises(ClosedResourceError): await stream.receive() async def test_send_after_close(self, server_sock: socket.socket, socket_path: Path) -> None: stream = await connect_unix(socket_path) await stream.aclose() with pytest.raises(ClosedResourceError): await stream.send(b'foo') async def test_cannot_connect(self, socket_path: Path) -> None: with pytest.raises(FileNotFoundError): await connect_unix(socket_path) @pytest.mark.skipif(sys.platform == 'win32', reason='UNIX sockets are not available on Windows') class TestUNIXListener: @pytest.fixture def socket_path(self, tmp_path_factory: TempPathFactory) -> Path: return tmp_path_factory.mktemp('unix').joinpath('socket') @pytest.fixture(params=[False, True], ids=["str", "path"]) def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Union[Path, str]: return socket_path if request.param else str(socket_path) async def test_extra_attributes(self, socket_path: Path) -> None: async with await create_unix_listener(socket_path) as listener: raw_socket = listener.extra(SocketAttribute.raw_socket) assert listener.extra(SocketAttribute.family) == socket.AF_UNIX assert listener.extra(SocketAttribute.local_address) == raw_socket.getsockname() pytest.raises(TypedAttributeLookupError, listener.extra, SocketAttribute.local_port) pytest.raises(TypedAttributeLookupError, listener.extra, SocketAttribute.remote_address) pytest.raises(TypedAttributeLookupError, listener.extra, SocketAttribute.remote_port) async def test_accept(self, socket_path_or_str: Union[Path, str]) -> None: async with await create_unix_listener(socket_path_or_str) as listener: client = socket.socket(socket.AF_UNIX) client.settimeout(1) client.connect(str(socket_path_or_str)) stream = await listener.accept() client.sendall(b'blah') request = await stream.receive() await stream.send(request[::-1]) assert client.recv(100) == b'halb' client.close() await stream.aclose() async def test_socket_options(self, socket_path: Path) -> None: async with await create_unix_listener(socket_path) as listener: listener_socket = listener.extra(SocketAttribute.raw_socket) assert listener_socket.family == socket.AddressFamily.AF_UNIX listener_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 80000) assert listener_socket.getsockopt( socket.SOL_SOCKET, socket.SO_RCVBUF) in (80000, 160000) client = socket.socket(listener_socket.family) client.settimeout(1) client.connect(listener_socket.getsockname()) async with await listener.accept() as stream: assert stream.extra(SocketAttribute.raw_socket).gettimeout() == 0 assert stream.extra(SocketAttribute.family) == listener_socket.family client.close() async def test_send_after_eof(self, socket_path: Path) -> None: async def handle(stream: SocketStream) -> None: async with stream: await stream.send(b'Hello\n') async with await create_unix_listener(socket_path) as listener, create_task_group() as tg: tg.start_soon(listener.serve, handle) await wait_all_tasks_blocked() with socket.socket(socket.AF_UNIX) as client: client.connect(str(socket_path)) client.shutdown(socket.SHUT_WR) client.setblocking(False) with fail_after(1): while True: try: message = client.recv(10) except BlockingIOError: await sleep(0) else: assert message == b'Hello\n' break tg.cancel_scope.cancel() async def test_bind_twice(self, socket_path: Path) -> None: """Test that the previous socket is removed before binding to the path.""" for _ in range(2): async with await create_unix_listener(socket_path): pass async def test_multi_listener(tmp_path_factory: TempPathFactory) -> None: async def handle(stream: SocketStream) -> None: client_addresses.append(stream.extra(SocketAttribute.remote_address)) event.set() await stream.aclose() client_addresses: List[Union[str, IPSockAddrType]] = [] listeners: List[Listener] = [await create_tcp_listener(local_host='localhost')] if sys.platform != 'win32': socket_path = tmp_path_factory.mktemp('unix').joinpath('socket') listeners.append(await create_unix_listener(socket_path)) expected_addresses: List[Union[str, IPSockAddrType]] = [] async with MultiListener(listeners) as multi_listener: async with create_task_group() as tg: tg.start_soon(multi_listener.serve, handle) for listener in multi_listener.listeners: event = Event() local_address = listener.extra(SocketAttribute.local_address) if sys.platform != 'win32' and listener.extra(SocketAttribute.family) == \ socket.AddressFamily.AF_UNIX: assert isinstance(local_address, str) stream: SocketStream = await connect_unix(local_address) else: assert isinstance(local_address, tuple) stream = await connect_tcp(*local_address) expected_addresses.append(stream.extra(SocketAttribute.local_address)) await event.wait() await stream.aclose() tg.cancel_scope.cancel() assert client_addresses == expected_addresses @pytest.mark.usefixtures('check_asyncio_bug') class TestUDPSocket: async def test_extra_attributes(self, family: AnyIPAddressFamily) -> None: async with await create_udp_socket(family=family, local_host='localhost') as udp: raw_socket = udp.extra(SocketAttribute.raw_socket) assert raw_socket.gettimeout() == 0 assert udp.extra(SocketAttribute.family) == family assert udp.extra(SocketAttribute.local_address) == raw_socket.getsockname()[:2] assert udp.extra(SocketAttribute.local_port) == raw_socket.getsockname()[1] pytest.raises(TypedAttributeLookupError, udp.extra, SocketAttribute.remote_address) pytest.raises(TypedAttributeLookupError, udp.extra, SocketAttribute.remote_port) async def test_send_receive(self, family: AnyIPAddressFamily) -> None: async with await create_udp_socket(local_host='localhost', family=family) as sock: host, port = sock.extra(SocketAttribute.local_address) # type: ignore[misc] await sock.sendto(b'blah', host, port) request, addr = await sock.receive() assert request == b'blah' assert addr == sock.extra(SocketAttribute.local_address) await sock.sendto(b'halb', host, port) response, addr = await sock.receive() assert response == b'halb' assert addr == (host, port) async def test_iterate(self, family: AnyIPAddressFamily) -> None: async def serve() -> None: async for packet, addr in server: await server.send((packet[::-1], addr)) async with await create_udp_socket(family=family, local_host='localhost') as server: host, port = server.extra(SocketAttribute.local_address) # type: ignore[misc] async with await create_udp_socket(family=family, local_host='localhost') as client: async with create_task_group() as tg: tg.start_soon(serve) await client.sendto(b'FOOBAR', host, port) assert await client.receive() == (b'RABOOF', (host, port)) await client.sendto(b'123456', host, port) assert await client.receive() == (b'654321', (host, port)) tg.cancel_scope.cancel() @pytest.mark.skipif(not hasattr(socket, "SO_REUSEPORT"), reason='SO_REUSEPORT option not supported') async def test_reuse_port(self, family: AnyIPAddressFamily) -> None: async with await create_udp_socket(family=family, local_host='localhost', reuse_port=True) as udp: port = udp.extra(SocketAttribute.local_port) assert port != 0 async with await create_udp_socket(family=family, local_host='localhost', local_port=port, reuse_port=True) as udp2: assert port == udp2.extra(SocketAttribute.local_port) async def test_concurrent_receive(self) -> None: async with await create_udp_socket(family=AddressFamily.AF_INET, local_host='localhost') as udp: async with create_task_group() as tg: tg.start_soon(udp.receive) await wait_all_tasks_blocked() try: with pytest.raises(BusyResourceError) as exc: await udp.receive() exc.match('already reading from') finally: tg.cancel_scope.cancel() async def test_close_during_receive(self) -> None: async def close_when_blocked() -> None: await wait_all_tasks_blocked() await udp.aclose() async with await create_udp_socket(family=AddressFamily.AF_INET, local_host='localhost') as udp: async with create_task_group() as tg: tg.start_soon(close_when_blocked) with pytest.raises(ClosedResourceError): await udp.receive() async def test_receive_after_close(self) -> None: udp = await create_udp_socket(family=AddressFamily.AF_INET, local_host='localhost') await udp.aclose() with pytest.raises(ClosedResourceError): await udp.receive() async def test_send_after_close(self) -> None: udp = await create_udp_socket(family=AddressFamily.AF_INET, local_host='localhost') host, port = udp.extra(SocketAttribute.local_address) # type: ignore[misc] await udp.aclose() with pytest.raises(ClosedResourceError): await udp.sendto(b'foo', host, port) async def test_create_unbound_socket(self, family: AnyIPAddressFamily) -> None: """Regression test for #360.""" async with await create_udp_socket(family=family) as udp: local_address = cast(IPSockAddrType, udp.extra(SocketAttribute.local_address)) assert local_address[1] > 0 @pytest.mark.usefixtures('check_asyncio_bug') class TestConnectedUDPSocket: async def test_extra_attributes(self, family: AnyIPAddressFamily) -> None: async with await create_connected_udp_socket('localhost', 5000, family=family) as udp: raw_socket = udp.extra(SocketAttribute.raw_socket) assert udp.extra(SocketAttribute.family) == family assert udp.extra(SocketAttribute.local_address) == raw_socket.getsockname()[:2] assert udp.extra(SocketAttribute.local_port) == raw_socket.getsockname()[1] assert udp.extra(SocketAttribute.remote_address) == raw_socket.getpeername()[:2] assert udp.extra(SocketAttribute.remote_port) == 5000 async def test_send_receive(self, family: AnyIPAddressFamily) -> None: async with await create_udp_socket(family=family, local_host='localhost') as udp1: host, port = udp1.extra(SocketAttribute.local_address) # type: ignore[misc] async with await create_connected_udp_socket( host, port, local_host='localhost', family=family) as udp2: host, port = udp2.extra(SocketAttribute.local_address) # type: ignore[misc] await udp2.send(b'blah') request = await udp1.receive() assert request == (b'blah', (host, port)) await udp1.sendto(b'halb', host, port) response = await udp2.receive() assert response == b'halb' async def test_iterate(self, family: AnyIPAddressFamily) -> None: async def serve() -> None: async for packet in udp2: await udp2.send(packet[::-1]) async with await create_udp_socket(family=family, local_host='localhost') as udp1: host, port = udp1.extra(SocketAttribute.local_address) # type: ignore[misc] async with await create_connected_udp_socket(host, port) as udp2: host, port = udp2.extra(SocketAttribute.local_address) # type: ignore[misc] async with create_task_group() as tg: tg.start_soon(serve) await udp1.sendto(b'FOOBAR', host, port) assert await udp1.receive() == (b'RABOOF', (host, port)) await udp1.sendto(b'123456', host, port) assert await udp1.receive() == (b'654321', (host, port)) tg.cancel_scope.cancel() @pytest.mark.skipif(not hasattr(socket, "SO_REUSEPORT"), reason='SO_REUSEPORT option not supported') async def test_reuse_port(self, family: AnyIPAddressFamily) -> None: async with await create_connected_udp_socket( 'localhost', 6000, family=family, local_host='localhost', reuse_port=True) as udp: port = udp.extra(SocketAttribute.local_port) assert port != 0 async with await create_connected_udp_socket( 'localhost', 6001, family=family, local_host='localhost', local_port=port, reuse_port=True) as udp2: assert port == udp2.extra(SocketAttribute.local_port) async def test_concurrent_receive(self) -> None: async with await create_connected_udp_socket( 'localhost', 5000, local_host='localhost', family=AddressFamily.AF_INET) as udp: async with create_task_group() as tg: tg.start_soon(udp.receive) await wait_all_tasks_blocked() try: with pytest.raises(BusyResourceError) as exc: await udp.receive() exc.match('already reading from') finally: tg.cancel_scope.cancel() async def test_close_during_receive(self) -> None: async def close_when_blocked() -> None: await wait_all_tasks_blocked() await udp.aclose() async with await create_connected_udp_socket( 'localhost', 5000, local_host='localhost', family=AddressFamily.AF_INET) as udp: async with create_task_group() as tg: tg.start_soon(close_when_blocked) with pytest.raises(ClosedResourceError): await udp.receive() async def test_receive_after_close(self, family: AnyIPAddressFamily) -> None: udp = await create_connected_udp_socket('localhost', 5000, local_host='localhost', family=family) await udp.aclose() with pytest.raises(ClosedResourceError): await udp.receive() async def test_send_after_close(self, family: AnyIPAddressFamily) -> None: udp = await create_connected_udp_socket('localhost', 5000, local_host='localhost', family=family) await udp.aclose() with pytest.raises(ClosedResourceError): await udp.send(b'foo') @pytest.mark.network async def test_getaddrinfo() -> None: # IDNA 2003 gets this wrong correct = await getaddrinfo('faß.de', 0) wrong = await getaddrinfo('fass.de', 0) assert correct != wrong @pytest.mark.parametrize('sock_type', [socket.SOCK_STREAM, socket.SocketKind.SOCK_STREAM]) async def test_getaddrinfo_ipv6addr(sock_type: Literal[socket.SocketKind.SOCK_STREAM]) -> None: # IDNA trips up over raw IPv6 addresses proto = 0 if platform.system() == 'Windows' else 6 assert await getaddrinfo('::1', 0, type=sock_type) == [ (socket.AddressFamily.AF_INET6, socket.SocketKind.SOCK_STREAM, proto, '', ('::1', 0)) ] async def test_getnameinfo() -> None: expected_result = socket.getnameinfo(('127.0.0.1', 6666), 0) result = await getnameinfo(('127.0.0.1', 6666)) assert result == expected_result