# Copyright (c) Twisted Matrix Laboratories. # See LICENSE for details. """ Test cases for L{twisted.names.server}. """ from zope.interface.verify import verifyClass from twisted.internet import defer from twisted.internet.interfaces import IProtocolFactory from twisted.names import dns, error, resolve, server from twisted.python import failure, log from twisted.trial import unittest class RaisedArguments(Exception): """ An exception containing the arguments raised by L{raiser}. """ def __init__(self, args, kwargs): self.args = args self.kwargs = kwargs def raiser(*args, **kwargs): """ Raise a L{RaisedArguments} exception containing the supplied arguments. Used as a fake when testing the call signatures of methods and functions. """ raise RaisedArguments(args, kwargs) class NoResponseDNSServerFactory(server.DNSServerFactory): """ A L{server.DNSServerFactory} subclass which does not attempt to reply to any received messages. Used for testing logged messages in C{messageReceived} without having to fake or patch the preceding code which attempts to deliver a response message. """ def allowQuery(self, message, protocol, address): """ Deny all queries. @param message: See L{server.DNSServerFactory.allowQuery} @param protocol: See L{server.DNSServerFactory.allowQuery} @param address: See L{server.DNSServerFactory.allowQuery} @return: L{False} @rtype: L{bool} """ return False def sendReply(self, protocol, message, address): """ A noop send reply. @param protocol: See L{server.DNSServerFactory.sendReply} @param message: See L{server.DNSServerFactory.sendReply} @param address: See L{server.DNSServerFactory.sendReply} """ class RaisingDNSServerFactory(server.DNSServerFactory): """ A L{server.DNSServerFactory} subclass whose methods raise an exception containing the supplied arguments. Used for stopping L{messageReceived} and testing the arguments supplied to L{allowQuery}. """ class AllowQueryArguments(Exception): """ Contains positional and keyword arguments in C{args}. """ def allowQuery(self, *args, **kwargs): """ Raise the arguments supplied to L{allowQuery}. @param args: Positional arguments which will be recorded in the raised exception. @type args: L{tuple} @param kwargs: Keyword args which will be recorded in the raised exception. @type kwargs: L{dict} """ raise self.AllowQueryArguments(args, kwargs) class RaisingProtocol: """ A partial fake L{IProtocol} whose methods raise an exception containing the supplied arguments. """ class WriteMessageArguments(Exception): """ Contains positional and keyword arguments in C{args}. """ def writeMessage(self, *args, **kwargs): """ Raises the supplied arguments. @param args: Positional arguments @type args: L{tuple} @param kwargs: Keyword args @type kwargs: L{dict} """ raise self.WriteMessageArguments(args, kwargs) class NoopProtocol: """ A partial fake L{dns.DNSProtocolMixin} with a noop L{writeMessage} method. """ def writeMessage(self, *args, **kwargs): """ A noop version of L{dns.DNSProtocolMixin.writeMessage}. @param args: Positional arguments @type args: L{tuple} @param kwargs: Keyword args @type kwargs: L{dict} """ class RaisingResolver: """ A partial fake L{IResolver} whose methods raise an exception containing the supplied arguments. """ class QueryArguments(Exception): """ Contains positional and keyword arguments in C{args}. """ def query(self, *args, **kwargs): """ Raises the supplied arguments. @param args: Positional arguments @type args: L{tuple} @param kwargs: Keyword args @type kwargs: L{dict} """ raise self.QueryArguments(args, kwargs) class RaisingCache: """ A partial fake L{twisted.names.cache.Cache} whose methods raise an exception containing the supplied arguments. """ class CacheResultArguments(Exception): """ Contains positional and keyword arguments in C{args}. """ def cacheResult(self, *args, **kwargs): """ Raises the supplied arguments. @param args: Positional arguments @type args: L{tuple} @param kwargs: Keyword args @type kwargs: L{dict} """ raise self.CacheResultArguments(args, kwargs) def assertLogMessage(testCase, expectedMessages, callable, *args, **kwargs): """ Assert that the callable logs the expected messages when called. XXX: Put this somewhere where it can be re-used elsewhere. See #6677. @param testCase: The test case controlling the test which triggers the logged messages and on which assertions will be called. @type testCase: L{unittest.SynchronousTestCase} @param expectedMessages: A L{list} of the expected log messages @type expectedMessages: L{list} @param callable: The function which is expected to produce the C{expectedMessages} when called. @type callable: L{callable} @param args: Positional arguments to be passed to C{callable}. @type args: L{list} @param kwargs: Keyword arguments to be passed to C{callable}. @type kwargs: L{dict} """ loggedMessages = [] log.addObserver(loggedMessages.append) testCase.addCleanup(log.removeObserver, loggedMessages.append) callable(*args, **kwargs) testCase.assertEqual([m["message"][0] for m in loggedMessages], expectedMessages) class DNSServerFactoryTests(unittest.TestCase): """ Tests for L{server.DNSServerFactory}. """ def test_resolverType(self): """ L{server.DNSServerFactory.resolver} is a L{resolve.ResolverChain} instance """ self.assertIsInstance(server.DNSServerFactory().resolver, resolve.ResolverChain) def test_resolverDefaultEmpty(self): """ L{server.DNSServerFactory.resolver} is an empty L{resolve.ResolverChain} by default. """ self.assertEqual(server.DNSServerFactory().resolver.resolvers, []) def test_authorities(self): """ L{server.DNSServerFactory.__init__} accepts an C{authorities} argument. The value of this argument is a list and is used to extend the C{resolver} L{resolve.ResolverChain}. """ dummyResolver = object() self.assertEqual( server.DNSServerFactory(authorities=[dummyResolver]).resolver.resolvers, [dummyResolver], ) def test_caches(self): """ L{server.DNSServerFactory.__init__} accepts a C{caches} argument. The value of this argument is a list and is used to extend the C{resolver} L{resolve.ResolverChain}. """ dummyResolver = object() self.assertEqual( server.DNSServerFactory(caches=[dummyResolver]).resolver.resolvers, [dummyResolver], ) def test_clients(self): """ L{server.DNSServerFactory.__init__} accepts a C{clients} argument. The value of this argument is a list and is used to extend the C{resolver} L{resolve.ResolverChain}. """ dummyResolver = object() self.assertEqual( server.DNSServerFactory(clients=[dummyResolver]).resolver.resolvers, [dummyResolver], ) def test_resolverOrder(self): """ L{server.DNSServerFactory.resolver} contains an ordered list of authorities, caches and clients. """ # Use classes here so that we can see meaningful names in test results class DummyAuthority: pass class DummyCache: pass class DummyClient: pass self.assertEqual( server.DNSServerFactory( authorities=[DummyAuthority], caches=[DummyCache], clients=[DummyClient] ).resolver.resolvers, [DummyAuthority, DummyCache, DummyClient], ) def test_cacheDefault(self): """ L{server.DNSServerFactory.cache} is L{None} by default. """ self.assertIsNone(server.DNSServerFactory().cache) def test_cacheOverride(self): """ L{server.DNSServerFactory.__init__} assigns the last object in the C{caches} list to L{server.DNSServerFactory.cache}. """ dummyResolver = object() self.assertEqual( server.DNSServerFactory(caches=[object(), dummyResolver]).cache, dummyResolver, ) def test_canRecurseDefault(self): """ L{server.DNSServerFactory.canRecurse} is a flag indicating that this server is capable of performing recursive DNS lookups. It defaults to L{False}. """ self.assertFalse(server.DNSServerFactory().canRecurse) def test_canRecurseOverride(self): """ L{server.DNSServerFactory.__init__} sets C{canRecurse} to L{True} if it is supplied with C{clients}. """ self.assertEqual(server.DNSServerFactory(clients=[None]).canRecurse, True) def test_verboseDefault(self): """ L{server.DNSServerFactory.verbose} defaults to L{False}. """ self.assertFalse(server.DNSServerFactory().verbose) def test_verboseOverride(self): """ L{server.DNSServerFactory.__init__} accepts a C{verbose} argument which overrides L{server.DNSServerFactory.verbose}. """ self.assertTrue(server.DNSServerFactory(verbose=True).verbose) def test_interface(self): """ L{server.DNSServerFactory} implements L{IProtocolFactory}. """ self.assertTrue(verifyClass(IProtocolFactory, server.DNSServerFactory)) def test_defaultProtocol(self): """ L{server.DNSServerFactory.protocol} defaults to L{dns.DNSProtocol}. """ self.assertIs(server.DNSServerFactory.protocol, dns.DNSProtocol) def test_buildProtocolProtocolOverride(self): """ L{server.DNSServerFactory.buildProtocol} builds a protocol by calling L{server.DNSServerFactory.protocol} with its self as a positional argument. """ class FakeProtocol: factory = None args = None kwargs = None stubProtocol = FakeProtocol() def fakeProtocolFactory(*args, **kwargs): stubProtocol.args = args stubProtocol.kwargs = kwargs return stubProtocol f = server.DNSServerFactory() f.protocol = fakeProtocolFactory p = f.buildProtocol(addr=None) self.assertEqual((stubProtocol, (f,), {}), (p, p.args, p.kwargs)) def test_verboseLogQuiet(self): """ L{server.DNSServerFactory._verboseLog} does not log messages unless C{verbose > 0}. """ f = server.DNSServerFactory() assertLogMessage(self, [], f._verboseLog, "Foo Bar") def test_verboseLogVerbose(self): """ L{server.DNSServerFactory._verboseLog} logs a message if C{verbose > 0}. """ f = server.DNSServerFactory(verbose=1) assertLogMessage(self, ["Foo Bar"], f._verboseLog, "Foo Bar") def test_messageReceivedLoggingNoQuery(self): """ L{server.DNSServerFactory.messageReceived} logs about an empty query if the message had no queries and C{verbose} is C{>0}. """ m = dns.Message() f = NoResponseDNSServerFactory(verbose=1) assertLogMessage( self, ["Empty query from ('192.0.2.100', 53)"], f.messageReceived, message=m, proto=None, address=("192.0.2.100", 53), ) def test_messageReceivedLogging1(self): """ L{server.DNSServerFactory.messageReceived} logs the query types of all queries in the message if C{verbose} is set to C{1}. """ m = dns.Message() m.addQuery(name="example.com", type=dns.MX) m.addQuery(name="example.com", type=dns.AAAA) f = NoResponseDNSServerFactory(verbose=1) assertLogMessage( self, ["MX AAAA query from ('192.0.2.100', 53)"], f.messageReceived, message=m, proto=None, address=("192.0.2.100", 53), ) def test_messageReceivedLogging2(self): """ L{server.DNSServerFactory.messageReceived} logs the repr of all queries in the message if C{verbose} is set to C{2}. """ m = dns.Message() m.addQuery(name="example.com", type=dns.MX) m.addQuery(name="example.com", type=dns.AAAA) f = NoResponseDNSServerFactory(verbose=2) assertLogMessage( self, [ " " " query from ('192.0.2.100', 53)" ], f.messageReceived, message=m, proto=None, address=("192.0.2.100", 53), ) def test_messageReceivedTimestamp(self): """ L{server.DNSServerFactory.messageReceived} assigns a unix timestamp to the received message. """ m = dns.Message() f = NoResponseDNSServerFactory() t = object() self.patch(server.time, "time", lambda: t) f.messageReceived(message=m, proto=None, address=None) self.assertEqual(m.timeReceived, t) def test_messageReceivedAllowQuery(self): """ L{server.DNSServerFactory.messageReceived} passes all messages to L{server.DNSServerFactory.allowQuery} along with the receiving protocol and origin address. """ message = dns.Message() dummyProtocol = object() dummyAddress = object() f = RaisingDNSServerFactory() e = self.assertRaises( RaisingDNSServerFactory.AllowQueryArguments, f.messageReceived, message=message, proto=dummyProtocol, address=dummyAddress, ) args, kwargs = e.args self.assertEqual(args, (message, dummyProtocol, dummyAddress)) self.assertEqual(kwargs, {}) def test_allowQueryFalse(self): """ If C{allowQuery} returns C{False}, L{server.DNSServerFactory.messageReceived} calls L{server.sendReply} with a message whose C{rCode} is L{dns.EREFUSED}. """ class SendReplyException(Exception): pass class RaisingDNSServerFactory(server.DNSServerFactory): def allowQuery(self, *args, **kwargs): return False def sendReply(self, *args, **kwargs): raise SendReplyException(args, kwargs) f = RaisingDNSServerFactory() e = self.assertRaises( SendReplyException, f.messageReceived, message=dns.Message(), proto=None, address=None, ) (proto, message, address), kwargs = e.args self.assertEqual(message.rCode, dns.EREFUSED) def _messageReceivedTest(self, methodName, message): """ Assert that the named method is called with the given message when it is passed to L{DNSServerFactory.messageReceived}. @param methodName: The name of the method which is expected to be called. @type methodName: L{str} @param message: The message which is expected to be passed to the C{methodName} method. @type message: L{dns.Message} """ # Make it appear to have some queries so that # DNSServerFactory.allowQuery allows it. message.queries = [None] receivedMessages = [] def fakeHandler(message, protocol, address): receivedMessages.append((message, protocol, address)) protocol = NoopProtocol() factory = server.DNSServerFactory(None) setattr(factory, methodName, fakeHandler) factory.messageReceived(message, protocol) self.assertEqual(receivedMessages, [(message, protocol, None)]) def test_queryMessageReceived(self): """ L{DNSServerFactory.messageReceived} passes messages with an opcode of C{OP_QUERY} on to L{DNSServerFactory.handleQuery}. """ self._messageReceivedTest("handleQuery", dns.Message(opCode=dns.OP_QUERY)) def test_inverseQueryMessageReceived(self): """ L{DNSServerFactory.messageReceived} passes messages with an opcode of C{OP_INVERSE} on to L{DNSServerFactory.handleInverseQuery}. """ self._messageReceivedTest( "handleInverseQuery", dns.Message(opCode=dns.OP_INVERSE) ) def test_statusMessageReceived(self): """ L{DNSServerFactory.messageReceived} passes messages with an opcode of C{OP_STATUS} on to L{DNSServerFactory.handleStatus}. """ self._messageReceivedTest("handleStatus", dns.Message(opCode=dns.OP_STATUS)) def test_notifyMessageReceived(self): """ L{DNSServerFactory.messageReceived} passes messages with an opcode of C{OP_NOTIFY} on to L{DNSServerFactory.handleNotify}. """ self._messageReceivedTest("handleNotify", dns.Message(opCode=dns.OP_NOTIFY)) def test_updateMessageReceived(self): """ L{DNSServerFactory.messageReceived} passes messages with an opcode of C{OP_UPDATE} on to L{DNSServerFactory.handleOther}. This may change if the implementation ever covers update messages. """ self._messageReceivedTest("handleOther", dns.Message(opCode=dns.OP_UPDATE)) def test_connectionTracking(self): """ The C{connectionMade} and C{connectionLost} methods of L{DNSServerFactory} cooperate to keep track of all L{DNSProtocol} objects created by a factory which are connected. """ protoA, protoB = object(), object() factory = server.DNSServerFactory() factory.connectionMade(protoA) self.assertEqual(factory.connections, [protoA]) factory.connectionMade(protoB) self.assertEqual(factory.connections, [protoA, protoB]) factory.connectionLost(protoA) self.assertEqual(factory.connections, [protoB]) factory.connectionLost(protoB) self.assertEqual(factory.connections, []) def test_handleQuery(self): """ L{server.DNSServerFactory.handleQuery} takes the first query from the supplied message and dispatches it to L{server.DNSServerFactory.resolver.query}. """ m = dns.Message() m.addQuery(b"one.example.com") m.addQuery(b"two.example.com") f = server.DNSServerFactory() f.resolver = RaisingResolver() e = self.assertRaises( RaisingResolver.QueryArguments, f.handleQuery, message=m, protocol=NoopProtocol(), address=None, ) (query,), kwargs = e.args self.assertEqual(query, m.queries[0]) def test_handleQueryCallback(self): """ L{server.DNSServerFactory.handleQuery} adds L{server.DNSServerFactory.resolver.gotResolverResponse} as a callback to the deferred returned by L{server.DNSServerFactory.resolver.query}. It is called with the query response, the original protocol, message and origin address. """ f = server.DNSServerFactory() d = defer.Deferred() class FakeResolver: def query(self, *args, **kwargs): return d f.resolver = FakeResolver() gotResolverResponseArgs = [] def fakeGotResolverResponse(*args, **kwargs): gotResolverResponseArgs.append((args, kwargs)) f.gotResolverResponse = fakeGotResolverResponse m = dns.Message() m.addQuery(b"one.example.com") stubProtocol = NoopProtocol() dummyAddress = object() f.handleQuery(message=m, protocol=stubProtocol, address=dummyAddress) dummyResponse = object() d.callback(dummyResponse) self.assertEqual( gotResolverResponseArgs, [((dummyResponse, stubProtocol, m, dummyAddress), {})], ) def test_handleQueryErrback(self): """ L{server.DNSServerFactory.handleQuery} adds L{server.DNSServerFactory.resolver.gotResolverError} as an errback to the deferred returned by L{server.DNSServerFactory.resolver.query}. It is called with the query failure, the original protocol, message and origin address. """ f = server.DNSServerFactory() d = defer.Deferred() class FakeResolver: def query(self, *args, **kwargs): return d f.resolver = FakeResolver() gotResolverErrorArgs = [] def fakeGotResolverError(*args, **kwargs): gotResolverErrorArgs.append((args, kwargs)) f.gotResolverError = fakeGotResolverError m = dns.Message() m.addQuery(b"one.example.com") stubProtocol = NoopProtocol() dummyAddress = object() f.handleQuery(message=m, protocol=stubProtocol, address=dummyAddress) stubFailure = failure.Failure(Exception()) d.errback(stubFailure) self.assertEqual( gotResolverErrorArgs, [((stubFailure, stubProtocol, m, dummyAddress), {})] ) def test_gotResolverResponse(self): """ L{server.DNSServerFactory.gotResolverResponse} accepts a tuple of resource record lists and triggers a response message containing those resource record lists. """ f = server.DNSServerFactory() answers = [] authority = [] additional = [] e = self.assertRaises( RaisingProtocol.WriteMessageArguments, f.gotResolverResponse, (answers, authority, additional), protocol=RaisingProtocol(), message=dns.Message(), address=None, ) (message,), kwargs = e.args self.assertIs(message.answers, answers) self.assertIs(message.authority, authority) self.assertIs(message.additional, additional) def test_gotResolverResponseCallsResponseFromMessage(self): """ L{server.DNSServerFactory.gotResolverResponse} calls L{server.DNSServerFactory._responseFromMessage} to generate a response. """ factory = NoResponseDNSServerFactory() factory._responseFromMessage = raiser request = dns.Message() request.timeReceived = 1 e = self.assertRaises( RaisedArguments, factory.gotResolverResponse, ([], [], []), protocol=None, message=request, address=None, ) self.assertEqual( ( (), dict( message=request, rCode=dns.OK, answers=[], authority=[], additional=[], ), ), (e.args, e.kwargs), ) def test_responseFromMessageNewMessage(self): """ L{server.DNSServerFactory._responseFromMessage} generates a response message which is a copy of the request message. """ factory = server.DNSServerFactory() request = dns.Message(answer=False, recAv=False) response = (factory._responseFromMessage(message=request),) self.assertIsNot(request, response) def test_responseFromMessageRecursionAvailable(self): """ L{server.DNSServerFactory._responseFromMessage} generates a response message whose C{recAV} attribute is L{True} if L{server.DNSServerFactory.canRecurse} is L{True}. """ factory = server.DNSServerFactory() factory.canRecurse = True response1 = factory._responseFromMessage(message=dns.Message(recAv=False)) factory.canRecurse = False response2 = factory._responseFromMessage(message=dns.Message(recAv=True)) self.assertEqual((True, False), (response1.recAv, response2.recAv)) def test_responseFromMessageTimeReceived(self): """ L{server.DNSServerFactory._responseFromMessage} generates a response message whose C{timeReceived} attribute has the same value as that found on the request. """ factory = server.DNSServerFactory() request = dns.Message() request.timeReceived = 1234 response = factory._responseFromMessage(message=request) self.assertEqual(request.timeReceived, response.timeReceived) def test_responseFromMessageMaxSize(self): """ L{server.DNSServerFactory._responseFromMessage} generates a response message whose C{maxSize} attribute has the same value as that found on the request. """ factory = server.DNSServerFactory() request = dns.Message() request.maxSize = 0 response = factory._responseFromMessage(message=request) self.assertEqual(request.maxSize, response.maxSize) def test_messageFactory(self): """ L{server.DNSServerFactory} has a C{_messageFactory} attribute which is L{dns.Message} by default. """ self.assertIs(dns.Message, server.DNSServerFactory._messageFactory) def test_responseFromMessageCallsMessageFactory(self): """ L{server.DNSServerFactory._responseFromMessage} calls C{dns._responseFromMessage} to generate a response message from the request message. It supplies the request message and other keyword arguments which should be passed to the response message initialiser. """ factory = server.DNSServerFactory() self.patch(dns, "_responseFromMessage", raiser) request = dns.Message() e = self.assertRaises( RaisedArguments, factory._responseFromMessage, message=request, rCode=dns.OK ) self.assertEqual( ( (), dict( responseConstructor=factory._messageFactory, message=request, rCode=dns.OK, recAv=factory.canRecurse, auth=False, ), ), (e.args, e.kwargs), ) def test_responseFromMessageAuthoritativeMessage(self): """ L{server.DNSServerFactory._responseFromMessage} marks the response message as authoritative if any of the answer records are authoritative. """ factory = server.DNSServerFactory() response1 = factory._responseFromMessage( message=dns.Message(), answers=[dns.RRHeader(auth=True)] ) response2 = factory._responseFromMessage( message=dns.Message(), answers=[dns.RRHeader(auth=False)] ) self.assertEqual( (True, False), (response1.auth, response2.auth), ) def test_gotResolverResponseLogging(self): """ L{server.DNSServerFactory.gotResolverResponse} logs the total number of records in the response if C{verbose > 0}. """ f = NoResponseDNSServerFactory(verbose=1) answers = [dns.RRHeader()] authority = [dns.RRHeader()] additional = [dns.RRHeader()] assertLogMessage( self, ["Lookup found 3 records"], f.gotResolverResponse, (answers, authority, additional), protocol=NoopProtocol(), message=dns.Message(), address=None, ) def test_gotResolverResponseCaching(self): """ L{server.DNSServerFactory.gotResolverResponse} caches the response if at least one cache was provided in the constructor. """ f = NoResponseDNSServerFactory(caches=[RaisingCache()]) m = dns.Message() m.addQuery(b"example.com") expectedAnswers = [dns.RRHeader()] expectedAuthority = [] expectedAdditional = [] e = self.assertRaises( RaisingCache.CacheResultArguments, f.gotResolverResponse, (expectedAnswers, expectedAuthority, expectedAdditional), protocol=NoopProtocol(), message=m, address=None, ) (query, (answers, authority, additional)), kwargs = e.args self.assertEqual(query.name.name, b"example.com") self.assertIs(answers, expectedAnswers) self.assertIs(authority, expectedAuthority) self.assertIs(additional, expectedAdditional) def test_gotResolverErrorCallsResponseFromMessage(self): """ L{server.DNSServerFactory.gotResolverError} calls L{server.DNSServerFactory._responseFromMessage} to generate a response. """ factory = NoResponseDNSServerFactory() factory._responseFromMessage = raiser request = dns.Message() request.timeReceived = 1 e = self.assertRaises( RaisedArguments, factory.gotResolverError, failure.Failure(error.DomainError()), protocol=None, message=request, address=None, ) self.assertEqual( ((), dict(message=request, rCode=dns.ENAME)), (e.args, e.kwargs) ) def _assertMessageRcodeForError(self, responseError, expectedMessageCode): """ L{server.DNSServerFactory.gotResolver} accepts a L{failure.Failure} and triggers a response message whose rCode corresponds to the DNS error contained in the C{Failure}. @param responseError: The L{Exception} instance which is expected to trigger C{expectedMessageCode} when it is supplied to C{gotResolverError} @type responseError: L{Exception} @param expectedMessageCode: The C{rCode} which is expected in the message returned by C{gotResolverError} in response to C{responseError}. @type expectedMessageCode: L{int} """ f = server.DNSServerFactory() e = self.assertRaises( RaisingProtocol.WriteMessageArguments, f.gotResolverError, failure.Failure(responseError), protocol=RaisingProtocol(), message=dns.Message(), address=None, ) (message,), kwargs = e.args self.assertEqual(message.rCode, expectedMessageCode) def test_gotResolverErrorDomainError(self): """ L{server.DNSServerFactory.gotResolver} triggers a response message with an C{rCode} of L{dns.ENAME} if supplied with a L{error.DomainError}. """ self._assertMessageRcodeForError(error.DomainError(), dns.ENAME) def test_gotResolverErrorAuthoritativeDomainError(self): """ L{server.DNSServerFactory.gotResolver} triggers a response message with an C{rCode} of L{dns.ENAME} if supplied with a L{error.AuthoritativeDomainError}. """ self._assertMessageRcodeForError(error.AuthoritativeDomainError(), dns.ENAME) def test_gotResolverErrorOtherError(self): """ L{server.DNSServerFactory.gotResolver} triggers a response message with an C{rCode} of L{dns.ESERVER} if supplied with another type of error and logs the error. """ self._assertMessageRcodeForError(KeyError(), dns.ESERVER) e = self.flushLoggedErrors(KeyError) self.assertEqual(len(e), 1) def test_gotResolverErrorLogging(self): """ L{server.DNSServerFactory.gotResolver} logs a message if C{verbose > 0}. """ f = NoResponseDNSServerFactory(verbose=1) assertLogMessage( self, ["Lookup failed"], f.gotResolverError, failure.Failure(error.DomainError()), protocol=NoopProtocol(), message=dns.Message(), address=None, ) def test_gotResolverErrorResetsResponseAttributes(self): """ L{server.DNSServerFactory.gotResolverError} does not allow request attributes to leak into the response ie it sends a response with AD, CD set to 0 and empty response record sections. """ factory = server.DNSServerFactory() responses = [] factory.sendReply = lambda protocol, response, address: responses.append( response ) request = dns.Message(authenticData=True, checkingDisabled=True) request.answers = [object(), object()] request.authority = [object(), object()] request.additional = [object(), object()] factory.gotResolverError( failure.Failure(error.DomainError()), protocol=None, message=request, address=None, ) self.assertEqual([dns.Message(rCode=3, answer=True)], responses) def test_gotResolverResponseResetsResponseAttributes(self): """ L{server.DNSServerFactory.gotResolverResponse} does not allow request attributes to leak into the response ie it sends a response with AD, CD set to 0 and none of the records in the request answer sections are copied to the response. """ factory = server.DNSServerFactory() responses = [] factory.sendReply = lambda protocol, response, address: responses.append( response ) request = dns.Message(authenticData=True, checkingDisabled=True) request.answers = [object(), object()] request.authority = [object(), object()] request.additional = [object(), object()] factory.gotResolverResponse( ([], [], []), protocol=None, message=request, address=None ) self.assertEqual([dns.Message(rCode=0, answer=True)], responses) def test_sendReplyWithAddress(self): """ If L{server.DNSServerFactory.sendReply} is supplied with a protocol *and* an address tuple it will supply that address to C{protocol.writeMessage}. """ m = dns.Message() dummyAddress = object() f = server.DNSServerFactory() e = self.assertRaises( RaisingProtocol.WriteMessageArguments, f.sendReply, protocol=RaisingProtocol(), message=m, address=dummyAddress, ) args, kwargs = e.args self.assertEqual(args, (m, dummyAddress)) self.assertEqual(kwargs, {}) def test_sendReplyWithoutAddress(self): """ If L{server.DNSServerFactory.sendReply} is supplied with a protocol but no address tuple it will supply only a message to C{protocol.writeMessage}. """ m = dns.Message() f = server.DNSServerFactory() e = self.assertRaises( RaisingProtocol.WriteMessageArguments, f.sendReply, protocol=RaisingProtocol(), message=m, address=None, ) args, kwargs = e.args self.assertEqual(args, (m,)) self.assertEqual(kwargs, {}) def test_sendReplyLoggingNoAnswers(self): """ If L{server.DNSServerFactory.sendReply} logs a "no answers" message if the supplied message has no answers. """ self.patch(server.time, "time", lambda: 86402) m = dns.Message() m.timeReceived = 86401 f = server.DNSServerFactory(verbose=2) assertLogMessage( self, ["Replying with no answers", "Processed query in 1.000 seconds"], f.sendReply, protocol=NoopProtocol(), message=m, address=None, ) def test_sendReplyLoggingWithAnswers(self): """ If L{server.DNSServerFactory.sendReply} logs a message for answers, authority, additional if the supplied a message has records in any of those sections. """ self.patch(server.time, "time", lambda: 86402) m = dns.Message() m.answers.append(dns.RRHeader(payload=dns.Record_A("127.0.0.1"))) m.authority.append(dns.RRHeader(payload=dns.Record_A("127.0.0.1"))) m.additional.append(dns.RRHeader(payload=dns.Record_A("127.0.0.1"))) m.timeReceived = 86401 f = server.DNSServerFactory(verbose=2) assertLogMessage( self, [ "Answers are ", "Authority is ", "Additional is ", "Processed query in 1.000 seconds", ], f.sendReply, protocol=NoopProtocol(), message=m, address=None, ) def test_handleInverseQuery(self): """ L{server.DNSServerFactory.handleInverseQuery} triggers the sending of a response message with C{rCode} set to L{dns.ENOTIMP}. """ f = server.DNSServerFactory() e = self.assertRaises( RaisingProtocol.WriteMessageArguments, f.handleInverseQuery, message=dns.Message(), protocol=RaisingProtocol(), address=None, ) (message,), kwargs = e.args self.assertEqual(message.rCode, dns.ENOTIMP) def test_handleInverseQueryLogging(self): """ L{server.DNSServerFactory.handleInverseQuery} logs the message origin address if C{verbose > 0}. """ f = NoResponseDNSServerFactory(verbose=1) assertLogMessage( self, ["Inverse query from ('::1', 53)"], f.handleInverseQuery, message=dns.Message(), protocol=NoopProtocol(), address=("::1", 53), ) def test_handleStatus(self): """ L{server.DNSServerFactory.handleStatus} triggers the sending of a response message with C{rCode} set to L{dns.ENOTIMP}. """ f = server.DNSServerFactory() e = self.assertRaises( RaisingProtocol.WriteMessageArguments, f.handleStatus, message=dns.Message(), protocol=RaisingProtocol(), address=None, ) (message,), kwargs = e.args self.assertEqual(message.rCode, dns.ENOTIMP) def test_handleStatusLogging(self): """ L{server.DNSServerFactory.handleStatus} logs the message origin address if C{verbose > 0}. """ f = NoResponseDNSServerFactory(verbose=1) assertLogMessage( self, ["Status request from ('::1', 53)"], f.handleStatus, message=dns.Message(), protocol=NoopProtocol(), address=("::1", 53), ) def test_handleNotify(self): """ L{server.DNSServerFactory.handleNotify} triggers the sending of a response message with C{rCode} set to L{dns.ENOTIMP}. """ f = server.DNSServerFactory() e = self.assertRaises( RaisingProtocol.WriteMessageArguments, f.handleNotify, message=dns.Message(), protocol=RaisingProtocol(), address=None, ) (message,), kwargs = e.args self.assertEqual(message.rCode, dns.ENOTIMP) def test_handleNotifyLogging(self): """ L{server.DNSServerFactory.handleNotify} logs the message origin address if C{verbose > 0}. """ f = NoResponseDNSServerFactory(verbose=1) assertLogMessage( self, ["Notify message from ('::1', 53)"], f.handleNotify, message=dns.Message(), protocol=NoopProtocol(), address=("::1", 53), ) def test_handleOther(self): """ L{server.DNSServerFactory.handleOther} triggers the sending of a response message with C{rCode} set to L{dns.ENOTIMP}. """ f = server.DNSServerFactory() e = self.assertRaises( RaisingProtocol.WriteMessageArguments, f.handleOther, message=dns.Message(), protocol=RaisingProtocol(), address=None, ) (message,), kwargs = e.args self.assertEqual(message.rCode, dns.ENOTIMP) def test_handleOtherLogging(self): """ L{server.DNSServerFactory.handleOther} logs the message origin address if C{verbose > 0}. """ f = NoResponseDNSServerFactory(verbose=1) assertLogMessage( self, ["Unknown op code (0) from ('::1', 53)"], f.handleOther, message=dns.Message(), protocol=NoopProtocol(), address=("::1", 53), )