# Copyright (c) Twisted Matrix Laboratories. # See LICENSE for details. """ Tests for implementations of L{IReactorTCP} and the TCP parts of L{IReactorSocket}. """ import errno import gc import io import os import socket from functools import wraps from typing import Callable, ClassVar, List, Optional, Sequence, Type from unittest import skipIf from zope.interface import Interface, implementer from zope.interface.verify import verifyClass, verifyObject import attr from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.defer import ( Deferred, DeferredList, fail, gatherResults, maybeDeferred, succeed, ) from twisted.internet.endpoints import TCP4ClientEndpoint, TCP4ServerEndpoint from twisted.internet.error import ( ConnectBindError, ConnectionAborted, ConnectionClosed, ConnectionDone, ConnectionLost, ConnectionRefusedError, DNSLookupError, NoProtocol, UserError, ) from twisted.internet.interfaces import ( IConnector, IHalfCloseableProtocol, ILoggingContext, IPullProducer, IPushProducer, IReactorFDSet, IReactorSocket, IReactorTCP, IResolverSimple, ITLSTransport, ) from twisted.internet.protocol import ClientFactory, Protocol, ServerFactory from twisted.internet.tcp import ( Connection, Server, _BuffersLogs, _FileDescriptorReservation, _IFileDescriptorReservation, _NullFileDescriptorReservation, _resolveIPv6, ) from twisted.internet.test.connectionmixins import ( BrokenContextFactory, ConnectableProtocol, ConnectionTestsMixin, EndpointCreator, LogObserverMixin, Stop, StreamClientTestsMixin, findFreePort, runProtocolsWithReactor, ) from twisted.internet.test.reactormixins import ( ReactorBuilder, needsRunningReactor, stopOnError, ) from twisted.logger import Logger from twisted.python import log from twisted.python.failure import Failure from twisted.python.runtime import platform from twisted.test.proto_helpers import MemoryReactor, StringTransport from twisted.test.test_tcp import ( ClientStartStopFactory, ClosingFactory, MyClientFactory, MyServerFactory, ) from twisted.trial.unittest import SkipTest, SynchronousTestCase, TestCase try: from OpenSSL import SSL except ImportError: useSSL = False else: from twisted.internet.ssl import ClientContextFactory useSSL = True s = None try: s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) s.bind(("::1", 0)) except OSError as e: ipv6Skip = True ipv6SkipReason = f"IPv6 not available. {e}" else: ipv6Skip = False ipv6SkipReason = "" if s is not None: s.close() if platform.isWindows(): from twisted.internet.test import _win32ifaces getLinkLocalIPv6Addresses = _win32ifaces.win32GetLinkLocalIPv6Addresses SKIP_EMFILE = True else: try: from twisted.internet.test import _posixifaces except ImportError: getLinkLocalIPv6Addresses = lambda: [] else: getLinkLocalIPv6Addresses = _posixifaces.posixGetLinkLocalIPv6Addresses SKIP_EMFILE = False def getLinkLocalIPv6Address(): """ Find and return a configured link local IPv6 address including a scope identifier using the % separation syntax. If the system has no link local IPv6 addresses, raise L{SkipTest} instead. @raise SkipTest: if no link local address can be found or if the C{netifaces} module is not available. @return: a C{str} giving the address """ addresses = getLinkLocalIPv6Addresses() if addresses: return addresses[0] raise SkipTest("Link local IPv6 address unavailable") def connect(client, destination): """ Connect a socket to the given destination. @param client: A C{socket.socket}. @param destination: A tuple of (host, port). The host is a C{str}, the port a C{int}. If the C{host} is an IPv6 IP, the address is resolved using C{getaddrinfo} and the first version found is used. """ (host, port) = destination if "%" in host or ":" in host: address = socket.getaddrinfo(host, port)[0][4] else: address = (host, port) client.connect(address) class FakeSocket: """ A fake for L{socket.socket} objects. @ivar data: A C{str} giving the data which will be returned from L{FakeSocket.recv}. @ivar sendBuffer: A C{list} of the objects passed to L{FakeSocket.send}. """ def __init__(self, data): self.data = data self.sendBuffer = [] def setblocking(self, blocking): self.blocking = blocking def recv(self, size): return self.data def send(self, bytes): """ I{Send} all of C{bytes} by accumulating it into C{self.sendBuffer}. @return: The length of C{bytes}, indicating all the data has been accepted. """ self.sendBuffer.append(bytes) return len(bytes) def shutdown(self, how): """ Shutdown is not implemented. The method is provided since real sockets have it and some code expects it. No behavior of L{FakeSocket} is affected by a call to it. """ def close(self): """ Close is not implemented. The method is provided since real sockets have it and some code expects it. No behavior of L{FakeSocket} is affected by a call to it. """ def setsockopt(self, *args): """ Setsockopt is not implemented. The method is provided since real sockets have it and some code expects it. No behavior of L{FakeSocket} is affected by a call to it. """ def fileno(self): """ Return a fake file descriptor. If actually used, this will have no connection to this L{FakeSocket} and will probably cause surprising results. """ return 1 class FakeSocketTests(TestCase): """ Test that the FakeSocket can be used by the doRead method of L{Connection} """ def test_blocking(self): skt = FakeSocket(b"someData") skt.setblocking(0) self.assertEqual(skt.blocking, 0) def test_recv(self): skt = FakeSocket(b"someData") self.assertEqual(skt.recv(10), b"someData") def test_send(self): """ L{FakeSocket.send} accepts the entire string passed to it, adds it to its send buffer, and returns its length. """ skt = FakeSocket(b"") count = skt.send(b"foo") self.assertEqual(count, 3) self.assertEqual(skt.sendBuffer, [b"foo"]) class FakeProtocol(Protocol): """ An L{IProtocol} that returns a value from its dataReceived method. """ def dataReceived(self, data): """ Return something other than L{None} to trigger a deprecation warning for that behavior. """ return () @implementer(IReactorFDSet) class _FakeFDSetReactor: """ An in-memory implementation of L{IReactorFDSet}, which records the current sets of active L{IReadDescriptor} and L{IWriteDescriptor}s. @ivar _readers: The set of L{IReadDescriptor}s active on this L{_FakeFDSetReactor} @type _readers: L{set} @ivar _writers: The set of L{IWriteDescriptor}s active on this L{_FakeFDSetReactor} @ivar _writers: L{set} """ def __init__(self): self._readers = set() self._writers = set() def addReader(self, reader): self._readers.add(reader) def removeReader(self, reader): if reader in self._readers: self._readers.remove(reader) def addWriter(self, writer): self._writers.add(writer) def removeWriter(self, writer): if writer in self._writers: self._writers.remove(writer) def removeAll(self): result = self.getReaders() + self.getWriters() self.__init__() return result def getReaders(self): return list(self._readers) def getWriters(self): return list(self._writers) verifyClass(IReactorFDSet, _FakeFDSetReactor) class TCPServerTests(TestCase): """ Whitebox tests for L{twisted.internet.tcp.Server}. """ def setUp(self): self.reactor = _FakeFDSetReactor() class FakePort: _realPortNumber = 3 self.skt = FakeSocket(b"") self.protocol = Protocol() self.server = Server( self.skt, self.protocol, ("", 0), FakePort(), None, self.reactor ) def test_writeAfterDisconnect(self): """ L{Server.write} discards bytes passed to it if called after it has lost its connection. """ self.server.connectionLost(Failure(Exception("Simulated lost connection"))) self.server.write(b"hello world") self.assertEqual(self.skt.sendBuffer, []) def test_writeAfterDisconnectAfterTLS(self): """ L{Server.write} discards bytes passed to it if called after it has lost its connection when the connection had started TLS. """ self.server.TLS = True self.test_writeAfterDisconnect() def test_writeSequenceAfterDisconnect(self): """ L{Server.writeSequence} discards bytes passed to it if called after it has lost its connection. """ self.server.connectionLost(Failure(Exception("Simulated lost connection"))) self.server.writeSequence([b"hello world"]) self.assertEqual(self.skt.sendBuffer, []) def test_writeSequenceAfterDisconnectAfterTLS(self): """ L{Server.writeSequence} discards bytes passed to it if called after it has lost its connection when the connection had started TLS. """ self.server.TLS = True self.test_writeSequenceAfterDisconnect() class TCPConnectionTests(TestCase): """ Whitebox tests for L{twisted.internet.tcp.Connection}. """ def test_doReadWarningIsRaised(self): """ When an L{IProtocol} implementation that returns a value from its C{dataReceived} method, a deprecated warning is emitted. """ skt = FakeSocket(b"someData") protocol = FakeProtocol() conn = Connection(skt, protocol) conn.doRead() warnings = self.flushWarnings([FakeProtocol.dataReceived]) self.assertEqual(warnings[0]["category"], DeprecationWarning) self.assertEqual( warnings[0]["message"], "Returning a value other than None from " "twisted.internet.test.test_tcp.FakeProtocol.dataReceived " "is deprecated since Twisted 11.0.0.", ) self.assertEqual(len(warnings), 1) def test_noTLSBeforeStartTLS(self): """ The C{TLS} attribute of a L{Connection} instance is C{False} before L{Connection.startTLS} is called. """ skt = FakeSocket(b"") protocol = FakeProtocol() conn = Connection(skt, protocol) self.assertFalse(conn.TLS) @skipIf(not useSSL, "No SSL support available") def test_tlsAfterStartTLS(self): """ The C{TLS} attribute of a L{Connection} instance is C{True} after L{Connection.startTLS} is called. """ skt = FakeSocket(b"") protocol = FakeProtocol() conn = Connection(skt, protocol, reactor=_FakeFDSetReactor()) conn._tlsClientDefault = True conn.startTLS(ClientContextFactory(), True) self.assertTrue(conn.TLS) class TCPCreator(EndpointCreator): """ Create IPv4 TCP endpoints for L{runProtocolsWithReactor}-based tests. """ interface = "127.0.0.1" def server(self, reactor): """ Create a server-side TCP endpoint. """ return TCP4ServerEndpoint(reactor, 0, interface=self.interface) def client(self, reactor, serverAddress): """ Create a client end point that will connect to the given address. @type serverAddress: L{IPv4Address} """ return TCP4ClientEndpoint(reactor, self.interface, serverAddress.port) class TCP6Creator(TCPCreator): """ Create IPv6 TCP endpoints for C{ReactorBuilder.runProtocolsWithReactor}-based tests. The endpoint types in question here are still the TCP4 variety, since these simply pass through IPv6 address literals to the reactor, and we are only testing address literals, not name resolution (as name resolution has not yet been implemented). See http://twistedmatrix.com/trac/ticket/4470 for more specific information about new endpoint classes. The naming is slightly misleading, but presumably if you're passing an IPv6 literal, you know what you're asking for. """ def __init__(self): self.interface = getLinkLocalIPv6Address() @implementer(IResolverSimple) class FakeResolver: """ A resolver implementation based on a C{dict} mapping names to addresses. """ def __init__(self, names): self.names = names def getHostByName(self, name, timeout): """ Return the address mapped to C{name} if it exists, or raise a C{DNSLookupError}. @param name: The name to resolve. @param timeout: The lookup timeout, ignore here. """ try: return succeed(self.names[name]) except KeyError: return fail(DNSLookupError("FakeResolver couldn't find " + name)) class TCPClientTestsBase(ReactorBuilder, ConnectionTestsMixin, StreamClientTestsMixin): """ Base class for builders defining tests related to L{IReactorTCP.connectTCP}. Classes which uses this in must provide all of the documented instance variables in order to specify how the test works. These are documented as instance variables rather than declared as methods due to some peculiar inheritance ordering concerns, but they are effectively abstract methods. @ivar endpoints: A client/server endpoint creator appropriate to the address family being tested. @type endpoints: L{twisted.internet.test.connectionmixins.EndpointCreator} @ivar interface: An IP address literal to locally bind a socket to as well as to connect to. This can be any valid interface for the local host. @type interface: C{str} @ivar port: An unused local listening port to listen on and connect to. This will be used in conjunction with the C{interface}. (Depending on what they're testing, some tests will locate their own port with L{findFreePort} instead.) @type port: C{int} @ivar family: an address family constant, such as L{socket.AF_INET}, L{socket.AF_INET6}, or L{socket.AF_UNIX}, which indicates the address family of the transport type under test. @type family: C{int} @ivar addressClass: the L{twisted.internet.interfaces.IAddress} implementor associated with the transport type under test. Must also be a 3-argument callable which produces an instance of same. @type addressClass: C{type} @ivar fakeDomainName: A fake domain name to use, to simulate hostname resolution and to distinguish between hostnames and IP addresses where necessary. @type fakeDomainName: C{str} """ requiredInterfaces = (IReactorTCP,) _port = None @property def port(self): """ Return the port number to connect to, using C{self._port} set up by C{listen} if available. @return: The port number to connect to. @rtype: C{int} """ if self._port is not None: return self._port.getHost().port return findFreePort(self.interface, self.family)[1] @property def interface(self): """ Return the interface attribute from the endpoints object. """ return self.endpoints.interface def listen(self, reactor, factory): """ Start a TCP server with the given C{factory}. @param reactor: The reactor to create the TCP port in. @param factory: The server factory. @return: A TCP port instance. """ self._port = reactor.listenTCP(0, factory, interface=self.interface) return self._port def connect(self, reactor, factory): """ Start a TCP client with the given C{factory}. @param reactor: The reactor to create the connection in. @param factory: The client factory. @return: A TCP connector instance. """ return reactor.connectTCP(self.interface, self.port, factory) def test_buildProtocolReturnsNone(self): """ When the factory's C{buildProtocol} returns L{None} the connection is gracefully closed. """ connectionLost = Deferred() reactor = self.buildReactor() serverFactory = MyServerFactory() serverFactory.protocolConnectionLost = connectionLost # Make sure the test ends quickly. stopOnError(self, reactor) class NoneFactory(ServerFactory): def buildProtocol(self, address): return None listening = self.endpoints.server(reactor).listen(serverFactory) def listened(port): clientFactory = NoneFactory() endpoint = self.endpoints.client(reactor, port.getHost()) return endpoint.connect(clientFactory) connecting = listening.addCallback(listened) def connectSucceeded(protocol): self.fail( "Stream client endpoint connect succeeded with %r, " "should have failed with NoProtocol." % (protocol,) ) def connectFailed(reason): reason.trap(NoProtocol) connecting.addCallbacks(connectSucceeded, connectFailed) def connected(ignored): # Now that the connection attempt has failed continue waiting for # the server-side connection to be lost. This is the behavior this # test is primarily concerned with. return connectionLost disconnecting = connecting.addCallback(connected) # Make sure any errors that happen in that process get logged quickly. disconnecting.addErrback(log.err) def disconnected(ignored): # The Deferred has to succeed at this point (because log.err always # returns None). If an error got logged it will fail the test. # Stop the reactor now so the test can complete one way or the # other now. reactor.stop() disconnecting.addCallback(disconnected) self.runReactor(reactor) def test_addresses(self): """ A client's transport's C{getHost} and C{getPeer} return L{IPv4Address} instances which have the dotted-quad string form of the resolved address of the local and remote endpoints of the connection respectively as their C{host} attribute, not the hostname originally passed in to L{connectTCP}, if a hostname was used. """ host, ignored = findFreePort(self.interface, self.family)[:2] reactor = self.buildReactor() fakeDomain = self.fakeDomainName reactor.installResolver(FakeResolver({fakeDomain: self.interface})) server = reactor.listenTCP( 0, ServerFactory.forProtocol(Protocol), interface=host ) serverAddress = server.getHost() transportData = {"host": None, "peer": None, "instance": None} class CheckAddress(Protocol): def makeConnection(self, transport): transportData["host"] = transport.getHost() transportData["peer"] = transport.getPeer() transportData["instance"] = transport reactor.stop() clientFactory = Stop(reactor) clientFactory.protocol = CheckAddress def connectMe(): while True: port = findFreePort(self.interface, self.family) bindAddress = (self.interface, port[1]) log.msg(f"Connect attempt with bindAddress {bindAddress}") try: reactor.connectTCP( fakeDomain, server.getHost().port, clientFactory, bindAddress=bindAddress, ) except ConnectBindError: continue else: clientFactory.boundPort = port[1] break needsRunningReactor(reactor, connectMe) self.runReactor(reactor) if clientFactory.failReason: self.fail(clientFactory.failReason.getTraceback()) transportRepr = "<{} to {} at {:x}>".format( transportData["instance"].__class__, transportData["instance"].addr, id(transportData["instance"]), ) boundPort = [host] + list( socket.getaddrinfo(self.interface, clientFactory.boundPort)[0][-1][1:] ) serverPort = [host] + list( socket.getaddrinfo(self.interface, serverAddress.port)[0][-1][1:] ) self.assertEqual(transportData["host"], self.addressClass("TCP", *boundPort)) self.assertEqual(transportData["peer"], self.addressClass("TCP", *serverPort)) self.assertEqual(repr(transportData["instance"]), transportRepr) def test_badContext(self): """ If the context factory passed to L{ITCPTransport.startTLS} raises an exception from its C{getContext} method, that exception is raised by L{ITCPTransport.startTLS}. """ reactor = self.buildReactor() brokenFactory = BrokenContextFactory() results = [] serverFactory = ServerFactory.forProtocol(Protocol) port = reactor.listenTCP(0, serverFactory, interface=self.interface) endpoint = self.endpoints.client(reactor, port.getHost()) clientFactory = ClientFactory() clientFactory.protocol = Protocol connectDeferred = endpoint.connect(clientFactory) def connected(protocol): if not ITLSTransport.providedBy(protocol.transport): results.append("skip") else: results.append( self.assertRaises( ValueError, protocol.transport.startTLS, brokenFactory ) ) def connectFailed(failure): results.append(failure) def whenRun(): connectDeferred.addCallback(connected) connectDeferred.addErrback(connectFailed) connectDeferred.addBoth(lambda ign: reactor.stop()) needsRunningReactor(reactor, whenRun) self.runReactor(reactor) self.assertEqual(len(results), 1, f"more than one callback result: {results}") if isinstance(results[0], Failure): # self.fail(Failure) results[0].raiseException() if results[0] == "skip": raise SkipTest("Reactor does not support ITLSTransport") self.assertEqual(BrokenContextFactory.message, str(results[0])) class TCP4ClientTestsBuilder(TCPClientTestsBase): """ Builder configured with IPv4 parameters for tests related to L{IReactorTCP.connectTCP}. """ fakeDomainName = "some-fake.domain.example.com" family = socket.AF_INET addressClass = IPv4Address endpoints = TCPCreator() @skipIf(ipv6Skip, ipv6SkipReason) class TCP6ClientTestsBuilder(TCPClientTestsBase): """ Builder configured with IPv6 parameters for tests related to L{IReactorTCP.connectTCP}. """ family = socket.AF_INET6 addressClass = IPv6Address def setUp(self): # Only create this object here, so that it won't be created if tests # are being skipped: self.endpoints = TCP6Creator() # This is used by test_addresses to test the distinction between the # resolved name and the name on the socket itself. All the same # invariants should hold, but giving back an IPv6 address from a # resolver is not something the reactor can handle, so instead, we make # it so that the connect call for the IPv6 address test simply uses an # address literal. self.fakeDomainName = self.endpoints.interface class TCPConnectorTestsBuilder(ReactorBuilder): """ Tests for the L{IConnector} provider returned by L{IReactorTCP.connectTCP}. """ requiredInterfaces = (IReactorTCP,) def test_connectorIdentity(self): """ L{IReactorTCP.connectTCP} returns an object which provides L{IConnector}. The destination of the connector is the address which was passed to C{connectTCP}. The same connector object is passed to the factory's C{startedConnecting} method as to the factory's C{clientConnectionLost} method. """ serverFactory = ClosingFactory() reactor = self.buildReactor() tcpPort = reactor.listenTCP(0, serverFactory, interface=self.interface) serverFactory.port = tcpPort portNumber = tcpPort.getHost().port seenConnectors = [] seenFailures = [] clientFactory = ClientStartStopFactory() clientFactory.clientConnectionLost = lambda connector, reason: ( seenConnectors.append(connector), seenFailures.append(reason), ) clientFactory.startedConnecting = seenConnectors.append connector = reactor.connectTCP(self.interface, portNumber, clientFactory) self.assertTrue(IConnector.providedBy(connector)) dest = connector.getDestination() self.assertEqual(dest.type, "TCP") self.assertEqual(dest.host, self.interface) self.assertEqual(dest.port, portNumber) clientFactory.whenStopped.addBoth(lambda _: reactor.stop()) self.runReactor(reactor) seenFailures[0].trap(ConnectionDone) self.assertEqual(seenConnectors, [connector, connector]) def test_userFail(self): """ Calling L{IConnector.stopConnecting} in C{Factory.startedConnecting} results in C{Factory.clientConnectionFailed} being called with L{error.UserError} as the reason. """ serverFactory = MyServerFactory() reactor = self.buildReactor() tcpPort = reactor.listenTCP(0, serverFactory, interface=self.interface) portNumber = tcpPort.getHost().port fatalErrors = [] def startedConnecting(connector): try: connector.stopConnecting() except Exception: fatalErrors.append(Failure()) reactor.stop() clientFactory = ClientStartStopFactory() clientFactory.startedConnecting = startedConnecting clientFactory.whenStopped.addBoth(lambda _: reactor.stop()) reactor.callWhenRunning( lambda: reactor.connectTCP(self.interface, portNumber, clientFactory) ) self.runReactor(reactor) if fatalErrors: self.fail(fatalErrors[0].getTraceback()) clientFactory.reason.trap(UserError) self.assertEqual(clientFactory.failed, 1) def test_reconnect(self): """ Calling L{IConnector.connect} in C{Factory.clientConnectionLost} causes a new connection attempt to be made. """ serverFactory = ClosingFactory() reactor = self.buildReactor() tcpPort = reactor.listenTCP(0, serverFactory, interface=self.interface) serverFactory.port = tcpPort portNumber = tcpPort.getHost().port clientFactory = MyClientFactory() def clientConnectionLost(connector, reason): connector.connect() clientFactory.clientConnectionLost = clientConnectionLost reactor.connectTCP(self.interface, portNumber, clientFactory) protocolMadeAndClosed = [] def reconnectFailed(ignored): p = clientFactory.protocol protocolMadeAndClosed.append((p.made, p.closed)) reactor.stop() clientFactory.failDeferred.addCallback(reconnectFailed) self.runReactor(reactor) clientFactory.reason.trap(ConnectionRefusedError) self.assertEqual(protocolMadeAndClosed, [(1, 1)]) class TCP4ConnectorTestsBuilder(TCPConnectorTestsBuilder): interface = "127.0.0.1" family = socket.AF_INET addressClass = IPv4Address @skipIf(ipv6Skip, ipv6SkipReason) class TCP6ConnectorTestsBuilder(TCPConnectorTestsBuilder): family = socket.AF_INET6 addressClass = IPv6Address def setUp(self): self.interface = getLinkLocalIPv6Address() def createTestSocket(test, addressFamily, socketType): """ Create a socket for the duration of the given test. @param test: the test to add cleanup to. @param addressFamily: an C{AF_*} constant @param socketType: a C{SOCK_*} constant. @return: a socket object. """ skt = socket.socket(addressFamily, socketType) test.addCleanup(skt.close) return skt class _IExhaustsFileDescriptors(Interface): """ A way to trigger C{EMFILE}. """ def exhaust(): """ Open file descriptors until C{EMFILE} is reached. This can raise any exception except an L{OSError} whose C{errno} is C{EMFILE}. Any exception raised to the caller implies L{release}. """ def release(): """ Release all file descriptors opened by L{exhaust}. """ def count(): """ Return the number of opened file descriptors. @return: The number of opened file descriptors; this will be zero if this instance has not opened any. @rtype: L{int} """ @implementer(_IExhaustsFileDescriptors) @attr.s(auto_attribs=True) class _ExhaustsFileDescriptors: """ A class that triggers C{EMFILE} by creating as many file descriptors as necessary. @ivar fileDescriptorFactory: A factory that creates a new file descriptor. @type fileDescriptorFactory: A L{callable} that accepts no arguments and returns an integral file descriptor, suitable for passing to L{os.close}. """ _log: ClassVar[Logger] = Logger() _fileDescriptorFactory: Callable[[], int] = attr.ib( default=lambda: os.dup(0), repr=False ) _close: Callable[[int], None] = attr.ib(default=os.close, repr=False) _fileDescriptors: List[int] = attr.ib( default=attr.Factory(list), init=False, repr=False ) def exhaust(self): """ Open file descriptors until C{EMFILE} is reached. """ # Force a collection to close dangling files. gc.collect() try: while True: try: fd = self._fileDescriptorFactory() except OSError as e: if e.errno == errno.EMFILE: break raise else: self._fileDescriptors.append(fd) except Exception: self.release() raise else: self._log.info( "EMFILE reached by opening" " {openedFileDescriptors} file descriptors.", openedFileDescriptors=self.count(), ) def release(self): """ Release all file descriptors opened by L{exhaust}. """ while self._fileDescriptors: fd = self._fileDescriptors.pop() try: self._close(fd) except OSError as e: if e.errno == errno.EBADF: continue raise def count(self): """ Return the number of opened file descriptors. @return: The number of opened file descriptors; this will be zero if this instance has not opened any. @rtype: L{int} """ return len(self._fileDescriptors) @skipIf(SKIP_EMFILE, "Reserved EMFILE file descriptor not supported on Windows.") class ExhaustsFileDescriptorsTests(SynchronousTestCase): """ Tests for L{_ExhaustsFileDescriptors}. """ def setUp(self): self.exhauster = _ExhaustsFileDescriptors() # This assumes release succeeds when there are no file # descriptors to close. self.addCleanup(self.exhauster.release) def openAFile(self): """ Attempt to open a file; if successful, the file is immediately closed. """ open(os.devnull).close() def test_providesInterface(self): """ L{_ExhaustsFileDescriptors} instances provide L{_IExhaustsFileDescriptors}. """ verifyObject(_IExhaustsFileDescriptors, self.exhauster) def test_count(self): """ L{_ExhaustsFileDescriptors.count} returns the number of open file descriptors. """ self.assertEqual(self.exhauster.count(), 0) self.exhauster.exhaust() self.assertGreater(self.exhauster.count(), 0) self.exhauster.release() self.assertEqual(self.exhauster.count(), 0) def test_exhaustTriggersEMFILE(self): """ L{_ExhaustsFileDescriptors.exhaust} causes the process to exhaust its available file descriptors. """ self.addCleanup(self.exhauster.release) self.exhauster.exhaust() exception = self.assertRaises(IOError, self.openAFile) self.assertEqual(exception.errno, errno.EMFILE) def test_exhaustRaisesOSError(self): """ An L{OSError} raised within L{_ExhaustsFileDescriptors.exhaust} with an C{errno} other than C{EMFILE} is reraised to the caller. """ def raiseOSError(): raise OSError(errno.EMFILE + 1, "Not EMFILE") exhauster = _ExhaustsFileDescriptors(raiseOSError) self.assertRaises(OSError, exhauster.exhaust) def test_release(self): """ L{_ExhaustsFileDescriptors.release} releases all opened file descriptors. """ self.exhauster.exhaust() self.exhauster.release() # Does not fail with EMFILE self.openAFile() def test_fileDescriptorsReleasedOnFailure(self): """ L{_ExhaustsFileDescriptors.exhaust} closes any opened file descriptors if an exception occurs during its exhaustion loop. """ fileDescriptors = [] def failsAfterThree(): if len(fileDescriptors) == 3: raise ValueError( "test_fileDescriptorsReleasedOnFailure" " fake open exception" ) else: fd = os.dup(0) fileDescriptors.append(fd) return fd exhauster = _ExhaustsFileDescriptors(failsAfterThree) self.addCleanup(exhauster.release) self.assertRaises(ValueError, exhauster.exhaust) self.assertEqual(len(fileDescriptors), 3) self.assertEqual(exhauster.count(), 0) for fd in fileDescriptors: exception = self.assertRaises(OSError, os.fstat, fd) self.assertEqual(exception.errno, errno.EBADF) def test_releaseIgnoresEBADF(self): """ L{_ExhaustsFileDescriptors.release} continues to close opened file descriptors even when closing one fails with C{EBADF}. """ fileDescriptors = [] def recordFileDescriptors(): fd = os.dup(0) fileDescriptors.append(fd) return fd exhauster = _ExhaustsFileDescriptors(recordFileDescriptors) self.addCleanup(exhauster.release) exhauster.exhaust() self.assertGreater(exhauster.count(), 0) os.close(fileDescriptors[0]) exhauster.release() self.assertEqual(exhauster.count(), 0) def test_releaseRaisesOSError(self): """ An L{OSError} raised within L{_ExhaustsFileDescriptors.release} with an C{errno} other than C{EBADF} is reraised to the caller. """ fakeFileDescriptors = [] def opensThree(): if len(fakeFileDescriptors) == 3: raise OSError(errno.EMFILE, "Too many files") fakeFileDescriptors.append(-1) return fakeFileDescriptors[-1] def failingClose(fd): raise OSError(11, "test_releaseRaisesOSError fake OSError") exhauster = _ExhaustsFileDescriptors(opensThree, close=failingClose) self.assertEqual(exhauster.count(), 0) exhauster.exhaust() self.assertGreater(exhauster.count(), 0) self.assertRaises(OSError, exhauster.release) def assertPeerClosedOnEMFILE( testCase, exhauster, reactor, runReactor, listen, connect, ): """ Assert that an L{IListeningPort} immediately closes an accepted peer socket when the number of open file descriptors exceeds the soft resource limit. @param testCase: The test case under which to run this assertion. @type testCase: L{trial.unittest.SynchronousTestCase} @param exhauster: The file descriptor exhauster. @type exhauster: L{_ExhaustsFileDescriptors} @param reactor: The reactor under test. @param runReactor: A callable that will synchronously run the provided reactor. @param listen: A callback to bind to a port. @type listen: A L{callable} that accepts two arguments: the provided C{reactor}; and a L{ServerFactory}. It must return an L{IListeningPort} provider. @param connect: A callback to connect a client to the listening port. @type connect: A L{callable} that accepts three arguments: the provided C{reactor}; the address returned by L{IListeningPort.getHost}; and a L{ClientFactory}. Its return value is ignored. """ testCase.addCleanup(exhauster.release) serverFactory = MyServerFactory() serverConnectionMade = Deferred() serverFactory.protocolConnectionMade = serverConnectionMade serverConnectionCompleted = [False] def stopReactorIfServerAccepted(_): reactor.stop() serverConnectionCompleted[0] = True serverConnectionMade.addCallback(stopReactorIfServerAccepted) port = listen(reactor, serverFactory) listeningHost = port.getHost() clientFactory = MyClientFactory() connect(reactor, listeningHost, clientFactory) reactor.callWhenRunning(exhauster.exhaust) def stopReactorAndCloseFileDescriptors(result): exhauster.release() reactor.stop() return result clientFactory.deferred.addBoth(stopReactorAndCloseFileDescriptors) clientFactory.failDeferred.addBoth(stopReactorAndCloseFileDescriptors) runReactor(reactor) noResult = [] serverConnectionMade.addBoth(noResult.append) testCase.assertFalse(noResult, "Server accepted connection; EMFILE not triggered.") testCase.assertNoResult(clientFactory.failDeferred) testCase.successResultOf(clientFactory.deferred) testCase.assertRaises( ConnectionClosed, clientFactory.lostReason.raiseException, ) @skipIf(SKIP_EMFILE, "Reserved EMFILE file descriptor not supported on Windows.") class AssertPeerClosedOnEMFILETests(SynchronousTestCase): """ Tests for L{assertPeerClosedOnEMFILE}. """ @implementer(_IExhaustsFileDescriptors) class NullExhauster: """ An exhauster that does nothing. """ def exhaust(self): """ See L{_IExhaustsFileDescriptors.exhaust} """ def release(self): """ See L{_IExhaustsFileDescriptors.release} """ def count(self): """ See L{_IExhaustsFileDescriptors.count} """ def setUp(self): self.reactor = MemoryReactor() self.testCase = SynchronousTestCase() def test_nullExhausterProvidesInterface(self): """ L{NullExhauster} instances provide L{_IExhaustsFileDescriptors}. """ verifyObject(_IExhaustsFileDescriptors, self.NullExhauster()) def test_reactorStoppedOnSuccessfulConnection(self): """ If the exhauster fails to trigger C{EMFILE} and a connection reaches the server, the reactor is stopped and the test fails. """ exhauster = self.NullExhauster() serverFactory = [None] def runReactor(reactor): reactor.run() proto = serverFactory[0].buildProtocol( IPv4Address("TCP", "127.0.0.1", 4321) ) proto.makeConnection(StringTransport()) def listen(reactor, factory): port = reactor.listenTCP("127.0.0.1", 1234, factory) factory.doStart() serverFactory[0] = factory return port def connect(reactor, address, factory): reactor.connectTCP("127.0.0.1", 0, factory) exception = self.assertRaises( self.testCase.failureException, assertPeerClosedOnEMFILE, testCase=self.testCase, exhauster=exhauster, reactor=self.reactor, runReactor=runReactor, listen=listen, connect=connect, ) self.assertIn("EMFILE", str(exception)) self.assertFalse(self.reactor.running) class StreamTransportTestsMixin(LogObserverMixin): """ Mixin defining tests which apply to any port/connection based transport. """ def test_startedListeningLogMessage(self): """ When a port starts, a message including a description of the associated factory is logged. """ loggedMessages = self.observe() reactor = self.buildReactor() @implementer(ILoggingContext) class SomeFactory(ServerFactory): def logPrefix(self): return "Crazy Factory" factory = SomeFactory() p = self.getListeningPort(reactor, factory) expectedMessage = self.getExpectedStartListeningLogMessage(p, "Crazy Factory") self.assertEqual((expectedMessage,), loggedMessages[0]["message"]) def test_connectionLostLogMsg(self): """ When a connection is lost, an informative message should be logged (see L{getExpectedConnectionLostLogMsg}): an address identifying the port and the fact that it was closed. """ loggedMessages = [] def logConnectionLostMsg(eventDict): loggedMessages.append(log.textFromEventDict(eventDict)) reactor = self.buildReactor() p = self.getListeningPort(reactor, ServerFactory()) expectedMessage = self.getExpectedConnectionLostLogMsg(p) log.addObserver(logConnectionLostMsg) def stopReactor(ignored): log.removeObserver(logConnectionLostMsg) reactor.stop() def doStopListening(): log.addObserver(logConnectionLostMsg) maybeDeferred(p.stopListening).addCallback(stopReactor) reactor.callWhenRunning(doStopListening) reactor.run() self.assertIn(expectedMessage, loggedMessages) @skipIf(SKIP_EMFILE, "Reserved EMFILE file descriptor not supported on Windows.") def test_closePeerOnEMFILE(self): """ See L{assertPeerClosedOnEMFILE}. """ assertPeerClosedOnEMFILE( testCase=self, exhauster=_ExhaustsFileDescriptors(), reactor=self.buildReactor(), runReactor=self.runReactor, listen=self.getListeningPort, connect=self.connectToListener, ) class ConnectToTCPListenerMixin: """ Provides L{connectToListener} for TCP transports. @ivar LISTENER_HOST: The host on which the port is expected to be listening. This is specific to avoid compatibility issues with Windows, which cannot connect to the wildcard host. @type LISTENER_HOST: L{str} @see: U{http://twistedmatrix.com/trac/ticket/1472} """ LISTENER_HOST = "127.0.0.1" def connectToListener(self, reactor, address, factory): """ Connect to the given listening TCP port. @param reactor: The reactor under test. @type reactor: L{IReactorTCP} @param address: The listening port's address. Only the C{port} component is used; see L{LISTENER_HOST}. @type address: L{IPv4Address} or L{IPv6Address} @param factory: The client factory. @type factory: L{ClientFactory} @return: The connector """ return reactor.connectTCP(self.LISTENER_HOST, address.port, factory) class ListenTCPMixin(ConnectToTCPListenerMixin): """ Mixin which uses L{IReactorTCP.listenTCP} to hand out listening TCP ports. """ def getListeningPort(self, reactor, factory, port=0, interface=""): """ Get a TCP port from a reactor. """ return reactor.listenTCP(port, factory, interface=interface) class SocketTCPMixin(ConnectToTCPListenerMixin): """ Mixin which uses L{IReactorSocket.adoptStreamPort} to hand out listening TCP ports. """ def getListeningPort(self, reactor, factory, port=0, interface=""): """ Get a TCP port from a reactor, wrapping an already-initialized file descriptor. """ if IReactorSocket.providedBy(reactor): if ":" in interface: domain = socket.AF_INET6 address = socket.getaddrinfo(interface, port)[0][4] else: domain = socket.AF_INET address = (interface, port) portSock = socket.socket(domain) portSock.bind(address) portSock.listen(3) portSock.setblocking(False) try: return reactor.adoptStreamPort( portSock.fileno(), portSock.family, factory ) finally: # The socket should still be open; fileno will raise if it is # not. portSock.fileno() # Now clean it up, because the rest of the test does not need # it. portSock.close() else: raise SkipTest("Reactor does not provide IReactorSocket") class TCPPortTestsMixin: """ Tests for L{IReactorTCP.listenTCP} """ requiredInterfaces: Optional[Sequence[Type[Interface]]] = (IReactorTCP,) def getExpectedStartListeningLogMessage(self, port, factory): """ Get the message expected to be logged when a TCP port starts listening. """ return "%s starting on %d" % (factory, port.getHost().port) def getExpectedConnectionLostLogMsg(self, port): """ Get the expected connection lost message for a TCP port. """ return f"(TCP Port {port.getHost().port} Closed)" def test_portGetHostOnIPv4(self): """ When no interface is passed to L{IReactorTCP.listenTCP}, the returned listening port listens on an IPv4 address. """ reactor = self.buildReactor() port = self.getListeningPort(reactor, ServerFactory()) address = port.getHost() self.assertIsInstance(address, IPv4Address) @skipIf(ipv6Skip, ipv6SkipReason) def test_portGetHostOnIPv6(self): """ When listening on an IPv6 address, L{IListeningPort.getHost} returns an L{IPv6Address} with C{host} and C{port} attributes reflecting the address the port is bound to. """ reactor = self.buildReactor() host, portNumber = findFreePort(family=socket.AF_INET6, interface="::1")[:2] port = self.getListeningPort(reactor, ServerFactory(), portNumber, host) address = port.getHost() self.assertIsInstance(address, IPv6Address) self.assertEqual("::1", address.host) self.assertEqual(portNumber, address.port) @skipIf(ipv6Skip, ipv6SkipReason) def test_portGetHostOnIPv6ScopeID(self): """ When a link-local IPv6 address including a scope identifier is passed as the C{interface} argument to L{IReactorTCP.listenTCP}, the resulting L{IListeningPort} reports its address as an L{IPv6Address} with a host value that includes the scope identifier. """ linkLocal = getLinkLocalIPv6Address() reactor = self.buildReactor() port = self.getListeningPort(reactor, ServerFactory(), 0, linkLocal) address = port.getHost() self.assertIsInstance(address, IPv6Address) self.assertEqual(linkLocal, address.host) def _buildProtocolAddressTest(self, client, interface): """ Connect C{client} to a server listening on C{interface} started with L{IReactorTCP.listenTCP} and return the address passed to the factory's C{buildProtocol} method. @param client: A C{SOCK_STREAM} L{socket.socket} created with an address family such that it will be able to connect to a server listening on C{interface}. @param interface: A C{str} giving an address for a server to listen on. This should almost certainly be the loopback address for some address family supported by L{IReactorTCP.listenTCP}. @return: Whatever object, probably an L{IAddress} provider, is passed to a server factory's C{buildProtocol} method when C{client} establishes a connection. """ class ObserveAddress(ServerFactory): def buildProtocol(self, address): reactor.stop() self.observedAddress = address return Protocol() factory = ObserveAddress() reactor = self.buildReactor() port = self.getListeningPort(reactor, factory, 0, interface) client.setblocking(False) try: connect(client, (port.getHost().host, port.getHost().port)) except OSError as e: self.assertIn(e.errno, (errno.EINPROGRESS, errno.EWOULDBLOCK)) self.runReactor(reactor) return factory.observedAddress def test_buildProtocolIPv4Address(self): """ When a connection is accepted over IPv4, an L{IPv4Address} is passed to the factory's C{buildProtocol} method giving the peer's address. """ interface = "127.0.0.1" client = createTestSocket(self, socket.AF_INET, socket.SOCK_STREAM) observedAddress = self._buildProtocolAddressTest(client, interface) self.assertEqual(IPv4Address("TCP", *client.getsockname()), observedAddress) @skipIf(ipv6Skip, ipv6SkipReason) def test_buildProtocolIPv6Address(self): """ When a connection is accepted to an IPv6 address, an L{IPv6Address} is passed to the factory's C{buildProtocol} method giving the peer's address. """ interface = "::1" client = createTestSocket(self, socket.AF_INET6, socket.SOCK_STREAM) observedAddress = self._buildProtocolAddressTest(client, interface) peer = client.getsockname() hostname = socket.getnameinfo(peer, socket.NI_NUMERICHOST)[0] self.assertEqual(IPv6Address("TCP", hostname, peer[1]), observedAddress) @skipIf(ipv6Skip, ipv6SkipReason) def test_buildProtocolIPv6AddressScopeID(self): """ When a connection is accepted to a link-local IPv6 address, an L{IPv6Address} is passed to the factory's C{buildProtocol} method giving the peer's address, including a scope identifier. """ interface = getLinkLocalIPv6Address() client = createTestSocket(self, socket.AF_INET6, socket.SOCK_STREAM) observedAddress = self._buildProtocolAddressTest(client, interface) peer = client.getsockname() hostname = socket.getnameinfo(peer, socket.NI_NUMERICHOST)[0] self.assertEqual(IPv6Address("TCP", hostname, *peer[1:]), observedAddress) def _serverGetConnectionAddressTest(self, client, interface, which): """ Connect C{client} to a server listening on C{interface} started with L{IReactorTCP.listenTCP} and return the address returned by one of the server transport's address lookup methods, C{getHost} or C{getPeer}. @param client: A C{SOCK_STREAM} L{socket.socket} created with an address family such that it will be able to connect to a server listening on C{interface}. @param interface: A C{str} giving an address for a server to listen on. This should almost certainly be the loopback address for some address family supported by L{IReactorTCP.listenTCP}. @param which: A C{str} equal to either C{"getHost"} or C{"getPeer"} determining which address will be returned. @return: Whatever object, probably an L{IAddress} provider, is returned from the method indicated by C{which}. """ class ObserveAddress(Protocol): def makeConnection(self, transport): reactor.stop() self.factory.address = getattr(transport, which)() reactor = self.buildReactor() factory = ServerFactory() factory.protocol = ObserveAddress port = self.getListeningPort(reactor, factory, 0, interface) client.setblocking(False) try: connect(client, (port.getHost().host, port.getHost().port)) except OSError as e: self.assertIn(e.errno, (errno.EINPROGRESS, errno.EWOULDBLOCK)) self.runReactor(reactor) return factory.address def test_serverGetHostOnIPv4(self): """ When a connection is accepted over IPv4, the server L{ITransport.getHost} method returns an L{IPv4Address} giving the address on which the server accepted the connection. """ interface = "127.0.0.1" client = createTestSocket(self, socket.AF_INET, socket.SOCK_STREAM) hostAddress = self._serverGetConnectionAddressTest(client, interface, "getHost") self.assertEqual(IPv4Address("TCP", *client.getpeername()), hostAddress) @skipIf(ipv6Skip, ipv6SkipReason) def test_serverGetHostOnIPv6(self): """ When a connection is accepted over IPv6, the server L{ITransport.getHost} method returns an L{IPv6Address} giving the address on which the server accepted the connection. """ interface = "::1" client = createTestSocket(self, socket.AF_INET6, socket.SOCK_STREAM) hostAddress = self._serverGetConnectionAddressTest(client, interface, "getHost") peer = client.getpeername() hostname = socket.getnameinfo(peer, socket.NI_NUMERICHOST)[0] self.assertEqual(IPv6Address("TCP", hostname, *peer[1:]), hostAddress) @skipIf(ipv6Skip, ipv6SkipReason) def test_serverGetHostOnIPv6ScopeID(self): """ When a connection is accepted over IPv6, the server L{ITransport.getHost} method returns an L{IPv6Address} giving the address on which the server accepted the connection, including the scope identifier. """ interface = getLinkLocalIPv6Address() client = createTestSocket(self, socket.AF_INET6, socket.SOCK_STREAM) hostAddress = self._serverGetConnectionAddressTest(client, interface, "getHost") peer = client.getpeername() hostname = socket.getnameinfo(peer, socket.NI_NUMERICHOST)[0] self.assertEqual(IPv6Address("TCP", hostname, *peer[1:]), hostAddress) def test_serverGetPeerOnIPv4(self): """ When a connection is accepted over IPv4, the server L{ITransport.getPeer} method returns an L{IPv4Address} giving the address of the remote end of the connection. """ interface = "127.0.0.1" client = createTestSocket(self, socket.AF_INET, socket.SOCK_STREAM) peerAddress = self._serverGetConnectionAddressTest(client, interface, "getPeer") self.assertEqual(IPv4Address("TCP", *client.getsockname()), peerAddress) @skipIf(ipv6Skip, ipv6SkipReason) def test_serverGetPeerOnIPv6(self): """ When a connection is accepted over IPv6, the server L{ITransport.getPeer} method returns an L{IPv6Address} giving the address on the remote end of the connection. """ interface = "::1" client = createTestSocket(self, socket.AF_INET6, socket.SOCK_STREAM) peerAddress = self._serverGetConnectionAddressTest(client, interface, "getPeer") peer = client.getsockname() hostname = socket.getnameinfo(peer, socket.NI_NUMERICHOST)[0] self.assertEqual(IPv6Address("TCP", hostname, *peer[1:]), peerAddress) @skipIf(ipv6Skip, ipv6SkipReason) def test_serverGetPeerOnIPv6ScopeID(self): """ When a connection is accepted over IPv6, the server L{ITransport.getPeer} method returns an L{IPv6Address} giving the address on the remote end of the connection, including the scope identifier. """ interface = getLinkLocalIPv6Address() client = createTestSocket(self, socket.AF_INET6, socket.SOCK_STREAM) peerAddress = self._serverGetConnectionAddressTest(client, interface, "getPeer") peer = client.getsockname() hostname = socket.getnameinfo(peer, socket.NI_NUMERICHOST)[0] self.assertEqual(IPv6Address("TCP", hostname, *peer[1:]), peerAddress) class TCPPortTestsBuilder( ReactorBuilder, ListenTCPMixin, TCPPortTestsMixin, StreamTransportTestsMixin, ): pass class TCPFDPortTestsBuilder( ReactorBuilder, SocketTCPMixin, TCPPortTestsMixin, StreamTransportTestsMixin, ): pass class StopStartReadingProtocol(Protocol): """ Protocol that pauses and resumes the transport a few times """ def connectionMade(self): self.data = b"" self.pauseResumeProducing(3) def pauseResumeProducing(self, counter): """ Toggle transport read state, then count down. """ self.transport.pauseProducing() self.transport.resumeProducing() if counter: self.factory.reactor.callLater(0, self.pauseResumeProducing, counter - 1) else: self.factory.reactor.callLater(0, self.factory.ready.callback, self) def dataReceived(self, data): log.msg("got data", len(data)) self.data += data if len(self.data) == 4 * 4096: self.factory.stop.callback(self.data) def oneTransportTest(testMethod): """ Decorate a L{ReactorBuilder} test function which tests one reactor and one connected transport. Run that test method in the context of C{connectionMade}, and immediately drop the connection (and end the test) when that completes. @param testMethod: A unit test method on a L{ReactorBuilder} test suite; taking two additional parameters; a C{reactor} as built by the L{ReactorBuilder}, and an L{ITCPTransport} provider. @type testMethod: 3-argument C{function} @return: a no-argument test method. @rtype: 1-argument C{function} """ @wraps(testMethod) def actualTestMethod(builder): other = ConnectableProtocol() class ServerProtocol(ConnectableProtocol): def connectionMade(self): try: testMethod(builder, self.reactor, self.transport) finally: if self.transport is not None: self.transport.loseConnection() if other.transport is not None: other.transport.loseConnection() serverProtocol = ServerProtocol() runProtocolsWithReactor(builder, serverProtocol, other, TCPCreator()) return actualTestMethod def assertReading(testCase, reactor, transport): """ Use the given test to assert that the given transport is actively reading in the given reactor. @note: Maintainers; for more information on why this is a function rather than a method on a test case, see U{this document on how we structure test tools } @param testCase: a test case to perform the assertion upon. @type testCase: L{TestCase} @param reactor: A reactor, possibly one providing L{IReactorFDSet}, or an IOCP reactor. @param transport: An L{ITCPTransport} """ if IReactorFDSet.providedBy(reactor): testCase.assertIn(transport, reactor.getReaders()) else: # IOCP. testCase.assertIn(transport, reactor.handles) testCase.assertTrue(transport.reading) def assertNotReading(testCase, reactor, transport): """ Use the given test to assert that the given transport is I{not} actively reading in the given reactor. @note: Maintainers; for more information on why this is a function rather than a method on a test case, see U{this document on how we structure test tools } @param testCase: a test case to perform the assertion upon. @type testCase: L{TestCase} @param reactor: A reactor, possibly one providing L{IReactorFDSet}, or an IOCP reactor. @param transport: An L{ITCPTransport} """ if IReactorFDSet.providedBy(reactor): testCase.assertNotIn(transport, reactor.getReaders()) else: # IOCP. testCase.assertFalse(transport.reading) class TCPConnectionTestsBuilder(ReactorBuilder): """ Builder defining tests relating to L{twisted.internet.tcp.Connection}. """ requiredInterfaces = (IReactorTCP,) def test_stopStartReading(self): """ This test verifies transport socket read state after multiple pause/resumeProducing calls. """ sf = ServerFactory() reactor = sf.reactor = self.buildReactor() skippedReactors = ["Glib2Reactor", "Gtk2Reactor"] reactorClassName = reactor.__class__.__name__ if reactorClassName in skippedReactors and platform.isWindows(): raise SkipTest("This test is broken on gtk/glib under Windows.") sf.protocol = StopStartReadingProtocol sf.ready = Deferred() sf.stop = Deferred() p = reactor.listenTCP(0, sf) port = p.getHost().port def proceed(protos, port): """ Send several IOCPReactor's buffers' worth of data. """ self.assertTrue(protos[0]) self.assertTrue(protos[1]) protos = protos[0][1], protos[1][1] protos[0].transport.write(b"x" * (2 * 4096) + b"y" * (2 * 4096)) return sf.stop.addCallback(cleanup, protos, port).addCallback( lambda ign: reactor.stop() ) def cleanup(data, protos, port): """ Make sure IOCPReactor didn't start several WSARecv operations that clobbered each other's results. """ self.assertEqual( data, b"x" * (2 * 4096) + b"y" * (2 * 4096), "did not get the right data", ) return DeferredList( [ maybeDeferred(protos[0].transport.loseConnection), maybeDeferred(protos[1].transport.loseConnection), maybeDeferred(port.stopListening), ] ) cc = TCP4ClientEndpoint(reactor, "127.0.0.1", port) cf = ClientFactory() cf.protocol = Protocol d = DeferredList([cc.connect(cf), sf.ready]).addCallback(proceed, p) d.addErrback(log.err) self.runReactor(reactor) @oneTransportTest def test_resumeProducing(self, reactor, server): """ When a L{Server} is connected, its C{resumeProducing} method adds it as a reader to the reactor. """ server.pauseProducing() assertNotReading(self, reactor, server) server.resumeProducing() assertReading(self, reactor, server) @oneTransportTest def test_resumeProducingWhileDisconnecting(self, reactor, server): """ When a L{Server} has already started disconnecting via C{loseConnection}, its C{resumeProducing} method does not add it as a reader to its reactor. """ server.loseConnection() server.resumeProducing() assertNotReading(self, reactor, server) @oneTransportTest def test_resumeProducingWhileDisconnected(self, reactor, server): """ When a L{Server} has already lost its connection, its C{resumeProducing} method does not add it as a reader to its reactor. """ server.connectionLost(Failure(Exception("dummy"))) assertNotReading(self, reactor, server) server.resumeProducing() assertNotReading(self, reactor, server) def test_connectionLostAfterPausedTransport(self): """ Alice connects to Bob. Alice writes some bytes and then shuts down the connection. Bob receives the bytes from the connection and then pauses the transport object. Shortly afterwards Bob resumes the transport object. At that point, Bob is notified that the connection has been closed. This is no problem for most reactors. The underlying event notification API will probably just remind them that the connection has been closed. It is a little tricky for win32eventreactor (MsgWaitForMultipleObjects). MsgWaitForMultipleObjects will only deliver the close notification once. The reactor needs to remember that notification until Bob resumes the transport. """ class Pauser(ConnectableProtocol): def __init__(self): self.events = [] def dataReceived(self, bytes): self.events.append("paused") self.transport.pauseProducing() self.reactor.callLater(0, self.resume) def resume(self): self.events.append("resumed") self.transport.resumeProducing() def connectionLost(self, reason): # This is the event you have been waiting for. self.events.append("lost") ConnectableProtocol.connectionLost(self, reason) class Client(ConnectableProtocol): def connectionMade(self): self.transport.write(b"some bytes for you") self.transport.loseConnection() pauser = Pauser() runProtocolsWithReactor(self, pauser, Client(), TCPCreator()) self.assertEqual(pauser.events, ["paused", "resumed", "lost"]) def test_doubleHalfClose(self): """ If one side half-closes its connection, and then the other side of the connection calls C{loseWriteConnection}, and then C{loseConnection} in {writeConnectionLost}, the connection is closed correctly. This rather obscure case used to fail (see ticket #3037). """ @implementer(IHalfCloseableProtocol) class ListenerProtocol(ConnectableProtocol): def readConnectionLost(self): self.transport.loseWriteConnection() def writeConnectionLost(self): self.transport.loseConnection() class Client(ConnectableProtocol): def connectionMade(self): self.transport.loseConnection() # If test fails, reactor won't stop and we'll hit timeout: runProtocolsWithReactor(self, ListenerProtocol(), Client(), TCPCreator()) class WriteSequenceTestsMixin: """ Test for L{twisted.internet.abstract.FileDescriptor.writeSequence}. """ requiredInterfaces: Optional[Sequence[Type[Interface]]] = (IReactorTCP,) def setWriteBufferSize(self, transport, value): """ Set the write buffer size for the given transport, mananing possible differences (ie, IOCP). Bug #4322 should remove the need of that hack. """ if getattr(transport, "writeBufferSize", None) is not None: transport.writeBufferSize = value else: transport.bufferSize = value def test_writeSequeceWithoutWrite(self): """ C{writeSequence} sends the data even if C{write} hasn't been called. """ def connected(protocols): client, server, port = protocols def dataReceived(data): log.msg("data received: %r" % data) self.assertEqual(data, b"Some sequence splitted") client.transport.loseConnection() server.dataReceived = dataReceived client.transport.writeSequence([b"Some ", b"sequence ", b"splitted"]) reactor = self.buildReactor() d = self.getConnectedClientAndServer(reactor, "127.0.0.1", socket.AF_INET) d.addCallback(connected) d.addErrback(log.err) self.runReactor(reactor) def test_writeSequenceWithUnicodeRaisesException(self): """ C{writeSequence} with an element in the sequence of type unicode raises C{TypeError}. """ def connected(protocols): client, server, port = protocols exc = self.assertRaises( TypeError, server.transport.writeSequence, ["Unicode is not kosher"] ) self.assertEqual(str(exc), "Data must be bytes") server.transport.loseConnection() reactor = self.buildReactor() d = self.getConnectedClientAndServer(reactor, "127.0.0.1", socket.AF_INET) d.addCallback(connected) d.addErrback(log.err) self.runReactor(reactor) def test_streamingProducer(self): """ C{writeSequence} pauses its streaming producer if too much data is buffered, and then resumes it. """ @implementer(IPushProducer) class SaveActionProducer: client = None server = None def __init__(self): self.actions = [] def pauseProducing(self): self.actions.append("pause") def resumeProducing(self): self.actions.append("resume") # Unregister the producer so the connection can close self.client.transport.unregisterProducer() # This is why the code below waits for the server connection # first - so we have it to close here. We close the server # side because win32evenreactor cannot reliably observe us # closing the client side (#5285). self.server.transport.loseConnection() def stopProducing(self): self.actions.append("stop") producer = SaveActionProducer() def connected(protocols): client, server = protocols[:2] producer.client = client producer.server = server # Register a streaming producer and verify that it gets paused # after it writes more than the local send buffer can hold. client.transport.registerProducer(producer, True) self.assertEqual(producer.actions, []) self.setWriteBufferSize(client.transport, 500) client.transport.writeSequence([b"x" * 50] * 20) self.assertEqual(producer.actions, ["pause"]) reactor = self.buildReactor() d = self.getConnectedClientAndServer(reactor, "127.0.0.1", socket.AF_INET) d.addCallback(connected) d.addErrback(log.err) self.runReactor(reactor) # After the send buffer gets a chance to empty out a bit, the producer # should be resumed. self.assertEqual(producer.actions, ["pause", "resume"]) def test_nonStreamingProducer(self): """ C{writeSequence} pauses its producer if too much data is buffered only if this is a streaming producer. """ test = self @implementer(IPullProducer) class SaveActionProducer: client = None def __init__(self): self.actions = [] def resumeProducing(self): self.actions.append("resume") if self.actions.count("resume") == 2: self.client.transport.stopConsuming() else: test.setWriteBufferSize(self.client.transport, 500) self.client.transport.writeSequence([b"x" * 50] * 20) def stopProducing(self): self.actions.append("stop") producer = SaveActionProducer() def connected(protocols): client = protocols[0] producer.client = client # Register a non-streaming producer and verify that it is resumed # immediately. client.transport.registerProducer(producer, False) self.assertEqual(producer.actions, ["resume"]) reactor = self.buildReactor() d = self.getConnectedClientAndServer(reactor, "127.0.0.1", socket.AF_INET) d.addCallback(connected) d.addErrback(log.err) self.runReactor(reactor) # After the local send buffer empties out, the producer should be # resumed again. self.assertEqual(producer.actions, ["resume", "resume"]) class TCPTransportServerAddressTestMixin: """ Test mixing for TCP server address building and log prefix. """ def getConnectedClientAndServer(self, reactor, interface, addressFamily): """ Helper method returnine a L{Deferred} firing with a tuple of a client protocol, a server protocol, and a running TCP port. """ raise NotImplementedError() def _testServerAddress(self, interface, addressFamily, adressClass): """ Helper method to test TCP server addresses on either IPv4 or IPv6. """ def connected(protocols): client, server, port = protocols try: self.assertEqual( "" % (server.transport.sessionno, port.getHost().port), str(server.transport), ) self.assertEqual( "AccumulatingProtocol,%s,%s" % (server.transport.sessionno, interface), server.transport.logstr, ) [peerAddress] = server.factory.peerAddresses self.assertIsInstance(peerAddress, adressClass) self.assertEqual("TCP", peerAddress.type) self.assertEqual(interface, peerAddress.host) finally: # Be certain to drop the connection so the test completes. server.transport.loseConnection() reactor = self.buildReactor() d = self.getConnectedClientAndServer(reactor, interface, addressFamily) d.addCallback(connected) d.addErrback(log.err) self.runReactor(reactor) def test_serverAddressTCP4(self): """ L{Server} instances have a string representation indicating on which port they're running, and the connected address is stored on the C{peerAddresses} attribute of the factory. """ return self._testServerAddress("127.0.0.1", socket.AF_INET, IPv4Address) @skipIf(ipv6Skip, ipv6SkipReason) def test_serverAddressTCP6(self): """ IPv6 L{Server} instances have a string representation indicating on which port they're running, and the connected address is stored on the C{peerAddresses} attribute of the factory. """ return self._testServerAddress( getLinkLocalIPv6Address(), socket.AF_INET6, IPv6Address ) class TCPTransportTestsBuilder( TCPTransportServerAddressTestMixin, WriteSequenceTestsMixin, ReactorBuilder ): """ Test standard L{ITCPTransport}s built with C{listenTCP} and C{connectTCP}. """ def getConnectedClientAndServer(self, reactor, interface, addressFamily): """ Return a L{Deferred} firing with a L{MyClientFactory} and L{MyServerFactory} connected pair, and the listening C{Port}. """ server = MyServerFactory() server.protocolConnectionMade = Deferred() server.protocolConnectionLost = Deferred() client = MyClientFactory() client.protocolConnectionMade = Deferred() client.protocolConnectionLost = Deferred() port = reactor.listenTCP(0, server, interface=interface) lostDeferred = gatherResults( [client.protocolConnectionLost, server.protocolConnectionLost] ) def stop(result): reactor.stop() return result lostDeferred.addBoth(stop) startDeferred = gatherResults( [client.protocolConnectionMade, server.protocolConnectionMade] ) deferred = Deferred() def start(protocols): client, server = protocols log.msg("client connected %s" % client) log.msg("server connected %s" % server) deferred.callback((client, server, port)) startDeferred.addCallback(start) reactor.connectTCP(interface, port.getHost().port, client) return deferred class AdoptStreamConnectionTestsBuilder( TCPTransportServerAddressTestMixin, WriteSequenceTestsMixin, ReactorBuilder ): """ Test server transports built using C{adoptStreamConnection}. """ requiredInterfaces = (IReactorFDSet, IReactorSocket) def getConnectedClientAndServer(self, reactor, interface, addressFamily): """ Return a L{Deferred} firing with a L{MyClientFactory} and L{MyServerFactory} connected pair, and the listening C{Port}. The particularity is that the server protocol has been obtained after doing a C{adoptStreamConnection} against the original server connection. """ firstServer = MyServerFactory() firstServer.protocolConnectionMade = Deferred() server = MyServerFactory() server.protocolConnectionMade = Deferred() server.protocolConnectionLost = Deferred() client = MyClientFactory() client.protocolConnectionMade = Deferred() client.protocolConnectionLost = Deferred() port = reactor.listenTCP(0, firstServer, interface=interface) def firtServerConnected(proto): reactor.removeReader(proto.transport) reactor.removeWriter(proto.transport) reactor.adoptStreamConnection( proto.transport.fileno(), addressFamily, server ) firstServer.protocolConnectionMade.addCallback(firtServerConnected) lostDeferred = gatherResults( [client.protocolConnectionLost, server.protocolConnectionLost] ) def stop(result): if reactor.running: reactor.stop() return result lostDeferred.addBoth(stop) deferred = Deferred() deferred.addErrback(stop) startDeferred = gatherResults( [client.protocolConnectionMade, server.protocolConnectionMade] ) def start(protocols): client, server = protocols log.msg("client connected %s" % client) log.msg("server connected %s" % server) deferred.callback((client, server, port)) startDeferred.addCallback(start) reactor.connectTCP(interface, port.getHost().port, client) return deferred globals().update(TCP4ClientTestsBuilder.makeTestCaseClasses()) globals().update(TCP6ClientTestsBuilder.makeTestCaseClasses()) globals().update(TCPPortTestsBuilder.makeTestCaseClasses()) globals().update(TCPFDPortTestsBuilder.makeTestCaseClasses()) globals().update(TCPConnectionTestsBuilder.makeTestCaseClasses()) globals().update(TCP4ConnectorTestsBuilder.makeTestCaseClasses()) globals().update(TCP6ConnectorTestsBuilder.makeTestCaseClasses()) globals().update(TCPTransportTestsBuilder.makeTestCaseClasses()) globals().update(AdoptStreamConnectionTestsBuilder.makeTestCaseClasses()) class ServerAbortsTwice(ConnectableProtocol): """ Call abortConnection() twice. """ def dataReceived(self, data): self.transport.abortConnection() self.transport.abortConnection() class ServerAbortsThenLoses(ConnectableProtocol): """ Call abortConnection() followed by loseConnection(). """ def dataReceived(self, data): self.transport.abortConnection() self.transport.loseConnection() class AbortServerWritingProtocol(ConnectableProtocol): """ Protocol that writes data upon connection. """ def connectionMade(self): """ Tell the client that the connection is set up and it's time to abort. """ self.transport.write(b"ready") class ReadAbortServerProtocol(AbortServerWritingProtocol): """ Server that should never receive any data, except 'X's which are written by the other side of the connection before abortConnection, and so might possibly arrive. """ def dataReceived(self, data): if data.replace(b"X", b""): raise Exception("Unexpectedly received data.") class NoReadServer(ConnectableProtocol): """ Stop reading immediately on connection. This simulates a lost connection that will cause the other side to time out, and therefore call abortConnection(). """ def connectionMade(self): self.transport.stopReading() class EventualNoReadServer(ConnectableProtocol): """ Like NoReadServer, except we Wait until some bytes have been delivered before stopping reading. This means TLS handshake has finished, where applicable. """ gotData = False stoppedReading = False def dataReceived(self, data): if not self.gotData: self.gotData = True self.transport.registerProducer(self, False) self.transport.write(b"hello") def resumeProducing(self): if self.stoppedReading: return self.stoppedReading = True # We've written out the data: self.transport.stopReading() def pauseProducing(self): pass def stopProducing(self): pass class BaseAbortingClient(ConnectableProtocol): """ Base class for abort-testing clients. """ inReactorMethod = False def connectionLost(self, reason): if self.inReactorMethod: raise RuntimeError("BUG: connectionLost was called re-entrantly!") ConnectableProtocol.connectionLost(self, reason) class WritingButNotAbortingClient(BaseAbortingClient): """ Write data, but don't abort. """ def connectionMade(self): self.transport.write(b"hello") class AbortingClient(BaseAbortingClient): """ Call abortConnection() after writing some data. """ def dataReceived(self, data): """ Some data was received, so the connection is set up. """ self.inReactorMethod = True self.writeAndAbort() self.inReactorMethod = False def writeAndAbort(self): # X is written before abortConnection, and so there is a chance it # might arrive. Y is written after, and so no Ys should ever be # delivered: self.transport.write(b"X" * 10000) self.transport.abortConnection() self.transport.write(b"Y" * 10000) class AbortingTwiceClient(AbortingClient): """ Call abortConnection() twice, after writing some data. """ def writeAndAbort(self): AbortingClient.writeAndAbort(self) self.transport.abortConnection() class AbortingThenLosingClient(AbortingClient): """ Call abortConnection() and then loseConnection(). """ def writeAndAbort(self): AbortingClient.writeAndAbort(self) self.transport.loseConnection() class ProducerAbortingClient(ConnectableProtocol): """ Call abortConnection from doWrite, via resumeProducing. """ inReactorMethod = True producerStopped = False def write(self): self.transport.write(b"lalala" * 127000) self.inRegisterProducer = True self.transport.registerProducer(self, False) self.inRegisterProducer = False def connectionMade(self): self.write() def resumeProducing(self): self.inReactorMethod = True if not self.inRegisterProducer: self.transport.abortConnection() self.inReactorMethod = False def stopProducing(self): self.producerStopped = True def connectionLost(self, reason): if not self.producerStopped: raise RuntimeError("BUG: stopProducing() was never called.") if self.inReactorMethod: raise RuntimeError("BUG: connectionLost called re-entrantly!") ConnectableProtocol.connectionLost(self, reason) class StreamingProducerClient(ConnectableProtocol): """ Call abortConnection() when the other side has stopped reading. In particular, we want to call abortConnection() only once our local socket hits a state where it is no longer writeable. This helps emulate the most common use case for abortConnection(), closing a connection after a timeout, with write buffers being full. Since it's very difficult to know when this actually happens, we just write a lot of data, and assume at that point no more writes will happen. """ paused = False extraWrites = 0 inReactorMethod = False def connectionMade(self): self.write() def write(self): """ Write large amount to transport, then wait for a while for buffers to fill up. """ self.transport.registerProducer(self, True) for i in range(100): self.transport.write(b"1234567890" * 32000) def resumeProducing(self): self.paused = False def stopProducing(self): pass def pauseProducing(self): """ Called when local buffer fills up. The goal is to hit the point where the local file descriptor is not writeable (or the moral equivalent). The fact that pauseProducing has been called is not sufficient, since that can happen when Twisted's buffers fill up but OS hasn't gotten any writes yet. We want to be as close as possible to every buffer (including OS buffers) being full. So, we wait a bit more after this for Twisted to write out a few chunks, then abortConnection. """ if self.paused: return self.paused = True # The amount we wait is arbitrary, we just want to make sure some # writes have happened and outgoing OS buffers filled up -- see # http://twistedmatrix.com/trac/ticket/5303 for details: self.reactor.callLater(0.01, self.doAbort) def doAbort(self): if not self.paused: log.err(RuntimeError("BUG: We should be paused a this point.")) self.inReactorMethod = True self.transport.abortConnection() self.inReactorMethod = False def connectionLost(self, reason): # Tell server to start reading again so it knows to go away: self.otherProtocol.transport.startReading() ConnectableProtocol.connectionLost(self, reason) class StreamingProducerClientLater(StreamingProducerClient): """ Call abortConnection() from dataReceived, after bytes have been exchanged. """ def connectionMade(self): self.transport.write(b"hello") self.gotData = False def dataReceived(self, data): if not self.gotData: self.gotData = True self.write() class ProducerAbortingClientLater(ProducerAbortingClient): """ Call abortConnection from doWrite, via resumeProducing. Try to do so after some bytes have already been exchanged, so we don't interrupt SSL handshake. """ def connectionMade(self): # Override base class connectionMade(). pass def dataReceived(self, data): self.write() class DataReceivedRaisingClient(AbortingClient): """ Call abortConnection(), and then throw exception, from dataReceived. """ def dataReceived(self, data): self.transport.abortConnection() raise ZeroDivisionError("ONO") class ResumeThrowsClient(ProducerAbortingClient): """ Call abortConnection() and throw exception from resumeProducing(). """ def resumeProducing(self): if not self.inRegisterProducer: self.transport.abortConnection() raise ZeroDivisionError("ono!") def connectionLost(self, reason): # Base class assertion about stopProducing being called isn't valid; # if the we blew up in resumeProducing, consumers are justified in # giving up on the producer and not calling stopProducing. ConnectableProtocol.connectionLost(self, reason) class AbortConnectionMixin: """ Unit tests for L{ITransport.abortConnection}. """ # Override in subclasses, should be an EndpointCreator instance: endpoints: Optional[EndpointCreator] = None def runAbortTest(self, clientClass, serverClass, clientConnectionLostReason=None): """ A test runner utility function, which hooks up a matched pair of client and server protocols. We then run the reactor until both sides have disconnected, and then verify that the right exception resulted. """ clientExpectedExceptions = (ConnectionAborted, ConnectionLost) serverExpectedExceptions = (ConnectionLost, ConnectionDone) # In TLS tests we may get SSL.Error instead of ConnectionLost, # since we're trashing the TLS protocol layer. if useSSL: clientExpectedExceptions = clientExpectedExceptions + (SSL.Error,) serverExpectedExceptions = serverExpectedExceptions + (SSL.Error,) client = clientClass() server = serverClass() client.otherProtocol = server server.otherProtocol = client reactor = runProtocolsWithReactor(self, server, client, self.endpoints) # Make sure everything was shutdown correctly: self.assertEqual(reactor.removeAll(), []) self.assertEqual(reactor.getDelayedCalls(), []) if clientConnectionLostReason is not None: self.assertIsInstance( client.disconnectReason.value, (clientConnectionLostReason,) + clientExpectedExceptions, ) else: self.assertIsInstance( client.disconnectReason.value, clientExpectedExceptions ) self.assertIsInstance(server.disconnectReason.value, serverExpectedExceptions) def test_dataReceivedAbort(self): """ abortConnection() is called in dataReceived. The protocol should be disconnected, but connectionLost should not be called re-entrantly. """ return self.runAbortTest(AbortingClient, ReadAbortServerProtocol) def test_clientAbortsConnectionTwice(self): """ abortConnection() is called twice by client. No exception should be thrown, and the connection will be closed. """ return self.runAbortTest(AbortingTwiceClient, ReadAbortServerProtocol) def test_clientAbortsConnectionThenLosesConnection(self): """ Client calls abortConnection(), followed by loseConnection(). No exception should be thrown, and the connection will be closed. """ return self.runAbortTest(AbortingThenLosingClient, ReadAbortServerProtocol) def test_serverAbortsConnectionTwice(self): """ abortConnection() is called twice by server. No exception should be thrown, and the connection will be closed. """ return self.runAbortTest( WritingButNotAbortingClient, ServerAbortsTwice, clientConnectionLostReason=ConnectionLost, ) def test_serverAbortsConnectionThenLosesConnection(self): """ Server calls abortConnection(), followed by loseConnection(). No exception should be thrown, and the connection will be closed. """ return self.runAbortTest( WritingButNotAbortingClient, ServerAbortsThenLoses, clientConnectionLostReason=ConnectionLost, ) def test_resumeProducingAbort(self): """ abortConnection() is called in resumeProducing, before any bytes have been exchanged. The protocol should be disconnected, but connectionLost should not be called re-entrantly. """ self.runAbortTest(ProducerAbortingClient, ConnectableProtocol) def test_resumeProducingAbortLater(self): """ abortConnection() is called in resumeProducing, after some bytes have been exchanged. The protocol should be disconnected. """ return self.runAbortTest( ProducerAbortingClientLater, AbortServerWritingProtocol ) def test_fullWriteBuffer(self): """ abortConnection() triggered by the write buffer being full. In particular, the server side stops reading. This is supposed to simulate a realistic timeout scenario where the client notices the server is no longer accepting data. The protocol should be disconnected, but connectionLost should not be called re-entrantly. """ self.runAbortTest(StreamingProducerClient, NoReadServer) def test_fullWriteBufferAfterByteExchange(self): """ abortConnection() is triggered by a write buffer being full. However, this buffer is filled after some bytes have been exchanged, allowing a TLS handshake if we're testing TLS. The connection will then be lost. """ return self.runAbortTest(StreamingProducerClientLater, EventualNoReadServer) def test_dataReceivedThrows(self): """ dataReceived calls abortConnection(), and then raises an exception. The connection will be lost, with the thrown exception (C{ZeroDivisionError}) as the reason on the client. The idea here is that bugs should not be masked by abortConnection, in particular unexpected exceptions. """ self.runAbortTest( DataReceivedRaisingClient, AbortServerWritingProtocol, clientConnectionLostReason=ZeroDivisionError, ) errors = self.flushLoggedErrors(ZeroDivisionError) self.assertEqual(len(errors), 1) def test_resumeProducingThrows(self): """ resumeProducing calls abortConnection(), and then raises an exception. The connection will be lost, with the thrown exception (C{ZeroDivisionError}) as the reason on the client. The idea here is that bugs should not be masked by abortConnection, in particular unexpected exceptions. """ self.runAbortTest( ResumeThrowsClient, ConnectableProtocol, clientConnectionLostReason=ZeroDivisionError, ) errors = self.flushLoggedErrors(ZeroDivisionError) self.assertEqual(len(errors), 1) class AbortConnectionTests(ReactorBuilder, AbortConnectionMixin): """ TCP-specific L{AbortConnectionMixin} tests. """ requiredInterfaces = (IReactorTCP,) endpoints = TCPCreator() globals().update(AbortConnectionTests.makeTestCaseClasses()) @skipIf(ipv6Skip, ipv6SkipReason) class SimpleUtilityTests(TestCase): """ Simple, direct tests for helpers within L{twisted.internet.tcp}. """ def test_resolveNumericHost(self): """ L{_resolveIPv6} raises a L{socket.gaierror} (L{socket.EAI_NONAME}) when invoked with a non-numeric host. (In other words, it is passing L{socket.AI_NUMERICHOST} to L{socket.getaddrinfo} and will not accidentally block if it receives bad input.) """ err = self.assertRaises(socket.gaierror, _resolveIPv6, "localhost", 1) self.assertEqual(err.args[0], socket.EAI_NONAME) @skipIf( platform.isWindows(), "The AI_NUMERICSERV flag is not supported by Microsoft providers.", ) # http://msdn.microsoft.com/en-us/library/windows/desktop/ms738520.aspx def test_resolveNumericService(self): """ L{_resolveIPv6} raises a L{socket.gaierror} (L{socket.EAI_NONAME}) when invoked with a non-numeric port. (In other words, it is passing L{socket.AI_NUMERICSERV} to L{socket.getaddrinfo} and will not accidentally block if it receives bad input.) """ err = self.assertRaises(socket.gaierror, _resolveIPv6, "::1", "http") self.assertEqual(err.args[0], socket.EAI_NONAME) def test_resolveIPv6(self): """ L{_resolveIPv6} discovers the flow info and scope ID of an IPv6 address. """ result = _resolveIPv6("::1", 2) self.assertEqual(len(result), 4) # We can't say anything more useful about these than that they're # integers, because the whole point of getaddrinfo is that you can # never know a-priori know _anything_ about the network interfaces # of the computer that you're on and you have to ask it. self.assertIsInstance(result[2], int) # flow info self.assertIsInstance(result[3], int) # scope id # but, luckily, IP presentation format and what it means to be a port # number are a little better specified. self.assertEqual(result[:2], ("::1", 2)) class BuffersLogsTests(SynchronousTestCase): """ Tests for L{_BuffersLogs}. """ def setUp(self): self.namespace = "name.space" self.events = [] self.logBuffer = _BuffersLogs(self.namespace, self.events.append) def test_buffersInBlock(self): """ The context manager's logger does not log to provided observer inside the block. """ with self.logBuffer as logger: logger.info("An event") self.assertFalse(self.events) def test_flushesOnExit(self): """ The context manager flushes its buffered logs when the block terminates without an exception. """ with self.logBuffer as logger: logger.info("An event") self.assertFalse(self.events) self.assertEqual(1, len(self.events)) [event] = self.events self.assertEqual(event["log_format"], "An event") self.assertEqual(event["log_namespace"], self.namespace) def test_flushesOnExitWithException(self): """ The context manager flushes its buffered logs when the block terminates because of an exception. """ class TestException(Exception): """ An exception only raised by this test. """ with self.assertRaises(TestException): with self.logBuffer as logger: logger.info("An event") self.assertFalse(self.events) raise TestException() self.assertEqual(1, len(self.events)) [event] = self.events self.assertEqual(event["log_format"], "An event") self.assertEqual(event["log_namespace"], self.namespace) @skipIf(SKIP_EMFILE, "Reserved EMFILE file descriptor not supported on Windows.") class FileDescriptorReservationTests(SynchronousTestCase): """ Tests for L{_FileDescriptorReservation}. """ def setUp(self): self.reservedFileObjects = [] self.tempfile = self.mktemp() def fakeFileFactory(): self.reservedFileObjects.append(open(self.tempfile, "w")) return self.reservedFileObjects[-1] self.reservedFD = _FileDescriptorReservation(fakeFileFactory) def test_providesInterface(self): """ L{_FileDescriptorReservation} instances provide L{_IFileDescriptorReservation}. """ verifyObject(_IFileDescriptorReservation, self.reservedFD) def test_reserveOpensFileOnce(self): """ Multiple acquisitions without releases open the reservation file exactly once. """ self.assertEqual(len(self.reservedFileObjects), 0) for _ in range(10): self.reservedFD.reserve() self.assertEqual(len(self.reservedFileObjects), 1) self.assertFalse(self.reservedFileObjects[0].closed) def test_reserveEMFILELogged(self): """ If reserving the file descriptor fails because of C{EMFILE}, the exception is suppressed but logged and the reservation remains unavailable. """ exhauster = _ExhaustsFileDescriptors() self.addCleanup(exhauster.release) exhauster.exhaust() self.assertFalse(self.reservedFD.available()) self.reservedFD.reserve() self.assertFalse(self.reservedFD.available()) errors = self.flushLoggedErrors(OSError, IOError) self.assertEqual(len(errors), 1) self.assertEqual(errors[0].value.errno, errno.EMFILE) def test_reserveRaisesNonEMFILEExceptions(self): """ Any exception raised while opening the reserve file that is not an L{OSError} or L{IOError} whose errno is C{EMFILE} is allowed through to the caller. """ for errorClass in (OSError, IOError, ValueError): # Note that the ValueError will present the errno as its # value. def failsWith(errorClass=errorClass): raise errorClass(errno.EMFILE + 1, "message") reserveFD = _FileDescriptorReservation(failsWith) self.assertRaises(errorClass, reserveFD.reserve) def test_available(self): """ The reservation is available after the file descriptor is reserved. """ self.assertFalse(self.reservedFD.available()) self.reservedFD.reserve() self.assertTrue(self.reservedFD.available()) def test_enterFailsWithoutFile(self): """ A reservation without an open file used as a context manager raises a L{RuntimeError}. """ with self.assertRaises(RuntimeError): with self.reservedFD: """This string cannot raise an exception.""" def test_enterClosesFileExitOpensFile(self): """ Entering a reservation closes its file for the duration of the context manager's block. """ self.reservedFD.reserve() self.assertTrue(self.reservedFD.available()) with self.reservedFD: self.assertFalse(self.reservedFD.available()) self.assertTrue(self.reservedFD.available()) def test_exitOpensFileOnException(self): """ An exception raised within a reservation context manager's block does not prevent the file from being reopened. """ class TestException(Exception): """ An exception only used by this test. """ self.reservedFD.reserve() with self.assertRaises(TestException): with self.reservedFD: raise TestException() def test_exitSuppressesReservationException(self): """ An exception raised while re-opening the reserve file exiting a reservation's context manager block is suppressed but logged, allowing an exception raised within the block through. """ class AllowedException(Exception): """ The exception allowed out of the block. """ class SuppressedException(Exception): """ An exception raised by the file descriptor factory. """ called = [False] def failsWithSuppressedExceptionAfterSecondOpen(): if called[0]: raise SuppressedException() else: called[0] = True return io.BytesIO() reservedFD = _FileDescriptorReservation( failsWithSuppressedExceptionAfterSecondOpen ) reservedFD.reserve() self.assertTrue(reservedFD.available()) with self.assertRaises(AllowedException): with reservedFD: raise AllowedException() errors = self.flushLoggedErrors(SuppressedException) self.assertEqual(len(errors), 1) class NullFileDescriptorReservationTests(SynchronousTestCase): """ Tests for L{_NullFileDescriptorReservation}. """ def setUp(self): self.nullReservedFD = _NullFileDescriptorReservation() def test_providesInterface(self): """ L{_NullFileDescriptorReservation} provides L{_IFileDescriptorReservation}. """ verifyObject(_IFileDescriptorReservation, self.nullReservedFD) def test_available(self): """ The null reserved file descriptor is never available. """ self.assertFalse(self.nullReservedFD.available()) def test_contextManager(self): """ The null reserved file descriptor is a null context manager. """ self.assertFalse(self.nullReservedFD.available()) with self.nullReservedFD: self.assertFalse(self.nullReservedFD.available()) self.assertFalse(self.nullReservedFD.available())