# Copyright (c) Twisted Matrix Laboratories. # See LICENSE for details. """ Test cases for L{twisted.protocols.haproxy.HAProxyProtocol}. """ from typing import Optional from unittest import mock from twisted.internet import address from twisted.internet.protocol import Factory, Protocol from twisted.test.proto_helpers import StringTransportWithDisconnection from twisted.trial import unittest from .._wrapper import HAProxyWrappingFactory class StaticProtocol(Protocol): """ Protocol stand-in that maintains test state. """ def __init__(self) -> None: self.source: Optional[address.IAddress] = None self.destination: Optional[address.IAddress] = None self.data = b"" self.disconnected = False def dataReceived(self, data: bytes) -> None: assert self.transport self.source = self.transport.getPeer() self.destination = self.transport.getHost() self.data += data class HAProxyWrappingFactoryV1Tests(unittest.TestCase): """ Test L{twisted.protocols.haproxy.HAProxyWrappingFactory} with v1 PROXY headers. """ def test_invalidHeaderDisconnects(self) -> None: """ Test if invalid headers result in connectionLost events. """ factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol)) proto = factory.buildProtocol( address.IPv4Address("TCP", "127.1.1.1", 8080), ) transport = StringTransportWithDisconnection() transport.protocol = proto proto.makeConnection(transport) proto.dataReceived(b"NOTPROXY anything can go here\r\n") self.assertFalse(transport.connected) def test_invalidPartialHeaderDisconnects(self) -> None: """ Test if invalid headers result in connectionLost events. """ factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol)) proto = factory.buildProtocol( address.IPv4Address("TCP", "127.1.1.1", 8080), ) transport = StringTransportWithDisconnection() transport.protocol = proto proto.makeConnection(transport) proto.dataReceived(b"PROXY TCP4 1.1.1.1\r\n") proto.dataReceived(b"2.2.2.2 8080\r\n") self.assertFalse(transport.connected) def test_preDataReceived_getPeerHost(self) -> None: """ Before any data is received the HAProxy protocol will return the same peer and host as the IP connection. """ factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol)) proto = factory.buildProtocol( address.IPv4Address("TCP", "127.0.0.1", 8080), ) transport = StringTransportWithDisconnection( hostAddress=mock.sentinel.host_address, peerAddress=mock.sentinel.peer_address, ) proto.makeConnection(transport) self.assertEqual(proto.getHost(), mock.sentinel.host_address) self.assertEqual(proto.getPeer(), mock.sentinel.peer_address) def test_validIPv4HeaderResolves_getPeerHost(self) -> None: """ Test if IPv4 headers result in the correct host and peer data. """ factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol)) proto = factory.buildProtocol( address.IPv4Address("TCP", "127.0.0.1", 8080), ) transport = StringTransportWithDisconnection() proto.makeConnection(transport) proto.dataReceived(b"PROXY TCP4 1.1.1.1 2.2.2.2 8080 8888\r\n") self.assertEqual(proto.getPeer().host, "1.1.1.1") self.assertEqual(proto.getPeer().port, 8080) self.assertEqual( proto.wrappedProtocol.transport.getPeer().host, "1.1.1.1", ) self.assertEqual( proto.wrappedProtocol.transport.getPeer().port, 8080, ) self.assertEqual(proto.getHost().host, "2.2.2.2") self.assertEqual(proto.getHost().port, 8888) self.assertEqual( proto.wrappedProtocol.transport.getHost().host, "2.2.2.2", ) self.assertEqual( proto.wrappedProtocol.transport.getHost().port, 8888, ) def test_validIPv6HeaderResolves_getPeerHost(self) -> None: """ Test if IPv6 headers result in the correct host and peer data. """ factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol)) proto = factory.buildProtocol( address.IPv6Address("TCP", "::1", 8080), ) transport = StringTransportWithDisconnection() proto.makeConnection(transport) proto.dataReceived(b"PROXY TCP6 ::1 ::2 8080 8888\r\n") self.assertEqual(proto.getPeer().host, "::1") self.assertEqual(proto.getPeer().port, 8080) self.assertEqual( proto.wrappedProtocol.transport.getPeer().host, "::1", ) self.assertEqual( proto.wrappedProtocol.transport.getPeer().port, 8080, ) self.assertEqual(proto.getHost().host, "::2") self.assertEqual(proto.getHost().port, 8888) self.assertEqual( proto.wrappedProtocol.transport.getHost().host, "::2", ) self.assertEqual( proto.wrappedProtocol.transport.getHost().port, 8888, ) def test_overflowBytesSentToWrappedProtocol(self) -> None: """ Test if non-header bytes are passed to the wrapped protocol. """ factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol)) proto = factory.buildProtocol( address.IPv6Address("TCP", "::1", 8080), ) transport = StringTransportWithDisconnection() proto.makeConnection(transport) proto.dataReceived(b"PROXY TCP6 ::1 ::2 8080 8888\r\nHTTP/1.1 / GET") self.assertEqual(proto.wrappedProtocol.data, b"HTTP/1.1 / GET") def test_overflowBytesSentToWrappedProtocolChunks(self) -> None: """ Test if header streaming passes extra data appropriately. """ factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol)) proto = factory.buildProtocol( address.IPv6Address("TCP", "::1", 8080), ) transport = StringTransportWithDisconnection() proto.makeConnection(transport) proto.dataReceived(b"PROXY TCP6 ::1 ::2 ") proto.dataReceived(b"8080 8888\r\nHTTP/1.1 / GET") self.assertEqual(proto.wrappedProtocol.data, b"HTTP/1.1 / GET") def test_overflowBytesSentToWrappedProtocolAfter(self) -> None: """ Test if wrapper writes all data to wrapped protocol after parsing. """ factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol)) proto = factory.buildProtocol( address.IPv6Address("TCP", "::1", 8080), ) transport = StringTransportWithDisconnection() proto.makeConnection(transport) proto.dataReceived(b"PROXY TCP6 ::1 ::2 ") proto.dataReceived(b"8080 8888\r\nHTTP/1.1 / GET") proto.dataReceived(b"\r\n\r\n") self.assertEqual(proto.wrappedProtocol.data, b"HTTP/1.1 / GET\r\n\r\n") class HAProxyWrappingFactoryV2Tests(unittest.TestCase): """ Test L{twisted.protocols.haproxy.HAProxyWrappingFactory} with v2 PROXY headers. """ IPV4HEADER = ( # V2 Signature b"\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A" # V2 PROXY command b"\x21" # AF_INET/STREAM b"\x11" # 12 bytes for 2 IPv4 addresses and two ports b"\x00\x0C" # 127.0.0.1 for source and destination b"\x7F\x00\x00\x01\x7F\x00\x00\x01" # 8080 for source 8888 for destination b"\x1F\x90\x22\xB8" ) IPV6HEADER = ( # V2 Signature b"\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A" # V2 PROXY command b"\x21" # AF_INET6/STREAM b"\x21" # 16 bytes for 2 IPv6 addresses and two ports b"\x00\x24" # ::1 for source and destination b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" # 8080 for source 8888 for destination b"\x1F\x90\x22\xB8" ) _SOCK_PATH = ( b"\x2F\x68\x6F\x6D\x65\x2F\x74\x65\x73\x74\x73\x2F\x6D\x79\x73\x6F" b"\x63\x6B\x65\x74\x73\x2F\x73\x6F\x63\x6B" + (b"\x00" * 82) ) UNIXHEADER = ( ( # V2 Signature b"\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A" # V2 PROXY command b"\x21" # AF_UNIX/STREAM b"\x31" # 108 bytes for 2 null terminated paths b"\x00\xD8" # /home/tests/mysockets/sock for source and destination paths ) + _SOCK_PATH + _SOCK_PATH ) def test_invalidHeaderDisconnects(self) -> None: """ Test if invalid headers result in connectionLost events. """ factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol)) proto = factory.buildProtocol( address.IPv6Address("TCP", "::1", 8080), ) transport = StringTransportWithDisconnection() transport.protocol = proto proto.makeConnection(transport) proto.dataReceived(b"\x00" + self.IPV4HEADER[1:]) self.assertFalse(transport.connected) def test_validIPv4HeaderResolves_getPeerHost(self) -> None: """ Test if IPv4 headers result in the correct host and peer data. """ factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol)) proto = factory.buildProtocol( address.IPv4Address("TCP", "127.0.0.1", 8080), ) transport = StringTransportWithDisconnection() proto.makeConnection(transport) proto.dataReceived(self.IPV4HEADER) self.assertEqual(proto.getPeer().host, "127.0.0.1") self.assertEqual(proto.getPeer().port, 8080) self.assertEqual( proto.wrappedProtocol.transport.getPeer().host, "127.0.0.1", ) self.assertEqual( proto.wrappedProtocol.transport.getPeer().port, 8080, ) self.assertEqual(proto.getHost().host, "127.0.0.1") self.assertEqual(proto.getHost().port, 8888) self.assertEqual( proto.wrappedProtocol.transport.getHost().host, "127.0.0.1", ) self.assertEqual( proto.wrappedProtocol.transport.getHost().port, 8888, ) def test_validIPv6HeaderResolves_getPeerHost(self) -> None: """ Test if IPv6 headers result in the correct host and peer data. """ factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol)) proto = factory.buildProtocol( address.IPv4Address("TCP", "::1", 8080), ) transport = StringTransportWithDisconnection() proto.makeConnection(transport) proto.dataReceived(self.IPV6HEADER) self.assertEqual(proto.getPeer().host, "0:0:0:0:0:0:0:1") self.assertEqual(proto.getPeer().port, 8080) self.assertEqual( proto.wrappedProtocol.transport.getPeer().host, "0:0:0:0:0:0:0:1", ) self.assertEqual( proto.wrappedProtocol.transport.getPeer().port, 8080, ) self.assertEqual(proto.getHost().host, "0:0:0:0:0:0:0:1") self.assertEqual(proto.getHost().port, 8888) self.assertEqual( proto.wrappedProtocol.transport.getHost().host, "0:0:0:0:0:0:0:1", ) self.assertEqual( proto.wrappedProtocol.transport.getHost().port, 8888, ) def test_validUNIXHeaderResolves_getPeerHost(self) -> None: """ Test if UNIX headers result in the correct host and peer data. """ factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol)) proto = factory.buildProtocol( address.UNIXAddress(b"/home/test/sockets/server.sock"), ) transport = StringTransportWithDisconnection() proto.makeConnection(transport) proto.dataReceived(self.UNIXHEADER) self.assertEqual(proto.getPeer().name, b"/home/tests/mysockets/sock") self.assertEqual( proto.wrappedProtocol.transport.getPeer().name, b"/home/tests/mysockets/sock", ) self.assertEqual(proto.getHost().name, b"/home/tests/mysockets/sock") self.assertEqual( proto.wrappedProtocol.transport.getHost().name, b"/home/tests/mysockets/sock", ) def test_overflowBytesSentToWrappedProtocol(self) -> None: """ Test if non-header bytes are passed to the wrapped protocol. """ factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol)) proto = factory.buildProtocol( address.IPv6Address("TCP", "::1", 8080), ) transport = StringTransportWithDisconnection() proto.makeConnection(transport) proto.dataReceived(self.IPV6HEADER + b"HTTP/1.1 / GET") self.assertEqual(proto.wrappedProtocol.data, b"HTTP/1.1 / GET") def test_overflowBytesSentToWrappedProtocolChunks(self) -> None: """ Test if header streaming passes extra data appropriately. """ factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol)) proto = factory.buildProtocol( address.IPv6Address("TCP", "::1", 8080), ) transport = StringTransportWithDisconnection() proto.makeConnection(transport) proto.dataReceived(self.IPV6HEADER[:18]) proto.dataReceived(self.IPV6HEADER[18:] + b"HTTP/1.1 / GET") self.assertEqual(proto.wrappedProtocol.data, b"HTTP/1.1 / GET")