import socket import ssl from contextlib import ExitStack from threading import Thread from typing import ContextManager, NoReturn import pytest from trustme import CA from anyio import ( BrokenResourceError, EndOfStream, Event, connect_tcp, create_task_group, create_tcp_listener) from anyio.abc import AnyByteStream, SocketAttribute, SocketStream from anyio.streams.tls import TLSAttribute, TLSListener, TLSStream pytestmark = pytest.mark.anyio class TestTLSStream: async def test_send_receive(self, server_context: ssl.SSLContext, client_context: ssl.SSLContext) -> None: def serve_sync() -> None: conn, addr = server_sock.accept() conn.settimeout(1) data = conn.recv(10) conn.send(data[::-1]) conn.close() server_sock = server_context.wrap_socket(socket.socket(), server_side=True, suppress_ragged_eofs=False) server_sock.settimeout(1) server_sock.bind(('127.0.0.1', 0)) server_sock.listen() server_thread = Thread(target=serve_sync) server_thread.start() async with await connect_tcp(*server_sock.getsockname()) as stream: wrapper = await TLSStream.wrap(stream, hostname='localhost', ssl_context=client_context) await wrapper.send(b'hello') response = await wrapper.receive() server_thread.join() server_sock.close() assert response == b'olleh' async def test_extra_attributes(self, server_context: ssl.SSLContext, client_context: ssl.SSLContext) -> None: def serve_sync() -> None: conn, addr = server_sock.accept() with conn: conn.settimeout(1) conn.recv(1) server_context.set_alpn_protocols(['h2']) client_context.set_alpn_protocols(['h2']) server_sock = server_context.wrap_socket(socket.socket(), server_side=True, suppress_ragged_eofs=True) server_sock.settimeout(1) server_sock.bind(('127.0.0.1', 0)) server_sock.listen() server_thread = Thread(target=serve_sync) server_thread.start() async with await connect_tcp(*server_sock.getsockname()) as stream: wrapper = await TLSStream.wrap(stream, hostname='localhost', ssl_context=client_context, standard_compatible=False) async with wrapper: for name, attribute in SocketAttribute.__dict__.items(): if not name.startswith('_'): assert wrapper.extra(attribute) == stream.extra(attribute) assert wrapper.extra(TLSAttribute.alpn_protocol) == 'h2' assert isinstance(wrapper.extra(TLSAttribute.channel_binding_tls_unique), bytes) assert isinstance(wrapper.extra(TLSAttribute.cipher), tuple) assert isinstance(wrapper.extra(TLSAttribute.peer_certificate), dict) assert isinstance(wrapper.extra(TLSAttribute.peer_certificate_binary), bytes) assert wrapper.extra(TLSAttribute.server_side) is False assert isinstance(wrapper.extra(TLSAttribute.shared_ciphers), list) assert isinstance(wrapper.extra(TLSAttribute.ssl_object), ssl.SSLObject) assert wrapper.extra(TLSAttribute.standard_compatible) is False assert wrapper.extra(TLSAttribute.tls_version).startswith('TLSv') await wrapper.send(b'\x00') server_thread.join() server_sock.close() async def test_unwrap(self, server_context: ssl.SSLContext, client_context: ssl.SSLContext) -> None: def serve_sync() -> None: conn, addr = server_sock.accept() conn.settimeout(1) conn.send(b'encrypted') unencrypted = conn.unwrap() unencrypted.send(b'unencrypted') unencrypted.close() server_sock = server_context.wrap_socket(socket.socket(), server_side=True, suppress_ragged_eofs=False) server_sock.settimeout(1) server_sock.bind(('127.0.0.1', 0)) server_sock.listen() server_thread = Thread(target=serve_sync) server_thread.start() async with await connect_tcp(*server_sock.getsockname()) as stream: wrapper = await TLSStream.wrap(stream, hostname='localhost', ssl_context=client_context) msg1 = await wrapper.receive() unwrapped_stream, msg2 = await wrapper.unwrap() if msg2 != b'unencrypted': msg2 += await unwrapped_stream.receive() server_thread.join() server_sock.close() assert msg1 == b'encrypted' assert msg2 == b'unencrypted' @pytest.mark.skipif(not ssl.HAS_ALPN, reason='ALPN support not available') async def test_alpn_negotiation(self, server_context: ssl.SSLContext, client_context: ssl.SSLContext) -> None: def serve_sync() -> None: conn, addr = server_sock.accept() conn.settimeout(1) selected_alpn_protocol = conn.selected_alpn_protocol() assert selected_alpn_protocol is not None conn.send(selected_alpn_protocol.encode()) conn.close() server_context.set_alpn_protocols(['dummy1', 'dummy2']) client_context.set_alpn_protocols(['dummy2', 'dummy3']) server_sock = server_context.wrap_socket(socket.socket(), server_side=True, suppress_ragged_eofs=False) server_sock.settimeout(1) server_sock.bind(('127.0.0.1', 0)) server_sock.listen() server_thread = Thread(target=serve_sync) server_thread.start() async with await connect_tcp(*server_sock.getsockname()) as stream: wrapper = await TLSStream.wrap(stream, hostname='localhost', ssl_context=client_context) assert wrapper.extra(TLSAttribute.alpn_protocol) == 'dummy2' server_alpn_protocol = await wrapper.receive() server_thread.join() server_sock.close() assert server_alpn_protocol == b'dummy2' @pytest.mark.parametrize('server_compatible, client_compatible', [ pytest.param(True, True, id='both_standard'), pytest.param(True, False, id='server_standard'), pytest.param(False, True, id='client_standard'), pytest.param(False, False, id='neither_standard') ]) async def test_ragged_eofs(self, server_context: ssl.SSLContext, client_context: ssl.SSLContext, server_compatible: bool, client_compatible: bool) -> None: server_exc = None def serve_sync() -> None: nonlocal server_exc conn, addr = server_sock.accept() try: conn.settimeout(1) conn.sendall(b'hello') if server_compatible: conn.unwrap() except BaseException as exc: server_exc = exc finally: conn.close() client_cm: ContextManager = ExitStack() if client_compatible and not server_compatible: client_cm = pytest.raises(BrokenResourceError) server_sock = server_context.wrap_socket(socket.socket(), server_side=True, suppress_ragged_eofs=not server_compatible) server_sock.settimeout(1) server_sock.bind(('127.0.0.1', 0)) server_sock.listen() server_thread = Thread(target=serve_sync, daemon=True) server_thread.start() async with await connect_tcp(*server_sock.getsockname()) as stream: wrapper = await TLSStream.wrap(stream, hostname='localhost', ssl_context=client_context, standard_compatible=client_compatible) with client_cm: assert await wrapper.receive() == b'hello' await wrapper.aclose() server_thread.join() server_sock.close() if not client_compatible and server_compatible: assert isinstance(server_exc, OSError) assert not isinstance(server_exc, socket.timeout) else: assert server_exc is None async def test_ragged_eof_on_receive(self, server_context: ssl.SSLContext, client_context: ssl.SSLContext) -> None: server_exc = None def serve_sync() -> None: nonlocal server_exc conn, addr = server_sock.accept() try: conn.settimeout(1) conn.sendall(b'hello') except BaseException as exc: server_exc = exc finally: conn.close() server_sock = server_context.wrap_socket(socket.socket(), server_side=True, suppress_ragged_eofs=True) server_sock.settimeout(1) server_sock.bind(('127.0.0.1', 0)) server_sock.listen() server_thread = Thread(target=serve_sync, daemon=True) server_thread.start() try: async with await connect_tcp(*server_sock.getsockname()) as stream: wrapper = await TLSStream.wrap(stream, hostname='localhost', ssl_context=client_context, standard_compatible=False) assert await wrapper.receive() == b'hello' with pytest.raises(EndOfStream): await wrapper.receive() finally: server_thread.join() server_sock.close() assert server_exc is None async def test_receive_send_after_eof(self, server_context: ssl.SSLContext, client_context: ssl.SSLContext) -> None: def serve_sync() -> None: conn, addr = server_sock.accept() conn.sendall(b'hello') conn.unwrap() conn.close() server_sock = server_context.wrap_socket(socket.socket(), server_side=True, suppress_ragged_eofs=False) server_sock.settimeout(1) server_sock.bind(('127.0.0.1', 0)) server_sock.listen() server_thread = Thread(target=serve_sync, daemon=True) server_thread.start() stream = await connect_tcp(*server_sock.getsockname()) async with await TLSStream.wrap(stream, hostname='localhost', ssl_context=client_context) as wrapper: assert await wrapper.receive() == b'hello' with pytest.raises(EndOfStream): await wrapper.receive() server_thread.join() server_sock.close() @pytest.mark.parametrize('force_tlsv12', [ pytest.param(False, marks=[pytest.mark.skipif(not getattr(ssl, 'HAS_TLSv1_3', False), reason='No TLS 1.3 support')]), pytest.param(True) ], ids=['tlsv13', 'tlsv12']) async def test_send_eof_not_implemented(self, server_context: ssl.SSLContext, ca: CA, force_tlsv12: bool) -> None: def serve_sync() -> None: conn, addr = server_sock.accept() conn.sendall(b'hello') conn.unwrap() conn.close() client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) ca.configure_trust(client_context) if force_tlsv12: expected_pattern = r'send_eof\(\) requires at least TLSv1.3' if hasattr(ssl, 'TLSVersion'): client_context.maximum_version = ssl.TLSVersion.TLSv1_2 else: # Python 3.6 client_context.options |= ssl.OP_NO_TLSv1_3 else: expected_pattern = r'send_eof\(\) has not yet been implemented for TLS streams' server_sock = server_context.wrap_socket(socket.socket(), server_side=True, suppress_ragged_eofs=False) server_sock.settimeout(1) server_sock.bind(('127.0.0.1', 0)) server_sock.listen() server_thread = Thread(target=serve_sync, daemon=True) server_thread.start() stream = await connect_tcp(*server_sock.getsockname()) async with await TLSStream.wrap(stream, hostname='localhost', ssl_context=client_context) as wrapper: assert await wrapper.receive() == b'hello' with pytest.raises(NotImplementedError) as exc: await wrapper.send_eof() exc.match(expected_pattern) server_thread.join() server_sock.close() class TestTLSListener: async def test_handshake_fail(self, server_context: ssl.SSLContext) -> None: def handler(stream: object) -> NoReturn: pytest.fail('This function should never be called in this scenario') exception = None class CustomTLSListener(TLSListener): @staticmethod async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None: nonlocal exception await TLSListener.handle_handshake_error(exc, stream) assert isinstance(stream, SocketStream) exception = exc event.set() event = Event() listener = await create_tcp_listener(local_host='127.0.0.1') tls_listener = CustomTLSListener(listener, server_context) async with tls_listener, create_task_group() as tg: tg.start_soon(tls_listener.serve, handler) sock = socket.socket() sock.connect(listener.extra(SocketAttribute.local_address)) sock.close() await event.wait() tg.cancel_scope.cancel() assert isinstance(exception, BrokenResourceError)