# test-case-name: twisted.names.test.test_dns # Copyright (c) Twisted Matrix Laboratories. # See LICENSE for details. """ Tests for twisted.names.dns. """ import struct from io import BytesIO from typing import cast from zope.interface.verify import verifyClass from twisted.internet import address, task from twisted.internet.error import CannotListenError, ConnectionDone from twisted.names import dns from twisted.python.failure import Failure from twisted.python.util import FancyEqMixin, FancyStrMixin from twisted.test import proto_helpers from twisted.test.testutils import ComparisonTestsMixin from twisted.trial import unittest RECORD_TYPES = [ dns.Record_NS, dns.Record_MD, dns.Record_MF, dns.Record_CNAME, dns.Record_MB, dns.Record_MG, dns.Record_MR, dns.Record_PTR, dns.Record_DNAME, dns.Record_A, dns.Record_SOA, dns.Record_NULL, dns.Record_WKS, dns.Record_SRV, dns.Record_AFSDB, dns.Record_RP, dns.Record_HINFO, dns.Record_MINFO, dns.Record_MX, dns.Record_TXT, dns.Record_AAAA, dns.Record_A6, dns.Record_NAPTR, dns.Record_SSHFP, dns.Record_TSIG, dns.UnknownRecord, ] class DomainStringTests(unittest.SynchronousTestCase): def test_bytes(self): """ L{dns.domainString} returns L{bytes} unchanged. """ self.assertEqual( b"twistedmatrix.com", dns.domainString(b"twistedmatrix.com"), ) def test_native(self): """ L{dns.domainString} converts a native string to L{bytes} if necessary. """ self.assertEqual(b"example.com", dns.domainString("example.com")) def test_text(self): """ L{dns.domainString} always converts a unicode string to L{bytes}. """ self.assertEqual(b"foo.example", dns.domainString("foo.example")) def test_idna(self): """ L{dns.domainString} encodes Unicode using IDNA. """ self.assertEqual(b"xn--fwg.test", dns.domainString("\u203D.test")) def test_nonsense(self): """ L{dns.domainString} encodes Unicode using IDNA. """ self.assertRaises(TypeError, dns.domainString, 9000) self.assertRaises(TypeError, dns.domainString, dns.Name("bar.example")) class Ord2ByteTests(unittest.TestCase): """ Tests for L{dns._ord2bytes}. """ def test_ord2byte(self): """ L{dns._ord2byte} accepts an integer and returns a byte string of length one with an ordinal value equal to the given integer. """ self.assertEqual(b"\x10", dns._ord2bytes(0x10)) class Str2TimeTests(unittest.TestCase): """ Tests for L{dns.str2name}. """ def test_nonString(self): """ When passed a non-string object, L{dns.str2name} returns it unmodified. """ time = object() self.assertIs(time, dns.str2time(time)) def test_seconds(self): """ Passed a string giving a number of seconds, L{dns.str2time} returns the number of seconds represented. For example, C{"10S"} represents C{10} seconds. """ self.assertEqual(10, dns.str2time("10S")) def test_minutes(self): """ Like C{test_seconds}, but for the C{"M"} suffix which multiplies the time value by C{60} (the number of seconds in a minute!). """ self.assertEqual(2 * 60, dns.str2time("2M")) def test_hours(self): """ Like C{test_seconds}, but for the C{"H"} suffix which multiplies the time value by C{3600}, the number of seconds in an hour. """ self.assertEqual(3 * 3600, dns.str2time("3H")) def test_days(self): """ Like L{test_seconds}, but for the C{"D"} suffix which multiplies the time value by C{86400}, the number of seconds in a day. """ self.assertEqual(4 * 86400, dns.str2time("4D")) def test_weeks(self): """ Like L{test_seconds}, but for the C{"W"} suffix which multiplies the time value by C{604800}, the number of seconds in a week. """ self.assertEqual(5 * 604800, dns.str2time("5W")) def test_years(self): """ Like L{test_seconds}, but for the C{"Y"} suffix which multiplies the time value by C{31536000}, the number of seconds in a year. """ self.assertEqual(6 * 31536000, dns.str2time("6Y")) def test_invalidPrefix(self): """ If a non-integer prefix is given, L{dns.str2time} raises L{ValueError}. """ self.assertRaises(ValueError, dns.str2time, "fooS") class NameTests(unittest.TestCase): """ Tests for L{Name}, the representation of a single domain name with support for encoding into and decoding from DNS message format. """ def test_nonStringName(self): """ When constructed with a name which is neither C{bytes} nor C{str}, L{Name} raises L{TypeError}. """ self.assertRaises(TypeError, dns.Name, 123) self.assertRaises(TypeError, dns.Name, object()) self.assertRaises(TypeError, dns.Name, []) def test_unicodeName(self): """ L{dns.Name} automatically encodes unicode domain name using C{idna} encoding. """ name = dns.Name("\u00e9chec.example.org") self.assertIsInstance(name.name, bytes) self.assertEqual(b"xn--chec-9oa.example.org", name.name) def test_decode(self): """ L{Name.decode} populates the L{Name} instance with name information read from the file-like object passed to it. """ n = dns.Name() n.decode(BytesIO(b"\x07example\x03com\x00")) self.assertEqual(n.name, b"example.com") def test_encode(self): """ L{Name.encode} encodes its name information and writes it to the file-like object passed to it. """ name = dns.Name(b"foo.example.com") stream = BytesIO() name.encode(stream) self.assertEqual(stream.getvalue(), b"\x03foo\x07example\x03com\x00") def test_encodeWithCompression(self): """ If a compression dictionary is passed to it, L{Name.encode} uses offset information from it to encode its name with references to existing labels in the stream instead of including another copy of them in the output. It also updates the compression dictionary with the location of the name it writes to the stream. """ name = dns.Name(b"foo.example.com") compression = {b"example.com": 0x17} # Some bytes already encoded into the stream for this message previous = b"some prefix to change .tell()" stream = BytesIO() stream.write(previous) # The position at which the encoded form of this new name will appear in # the stream. expected = len(previous) + dns.Message.headerSize name.encode(stream, compression) self.assertEqual(b"\x03foo\xc0\x17", stream.getvalue()[len(previous) :]) self.assertEqual( {b"example.com": 0x17, b"foo.example.com": expected}, compression ) def test_unknown(self): """ A resource record of unknown type and class is parsed into an L{UnknownRecord} instance with its data preserved, and an L{UnknownRecord} instance is serialized to a string equal to the one it was parsed from. """ wire = ( b"\x01\x00" # Message ID b"\x00" # answer bit, opCode nibble, auth bit, trunc bit, recursive # bit b"\x00" # recursion bit, empty bit, authenticData bit, # checkingDisabled bit, response code nibble b"\x00\x01" # number of queries b"\x00\x01" # number of answers b"\x00\x00" # number of authorities b"\x00\x01" # number of additionals # query b"\x03foo\x03bar\x00" # foo.bar b"\xde\xad" # type=0xdead b"\xbe\xef" # cls=0xbeef # 1st answer b"\xc0\x0c" # foo.bar - compressed b"\xde\xad" # type=0xdead b"\xbe\xef" # cls=0xbeef b"\x00\x00\x01\x01" # ttl=257 b"\x00\x08somedata" # some payload data # 1st additional b"\x03baz\x03ban\x00" # baz.ban b"\x00\x01" # type=A b"\x00\x01" # cls=IN b"\x00\x00\x01\x01" # ttl=257 b"\x00\x04" # len=4 b"\x01\x02\x03\x04" # 1.2.3.4 ) msg = dns.Message() msg.fromStr(wire) self.assertEqual( msg.queries, [ dns.Query(b"foo.bar", type=0xDEAD, cls=0xBEEF), ], ) self.assertEqual( msg.answers, [ dns.RRHeader( b"foo.bar", type=0xDEAD, cls=0xBEEF, ttl=257, payload=dns.UnknownRecord(b"somedata", ttl=257), ), ], ) self.assertEqual( msg.additional, [ dns.RRHeader( b"baz.ban", type=dns.A, cls=dns.IN, ttl=257, payload=dns.Record_A("1.2.3.4", ttl=257), ), ], ) enc = msg.toStr() self.assertEqual(enc, wire) def test_decodeWithCompression(self): """ If the leading byte of an encoded label (in bytes read from a stream passed to L{Name.decode}) has its two high bits set, the next byte is treated as a pointer to another label in the stream and that label is included in the name being decoded. """ # Slightly modified version of the example from RFC 1035, section 4.1.4. stream = BytesIO( b"x" * 20 + b"\x01f\x03isi\x04arpa\x00" b"\x03foo\xc0\x14" b"\x03bar\xc0\x20" ) stream.seek(20) name = dns.Name() name.decode(stream) # Verify we found the first name in the stream and that the stream # position is left at the first byte after the decoded name. self.assertEqual(b"f.isi.arpa", name.name) self.assertEqual(32, stream.tell()) # Get the second name from the stream and make the same assertions. name.decode(stream) self.assertEqual(name.name, b"foo.f.isi.arpa") self.assertEqual(38, stream.tell()) # Get the third and final name name.decode(stream) self.assertEqual(name.name, b"bar.foo.f.isi.arpa") self.assertEqual(44, stream.tell()) def test_rejectCompressionLoop(self): """ L{Name.decode} raises L{ValueError} if the stream passed to it includes a compression pointer which forms a loop, causing the name to be undecodable. """ name = dns.Name() stream = BytesIO(b"\xc0\x00") self.assertRaises(ValueError, name.decode, stream) def test_equality(self): """ L{Name} instances are equal as long as they have the same value for L{Name.name}, regardless of the case. """ name1 = dns.Name(b"foo.bar") name2 = dns.Name(b"foo.bar") self.assertEqual(name1, name2) name3 = dns.Name(b"fOO.bar") self.assertEqual(name1, name3) def test_inequality(self): """ L{Name} instances are not equal as long as they have different L{Name.name} attributes. """ name1 = dns.Name(b"foo.bar") name2 = dns.Name(b"bar.foo") self.assertNotEqual(name1, name2) class RoundtripDNSTests(unittest.TestCase): """ Encoding and then decoding various objects. """ names = [b"example.org", b"go-away.fish.tv", b"23strikesback.net"] def test_name(self): for n in self.names: # encode the name f = BytesIO() dns.Name(n).encode(f) # decode the name f.seek(0, 0) result = dns.Name() result.decode(f) self.assertEqual(result.name, n) def test_query(self): """ L{dns.Query.encode} returns a byte string representing the fields of the query which can be decoded into a new L{dns.Query} instance using L{dns.Query.decode}. """ for n in self.names: for dnstype in range(1, 17): for dnscls in range(1, 5): # encode the query f = BytesIO() dns.Query(n, dnstype, dnscls).encode(f) # decode the result f.seek(0, 0) result = dns.Query() result.decode(f) self.assertEqual(result.name.name, n) self.assertEqual(result.type, dnstype) self.assertEqual(result.cls, dnscls) def test_resourceRecordHeader(self): """ L{dns.RRHeader.encode} encodes the record header's information and writes it to the file-like object passed to it and L{dns.RRHeader.decode} reads from a file-like object to re-construct a L{dns.RRHeader} instance. """ # encode the RR f = BytesIO() dns.RRHeader(b"test.org", 3, 4, 17).encode(f) # decode the result f.seek(0, 0) result = dns.RRHeader() result.decode(f) self.assertEqual(result.name, dns.Name(b"test.org")) self.assertEqual(result.type, 3) self.assertEqual(result.cls, 4) self.assertEqual(result.ttl, 17) def test_resourceRecordHeaderTypeMismatch(self): """ L{RRHeader()} raises L{ValueError} when the given type and the type of the payload don't match. """ with self.assertRaisesRegex(ValueError, r"Payload type \(AAAA\) .* type \(A\)"): dns.RRHeader(type=dns.A, payload=dns.Record_AAAA()) def test_resources(self): """ L{dns.SimpleRecord.encode} encodes the record's name information and writes it to the file-like object passed to it and L{dns.SimpleRecord.decode} reads from a file-like object to re-construct a L{dns.SimpleRecord} instance. """ names = ( b"this.are.test.name", b"will.compress.will.this.will.name.will.hopefully", b"test.CASE.preSErVatIOn.YeAH", b"a.s.h.o.r.t.c.a.s.e.t.o.t.e.s.t", b"singleton", ) for s in names: f = BytesIO() dns.SimpleRecord(s).encode(f) f.seek(0, 0) result = dns.SimpleRecord() result.decode(f) self.assertEqual(result.name, dns.Name(s)) def test_hashable(self): """ Instances of all record types are hashable. """ for k in RECORD_TYPES: k1, k2 = k(), k() hk1 = hash(k1) hk2 = hash(k2) self.assertEqual(hk1, hk2, f"{hk1} != {hk2} (for {k})") def test_Charstr(self): """ Test L{dns.Charstr} encode and decode. """ for n in self.names: # encode the name f = BytesIO() dns.Charstr(n).encode(f) # decode the name f.seek(0, 0) result = dns.Charstr() result.decode(f) self.assertEqual(result.string, n) def _recordRoundtripTest(self, record): """ Assert that encoding C{record} and then decoding the resulting bytes creates a record which compares equal to C{record}. @type record: L{dns.IEncodable} @param record: A record instance to encode """ stream = BytesIO() record.encode(stream) length = stream.tell() stream.seek(0, 0) replica = record.__class__() replica.decode(stream, length) self.assertEqual(record, replica) def assertEncodedFormat(self, expectedEncoding, record): """ Assert that encoding C{record} produces the expected bytes. @type record: L{dns.IEncodable} @param record: A record instance to encode @type expectedEncoding: L{bytes} @param expectedEncoding: The value which C{record.encode()} should produce. """ stream = BytesIO() record.encode(stream) self.assertEqual(stream.getvalue(), expectedEncoding) def test_SOA(self): """ The byte stream written by L{dns.Record_SOA.encode} can be used by L{dns.Record_SOA.decode} to reconstruct the state of the original L{dns.Record_SOA} instance. """ self._recordRoundtripTest( dns.Record_SOA( mname=b"foo", rname=b"bar", serial=12, refresh=34, retry=56, expire=78, minimum=90, ) ) def test_A(self): """ The byte stream written by L{dns.Record_A.encode} can be used by L{dns.Record_A.decode} to reconstruct the state of the original L{dns.Record_A} instance. """ self._recordRoundtripTest(dns.Record_A("1.2.3.4")) def test_NULL(self): """ The byte stream written by L{dns.Record_NULL.encode} can be used by L{dns.Record_NULL.decode} to reconstruct the state of the original L{dns.Record_NULL} instance. """ self._recordRoundtripTest(dns.Record_NULL(b"foo bar")) def test_WKS(self): """ The byte stream written by L{dns.Record_WKS.encode} can be used by L{dns.Record_WKS.decode} to reconstruct the state of the original L{dns.Record_WKS} instance. """ self._recordRoundtripTest(dns.Record_WKS("1.2.3.4", 3, b"xyz")) def test_AAAA(self): """ The byte stream written by L{dns.Record_AAAA.encode} can be used by L{dns.Record_AAAA.decode} to reconstruct the state of the original L{dns.Record_AAAA} instance. """ self._recordRoundtripTest(dns.Record_AAAA("::1")) def test_A6(self): """ The byte stream written by L{dns.Record_A6.encode} can be used by L{dns.Record_A6.decode} to reconstruct the state of the original L{dns.Record_A6} instance. """ self._recordRoundtripTest(dns.Record_A6(8, "::1:2", b"foo")) def test_SRV(self): """ The byte stream written by L{dns.Record_SRV.encode} can be used by L{dns.Record_SRV.decode} to reconstruct the state of the original L{dns.Record_SRV} instance. """ self._recordRoundtripTest( dns.Record_SRV(priority=1, weight=2, port=3, target=b"example.com") ) def test_SSHFP(self): """ The byte stream written by L{dns.Record_SSHFP.encode} can be used by L{dns.Record_SSHFP.decode} to reconstruct the state of the original L{dns.Record_SSHFP} instance. """ fp = ( b"\xda\x39\xa3\xee\x5e\x6b\x4b\x0d" + b"\x32\x55\xbf\xef\x95\x60\x18\x90\xaf\xd8\x07\x09" ) rr = dns.Record_SSHFP( algorithm=dns.Record_SSHFP.ALGORITHM_DSS, fingerprintType=dns.Record_SSHFP.FINGERPRINT_TYPE_SHA1, fingerprint=fp, ) self._recordRoundtripTest(rr) self.assertEncodedFormat(b"\x02\x01" + fp, rr) def test_NAPTR(self): """ Test L{dns.Record_NAPTR} encode and decode. """ naptrs = [ (100, 10, b"u", b"sip+E2U", b"!^.*$!sip:information@domain.tld!", b""), (100, 50, b"s", b"http+I2L+I2C+I2R", b"", b"_http._tcp.gatech.edu"), ] for (order, preference, flags, service, regexp, replacement) in naptrs: rin = dns.Record_NAPTR( order, preference, flags, service, regexp, replacement ) e = BytesIO() rin.encode(e) e.seek(0, 0) rout = dns.Record_NAPTR() rout.decode(e) self.assertEqual(rin.order, rout.order) self.assertEqual(rin.preference, rout.preference) self.assertEqual(rin.flags, rout.flags) self.assertEqual(rin.service, rout.service) self.assertEqual(rin.regexp, rout.regexp) self.assertEqual(rin.replacement.name, rout.replacement.name) self.assertEqual(rin.ttl, rout.ttl) def test_AFSDB(self): """ The byte stream written by L{dns.Record_AFSDB.encode} can be used by L{dns.Record_AFSDB.decode} to reconstruct the state of the original L{dns.Record_AFSDB} instance. """ self._recordRoundtripTest(dns.Record_AFSDB(subtype=3, hostname=b"example.com")) def test_RP(self): """ The byte stream written by L{dns.Record_RP.encode} can be used by L{dns.Record_RP.decode} to reconstruct the state of the original L{dns.Record_RP} instance. """ self._recordRoundtripTest( dns.Record_RP(mbox=b"alice.example.com", txt=b"example.com") ) def test_HINFO(self): """ The byte stream written by L{dns.Record_HINFO.encode} can be used by L{dns.Record_HINFO.decode} to reconstruct the state of the original L{dns.Record_HINFO} instance. """ self._recordRoundtripTest(dns.Record_HINFO(cpu=b"fast", os=b"great")) def test_MINFO(self): """ The byte stream written by L{dns.Record_MINFO.encode} can be used by L{dns.Record_MINFO.decode} to reconstruct the state of the original L{dns.Record_MINFO} instance. """ self._recordRoundtripTest(dns.Record_MINFO(rmailbx=b"foo", emailbx=b"bar")) def test_MX(self): """ The byte stream written by L{dns.Record_MX.encode} can be used by L{dns.Record_MX.decode} to reconstruct the state of the original L{dns.Record_MX} instance. """ self._recordRoundtripTest(dns.Record_MX(preference=1, name=b"example.com")) def test_TSIG(self): """ The byte stream written by L{dns.Record_TSIG.encode} can be used by L{dns.Record_TSIG.decode} to reconstruct the state of the original L{dns.Record_TSIG} instance. """ mac = b"\x00\x01\x02\x03\x10\x11\x12\x13" b"\x20\x21\x22\x23\x30\x31\x32\x33" rr = dns.Record_TSIG( algorithm="hmac-md5.sig-alg.reg.int", timeSigned=1515548975, originalID=42, fudge=5, MAC=mac, ) self._recordRoundtripTest(rr) rdata = ( b"\x08hmac-md5\x07sig-alg\x03reg\x03int\x00" b"\x00\x00\x5a\x55\x71\x2f\x00\x05\x00\x10" + mac + b"\x00\x2A\x00\x00\x00\x00" ) self.assertEncodedFormat(rdata, rr) rr = dns.Record_TSIG( algorithm="hmac-sha256", timeSigned=4511798055, # More than 32 bits originalID=65535, error=dns.EBADTIME, otherData=b"\x80\x00\x00\x00\x00\x08", MAC=mac, ) self._recordRoundtripTest(rr) rdata = ( b"\x0Bhmac-sha256\x00" b"\x00\x01\x0c\xec\x93\x27\x00\x05\x00\x10" + mac + b"\xff\xff\x00\x12\x00\x06" b"\x80\x00\x00\x00\x00\x08" ) self.assertEncodedFormat(rdata, rr) def test_TXT(self): """ The byte stream written by L{dns.Record_TXT.encode} can be used by L{dns.Record_TXT.decode} to reconstruct the state of the original L{dns.Record_TXT} instance. """ self._recordRoundtripTest(dns.Record_TXT(b"foo", b"bar")) MESSAGE_AUTHENTIC_DATA_BYTES = ( b"\x00\x00" # ID b"\x00" # b"\x20" # RA, Z, AD=1, CD, RCODE b"\x00\x00" # Query count b"\x00\x00" # Answer count b"\x00\x00" # Authority count b"\x00\x00" # Additional count ) MESSAGE_CHECKING_DISABLED_BYTES = ( b"\x00\x00" # ID b"\x00" # b"\x10" # RA, Z, AD, CD=1, RCODE b"\x00\x00" # Query count b"\x00\x00" # Answer count b"\x00\x00" # Authority count b"\x00\x00" # Additional count ) class MessageTests(unittest.SynchronousTestCase): """ Tests for L{twisted.names.dns.Message}. """ def test_authenticDataDefault(self): """ L{dns.Message.authenticData} has default value 0. """ self.assertEqual(dns.Message().authenticData, 0) def test_authenticDataOverride(self): """ L{dns.Message.__init__} accepts a C{authenticData} argument which is assigned to L{dns.Message.authenticData}. """ self.assertEqual(dns.Message(authenticData=1).authenticData, 1) def test_authenticDataEncode(self): """ L{dns.Message.toStr} encodes L{dns.Message.authenticData} into byte4 of the byte string. """ self.assertEqual( dns.Message(authenticData=1).toStr(), MESSAGE_AUTHENTIC_DATA_BYTES ) def test_authenticDataDecode(self): """ L{dns.Message.fromStr} decodes byte4 and assigns bit3 to L{dns.Message.authenticData}. """ m = dns.Message() m.fromStr(MESSAGE_AUTHENTIC_DATA_BYTES) self.assertEqual(m.authenticData, 1) def test_checkingDisabledDefault(self): """ L{dns.Message.checkingDisabled} has default value 0. """ self.assertEqual(dns.Message().checkingDisabled, 0) def test_checkingDisabledOverride(self): """ L{dns.Message.__init__} accepts a C{checkingDisabled} argument which is assigned to L{dns.Message.checkingDisabled}. """ self.assertEqual(dns.Message(checkingDisabled=1).checkingDisabled, 1) def test_checkingDisabledEncode(self): """ L{dns.Message.toStr} encodes L{dns.Message.checkingDisabled} into byte4 of the byte string. """ self.assertEqual( dns.Message(checkingDisabled=1).toStr(), MESSAGE_CHECKING_DISABLED_BYTES ) def test_checkingDisabledDecode(self): """ L{dns.Message.fromStr} decodes byte4 and assigns bit4 to L{dns.Message.checkingDisabled}. """ m = dns.Message() m.fromStr(MESSAGE_CHECKING_DISABLED_BYTES) self.assertEqual(m.checkingDisabled, 1) def test_reprDefaults(self): """ L{dns.Message.__repr__} omits field values and sections which are identical to their defaults. The id field value is always shown. """ self.assertEqual("", repr(dns.Message())) def test_reprFlagsIfSet(self): """ L{dns.Message.__repr__} displays flags if they are L{True}. """ m = dns.Message( answer=True, auth=True, trunc=True, recDes=True, recAv=True, authenticData=True, checkingDisabled=True, ) self.assertEqual( "", repr(m), ) def test_reprNonDefautFields(self): """ L{dns.Message.__repr__} displays field values if they differ from their defaults. """ m = dns.Message(id=10, opCode=20, rCode=30, maxSize=40) self.assertEqual( "", repr(m), ) def test_reprNonDefaultSections(self): """ L{dns.Message.__repr__} displays sections which differ from their defaults. """ m = dns.Message() m.queries = [1, 2, 3] m.answers = [4, 5, 6] m.authority = [7, 8, 9] m.additional = [10, 11, 12] self.assertEqual( "", repr(m), ) def test_emptyMessage(self): """ Test that a message which has been truncated causes an EOFError to be raised when it is parsed. """ msg = dns.Message() self.assertRaises(EOFError, msg.fromStr, b"") def test_emptyQuery(self): """ Test that bytes representing an empty query message can be decoded as such. """ msg = dns.Message() msg.fromStr( b"\x01\x00" # Message ID b"\x00" # answer bit, opCode nibble, auth bit, trunc bit, recursive bit b"\x00" # recursion bit, empty bit, authenticData bit, # checkingDisabled bit, response code nibble b"\x00\x00" # number of queries b"\x00\x00" # number of answers b"\x00\x00" # number of authorities b"\x00\x00" # number of additionals ) self.assertEqual(msg.id, 256) self.assertFalse(msg.answer, "Message was not supposed to be an answer.") self.assertEqual(msg.opCode, dns.OP_QUERY) self.assertFalse(msg.auth, "Message was not supposed to be authoritative.") self.assertFalse(msg.trunc, "Message was not supposed to be truncated.") self.assertEqual(msg.queries, []) self.assertEqual(msg.answers, []) self.assertEqual(msg.authority, []) self.assertEqual(msg.additional, []) def test_NULL(self): """ A I{NULL} record with an arbitrary payload can be encoded and decoded as part of a L{dns.Message}. """ bytes = b"".join([dns._ord2bytes(i) for i in range(256)]) rec = dns.Record_NULL(bytes) rr = dns.RRHeader(b"testname", dns.NULL, payload=rec) msg1 = dns.Message() msg1.answers.append(rr) s = BytesIO() msg1.encode(s) s.seek(0, 0) msg2 = dns.Message() msg2.decode(s) self.assertIsInstance(msg2.answers[0].payload, dns.Record_NULL) self.assertEqual(msg2.answers[0].payload.payload, bytes) def test_lookupRecordTypeDefault(self): """ L{Message.lookupRecordType} returns C{dns.UnknownRecord} if it is called with an integer which doesn't correspond to any known record type. """ # 65280 is the first value in the range reserved for private # use, so it shouldn't ever conflict with an officially # allocated value. self.assertIs(dns.Message().lookupRecordType(65280), dns.UnknownRecord) def test_nonAuthoritativeMessage(self): """ The L{RRHeader} instances created by L{Message} from a non-authoritative message are marked as not authoritative. """ buf = BytesIO() answer = dns.RRHeader(payload=dns.Record_A("1.2.3.4", ttl=0)) answer.encode(buf) message = dns.Message() message.fromStr( b"\x01\x00" # Message ID # answer bit, opCode nibble, auth bit, trunc bit, recursive bit b"\x00" # recursion bit, empty bit, authenticData bit, # checkingDisabled bit, response code nibble b"\x00" b"\x00\x00" # number of queries b"\x00\x01" # number of answers b"\x00\x00" # number of authorities b"\x00\x00" + buf.getvalue() # number of additionals ) self.assertEqual(message.answers, [answer]) self.assertFalse(message.answers[0].auth) def test_authoritativeMessage(self): """ The L{RRHeader} instances created by L{Message} from an authoritative message are marked as authoritative. """ buf = BytesIO() answer = dns.RRHeader(payload=dns.Record_A("1.2.3.4", ttl=0)) answer.encode(buf) message = dns.Message() message.fromStr( b"\x01\x00" # Message ID # answer bit, opCode nibble, auth bit, trunc bit, recursive bit b"\x04" # recursion bit, empty bit, authenticData bit, # checkingDisabled bit, response code nibble b"\x00" b"\x00\x00" # number of queries b"\x00\x01" # number of answers b"\x00\x00" # number of authorities b"\x00\x00" + buf.getvalue() # number of additionals ) answer.auth = True self.assertEqual(message.answers, [answer]) self.assertTrue(message.answers[0].auth) class MessageComparisonTests(ComparisonTestsMixin, unittest.SynchronousTestCase): """ Tests for the rich comparison of L{dns.Message} instances. """ def messageFactory(self, *args, **kwargs): """ Create a L{dns.Message}. The L{dns.Message} constructor doesn't accept C{queries}, C{answers}, C{authority}, C{additional} arguments, so we extract them from the kwargs supplied to this factory function and assign them to the message. @param args: Positional arguments. @param kwargs: Keyword arguments. @return: A L{dns.Message} instance. """ queries = kwargs.pop("queries", []) answers = kwargs.pop("answers", []) authority = kwargs.pop("authority", []) additional = kwargs.pop("additional", []) m = dns.Message(**kwargs) if queries: m.queries = queries if answers: m.answers = answers if authority: m.authority = authority if additional: m.additional = additional return m def test_id(self): """ Two L{dns.Message} instances compare equal if they have the same id value. """ self.assertNormalEqualityImplementation( self.messageFactory(id=10), self.messageFactory(id=10), self.messageFactory(id=20), ) def test_answer(self): """ Two L{dns.Message} instances compare equal if they have the same answer flag. """ self.assertNormalEqualityImplementation( self.messageFactory(answer=1), self.messageFactory(answer=1), self.messageFactory(answer=0), ) def test_opCode(self): """ Two L{dns.Message} instances compare equal if they have the same opCode value. """ self.assertNormalEqualityImplementation( self.messageFactory(opCode=10), self.messageFactory(opCode=10), self.messageFactory(opCode=20), ) def test_recDes(self): """ Two L{dns.Message} instances compare equal if they have the same recDes flag. """ self.assertNormalEqualityImplementation( self.messageFactory(recDes=1), self.messageFactory(recDes=1), self.messageFactory(recDes=0), ) def test_recAv(self): """ Two L{dns.Message} instances compare equal if they have the same recAv flag. """ self.assertNormalEqualityImplementation( self.messageFactory(recAv=1), self.messageFactory(recAv=1), self.messageFactory(recAv=0), ) def test_auth(self): """ Two L{dns.Message} instances compare equal if they have the same auth flag. """ self.assertNormalEqualityImplementation( self.messageFactory(auth=1), self.messageFactory(auth=1), self.messageFactory(auth=0), ) def test_rCode(self): """ Two L{dns.Message} instances compare equal if they have the same rCode value. """ self.assertNormalEqualityImplementation( self.messageFactory(rCode=10), self.messageFactory(rCode=10), self.messageFactory(rCode=20), ) def test_trunc(self): """ Two L{dns.Message} instances compare equal if they have the same trunc flag. """ self.assertNormalEqualityImplementation( self.messageFactory(trunc=1), self.messageFactory(trunc=1), self.messageFactory(trunc=0), ) def test_maxSize(self): """ Two L{dns.Message} instances compare equal if they have the same maxSize value. """ self.assertNormalEqualityImplementation( self.messageFactory(maxSize=10), self.messageFactory(maxSize=10), self.messageFactory(maxSize=20), ) def test_authenticData(self): """ Two L{dns.Message} instances compare equal if they have the same authenticData flag. """ self.assertNormalEqualityImplementation( self.messageFactory(authenticData=1), self.messageFactory(authenticData=1), self.messageFactory(authenticData=0), ) def test_checkingDisabled(self): """ Two L{dns.Message} instances compare equal if they have the same checkingDisabled flag. """ self.assertNormalEqualityImplementation( self.messageFactory(checkingDisabled=1), self.messageFactory(checkingDisabled=1), self.messageFactory(checkingDisabled=0), ) def test_queries(self): """ Two L{dns.Message} instances compare equal if they have the same queries. """ self.assertNormalEqualityImplementation( self.messageFactory(queries=[dns.Query(b"example.com")]), self.messageFactory(queries=[dns.Query(b"example.com")]), self.messageFactory(queries=[dns.Query(b"example.org")]), ) def test_answers(self): """ Two L{dns.Message} instances compare equal if they have the same answers. """ self.assertNormalEqualityImplementation( self.messageFactory( answers=[dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4"))] ), self.messageFactory( answers=[dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4"))] ), self.messageFactory( answers=[dns.RRHeader(b"example.org", payload=dns.Record_A("4.3.2.1"))] ), ) def test_authority(self): """ Two L{dns.Message} instances compare equal if they have the same authority records. """ self.assertNormalEqualityImplementation( self.messageFactory( authority=[ dns.RRHeader(b"example.com", type=dns.SOA, payload=dns.Record_SOA()) ] ), self.messageFactory( authority=[ dns.RRHeader(b"example.com", type=dns.SOA, payload=dns.Record_SOA()) ] ), self.messageFactory( authority=[ dns.RRHeader(b"example.org", type=dns.SOA, payload=dns.Record_SOA()) ] ), ) def test_additional(self): """ Two L{dns.Message} instances compare equal if they have the same additional records. """ self.assertNormalEqualityImplementation( self.messageFactory( additional=[ dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4")) ] ), self.messageFactory( additional=[ dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4")) ] ), self.messageFactory( additional=[ dns.RRHeader(b"example.org", payload=dns.Record_A("1.2.3.4")) ] ), ) class TestController: """ Pretend to be a DNS query processor for a DNSDatagramProtocol. @ivar messages: the list of received messages. @type messages: C{list} of (msg, protocol, address) """ def __init__(self): """ Initialize the controller: create a list of messages. """ self.messages = [] def messageReceived(self, msg, proto, addr=None): """ Save the message so that it can be checked during the tests. """ self.messages.append((msg, proto, addr)) class DatagramProtocolTests(unittest.TestCase): """ Test various aspects of L{dns.DNSDatagramProtocol}. """ def setUp(self): """ Create a L{dns.DNSDatagramProtocol} with a deterministic clock. """ self.clock = task.Clock() self.controller = TestController() self.proto = dns.DNSDatagramProtocol(self.controller) transport = proto_helpers.FakeDatagramTransport() self.proto.makeConnection(transport) self.proto.callLater = self.clock.callLater def test_truncatedPacket(self): """ Test that when a short datagram is received, datagramReceived does not raise an exception while processing it. """ self.proto.datagramReceived(b"", address.IPv4Address("UDP", "127.0.0.1", 12345)) self.assertEqual(self.controller.messages, []) def test_simpleQuery(self): """ Test content received after a query. """ d = self.proto.query(("127.0.0.1", 21345), [dns.Query(b"foo")]) self.assertEqual(len(self.proto.liveMessages.keys()), 1) m = dns.Message() m.id = next(iter(self.proto.liveMessages.keys())) m.answers = [dns.RRHeader(payload=dns.Record_A(address="1.2.3.4"))] def cb(result): self.assertEqual(result.answers[0].payload.dottedQuad(), "1.2.3.4") d.addCallback(cb) self.proto.datagramReceived(m.toStr(), ("127.0.0.1", 21345)) return d def test_queryTimeout(self): """ Test that query timeouts after some seconds. """ d = self.proto.query(("127.0.0.1", 21345), [dns.Query(b"foo")]) self.assertEqual(len(self.proto.liveMessages), 1) self.clock.advance(10) self.assertFailure(d, dns.DNSQueryTimeoutError) self.assertEqual(len(self.proto.liveMessages), 0) return d def test_writeError(self): """ Exceptions raised by the transport's write method should be turned into C{Failure}s passed to errbacks of the C{Deferred} returned by L{DNSDatagramProtocol.query}. """ def writeError(message, addr): raise RuntimeError("bar") self.proto.transport.write = writeError d = self.proto.query(("127.0.0.1", 21345), [dns.Query(b"foo")]) return self.assertFailure(d, RuntimeError) def test_listenError(self): """ Exception L{CannotListenError} raised by C{listenUDP} should be turned into a C{Failure} passed to errback of the C{Deferred} returned by L{DNSDatagramProtocol.query}. """ def startListeningError(): raise CannotListenError(None, None, None) self.proto.startListening = startListeningError # Clean up transport so that the protocol calls startListening again self.proto.transport = None d = self.proto.query(("127.0.0.1", 21345), [dns.Query(b"foo")]) return self.assertFailure(d, CannotListenError) def test_receiveMessageNotInLiveMessages(self): """ When receiving a message whose id is not in L{DNSDatagramProtocol.liveMessages} or L{DNSDatagramProtocol.resends}, the message will be received by L{DNSDatagramProtocol.controller}. """ message = dns.Message() message.id = 1 message.answers = [dns.RRHeader(payload=dns.Record_A(address="1.2.3.4"))] self.proto.datagramReceived(message.toStr(), ("127.0.0.1", 21345)) self.assertEqual(self.controller.messages[-1][0].toStr(), message.toStr()) class TestTCPController(TestController): """ Pretend to be a DNS query processor for a DNSProtocol. @ivar connections: A list of L{DNSProtocol} instances which have notified this controller that they are connected and have not yet notified it that their connection has been lost. """ def __init__(self): TestController.__init__(self) self.connections = [] def connectionMade(self, proto): self.connections.append(proto) def connectionLost(self, proto): self.connections.remove(proto) class DNSProtocolTests(unittest.TestCase): """ Test various aspects of L{dns.DNSProtocol}. """ def setUp(self): """ Create a L{dns.DNSProtocol} with a deterministic clock. """ self.clock = task.Clock() self.controller = TestTCPController() self.proto = dns.DNSProtocol(self.controller) self.proto.makeConnection(proto_helpers.StringTransport()) self.proto.callLater = self.clock.callLater def test_connectionTracking(self): """ L{dns.DNSProtocol} calls its controller's C{connectionMade} method with itself when it is connected to a transport and its controller's C{connectionLost} method when it is disconnected. """ self.assertEqual(self.controller.connections, [self.proto]) self.proto.connectionLost(Failure(ConnectionDone("Fake Connection Done"))) self.assertEqual(self.controller.connections, []) def test_queryTimeout(self): """ Test that query timeouts after some seconds. """ d = self.proto.query([dns.Query(b"foo")]) self.assertEqual(len(self.proto.liveMessages), 1) self.clock.advance(60) self.assertFailure(d, dns.DNSQueryTimeoutError) self.assertEqual(len(self.proto.liveMessages), 0) return d def test_simpleQuery(self): """ Test content received after a query. """ d = self.proto.query([dns.Query(b"foo")]) self.assertEqual(len(self.proto.liveMessages.keys()), 1) m = dns.Message() m.id = next(iter(self.proto.liveMessages.keys())) m.answers = [dns.RRHeader(payload=dns.Record_A(address="1.2.3.4"))] def cb(result): self.assertEqual(result.answers[0].payload.dottedQuad(), "1.2.3.4") d.addCallback(cb) s = m.toStr() s = struct.pack("!H", len(s)) + s self.proto.dataReceived(s) return d def test_writeError(self): """ Exceptions raised by the transport's write method should be turned into C{Failure}s passed to errbacks of the C{Deferred} returned by L{DNSProtocol.query}. """ def writeError(message): raise RuntimeError("bar") self.proto.transport.write = writeError d = self.proto.query([dns.Query(b"foo")]) return self.assertFailure(d, RuntimeError) def test_receiveMessageNotInLiveMessages(self): """ When receiving a message whose id is not in L{DNSProtocol.liveMessages} the message will be received by L{DNSProtocol.controller}. """ message = dns.Message() message.id = 1 message.answers = [dns.RRHeader(payload=dns.Record_A(address="1.2.3.4"))] string = message.toStr() string = struct.pack("!H", len(string)) + string self.proto.dataReceived(string) self.assertEqual(self.controller.messages[-1][0].toStr(), message.toStr()) class ReprTests(unittest.TestCase): """ Tests for the C{__repr__} implementation of record classes. """ def test_ns(self): """ The repr of a L{dns.Record_NS} instance includes the name of the nameserver and the TTL of the record. """ self.assertEqual( repr(dns.Record_NS(b"example.com", 4321)), "" ) def test_md(self): """ The repr of a L{dns.Record_MD} instance includes the name of the mail destination and the TTL of the record. """ self.assertEqual( repr(dns.Record_MD(b"example.com", 4321)), "" ) def test_mf(self): """ The repr of a L{dns.Record_MF} instance includes the name of the mail forwarder and the TTL of the record. """ self.assertEqual( repr(dns.Record_MF(b"example.com", 4321)), "" ) def test_cname(self): """ The repr of a L{dns.Record_CNAME} instance includes the name of the mail forwarder and the TTL of the record. """ self.assertEqual( repr(dns.Record_CNAME(b"example.com", 4321)), "", ) def test_mb(self): """ The repr of a L{dns.Record_MB} instance includes the name of the mailbox and the TTL of the record. """ self.assertEqual( repr(dns.Record_MB(b"example.com", 4321)), "" ) def test_mg(self): """ The repr of a L{dns.Record_MG} instance includes the name of the mail group member and the TTL of the record. """ self.assertEqual( repr(dns.Record_MG(b"example.com", 4321)), "" ) def test_mr(self): """ The repr of a L{dns.Record_MR} instance includes the name of the mail rename domain and the TTL of the record. """ self.assertEqual( repr(dns.Record_MR(b"example.com", 4321)), "" ) def test_ptr(self): """ The repr of a L{dns.Record_PTR} instance includes the name of the pointer and the TTL of the record. """ self.assertEqual( repr(dns.Record_PTR(b"example.com", 4321)), "", ) def test_dname(self): """ The repr of a L{dns.Record_DNAME} instance includes the name of the non-terminal DNS name redirection and the TTL of the record. """ self.assertEqual( repr(dns.Record_DNAME(b"example.com", 4321)), "", ) def test_a(self): """ The repr of a L{dns.Record_A} instance includes the dotted-quad string representation of the address it is for and the TTL of the record. """ self.assertEqual( repr(dns.Record_A("1.2.3.4", 567)), "" ) def test_soa(self): """ The repr of a L{dns.Record_SOA} instance includes all of the authority fields. """ self.assertEqual( repr( dns.Record_SOA( mname=b"mName", rname=b"rName", serial=123, refresh=456, retry=789, expire=10, minimum=11, ttl=12, ) ), "", ) def test_null(self): """ The repr of a L{dns.Record_NULL} instance includes the repr of its payload and the TTL of the record. """ self.assertEqual( repr(dns.Record_NULL(b"abcd", 123)), "" ) def test_wks(self): """ The repr of a L{dns.Record_WKS} instance includes the dotted-quad string representation of the address it is for, the IP protocol number it is for, and the TTL of the record. """ self.assertEqual( repr(dns.Record_WKS("2.3.4.5", 7, ttl=8)), "", ) def test_aaaa(self): """ The repr of a L{dns.Record_AAAA} instance includes the colon-separated hex string representation of the address it is for and the TTL of the record. """ self.assertEqual( repr(dns.Record_AAAA("8765::1234", ttl=10)), "", ) def test_a6(self): """ The repr of a L{dns.Record_A6} instance includes the colon-separated hex string representation of the address it is for and the TTL of the record. """ self.assertEqual( repr(dns.Record_A6(0, "1234::5678", b"foo.bar", ttl=10)), "", ) def test_srv(self): """ The repr of a L{dns.Record_SRV} instance includes the name and port of the target and the priority, weight, and TTL of the record. """ self.assertEqual( repr(dns.Record_SRV(1, 2, 3, b"example.org", 4)), "", ) def test_naptr(self): """ The repr of a L{dns.Record_NAPTR} instance includes the order, preference, flags, service, regular expression, replacement, and TTL of the record. """ record = dns.Record_NAPTR(5, 9, b"S", b"http", b"/foo/bar/i", b"baz", 3) self.assertEqual( repr(record), "", ) def test_afsdb(self): """ The repr of a L{dns.Record_AFSDB} instance includes the subtype, hostname, and TTL of the record. """ self.assertEqual( repr(dns.Record_AFSDB(3, b"example.org", 5)), "", ) def test_rp(self): """ The repr of a L{dns.Record_RP} instance includes the mbox, txt, and TTL fields of the record. """ self.assertEqual( repr(dns.Record_RP(b"alice.example.com", b"admin.example.com", 3)), "", ) def test_hinfo(self): """ The repr of a L{dns.Record_HINFO} instance includes the cpu, os, and TTL fields of the record. """ self.assertEqual( repr(dns.Record_HINFO(b"sparc", b"minix", 12)), "", ) def test_minfo(self): """ The repr of a L{dns.Record_MINFO} instance includes the rmailbx, emailbx, and TTL fields of the record. """ record = dns.Record_MINFO(b"alice.example.com", b"bob.example.com", 15) self.assertEqual( repr(record), "", ) def test_mx(self): """ The repr of a L{dns.Record_MX} instance includes the preference, name, and TTL fields of the record. """ self.assertEqual( repr(dns.Record_MX(13, b"mx.example.com", 2)), "", ) def test_txt(self): """ The repr of a L{dns.Record_TXT} instance includes the data and ttl fields of the record. """ self.assertEqual( repr(dns.Record_TXT(b"foo", b"bar", ttl=15)), "", ) def test_spf(self): """ The repr of a L{dns.Record_SPF} instance includes the data and ttl fields of the record. """ self.assertEqual( repr(dns.Record_SPF(b"foo", b"bar", ttl=15)), "", ) def test_unknown(self): """ The repr of a L{dns.UnknownRecord} instance includes the data and ttl fields of the record. """ self.assertEqual( repr(dns.UnknownRecord(b"foo\x1fbar", 12)), "", ) class EqualityTests(ComparisonTestsMixin, unittest.TestCase): """ Tests for the equality and non-equality behavior of record classes. """ def _equalityTest(self, firstValueOne, secondValueOne, valueTwo): return self.assertNormalEqualityImplementation( firstValueOne, secondValueOne, valueTwo ) def test_charstr(self): """ Two L{dns.Charstr} instances compare equal if and only if they have the same string value. """ self._equalityTest( dns.Charstr(b"abc"), dns.Charstr(b"abc"), dns.Charstr(b"def") ) def test_name(self): """ Two L{dns.Name} instances compare equal if and only if they have the same name value. """ self._equalityTest(dns.Name(b"abc"), dns.Name(b"abc"), dns.Name(b"def")) def _simpleEqualityTest(self, cls): """ Assert that instances of C{cls} with the same attributes compare equal to each other and instances with different attributes compare as not equal. @param cls: A L{dns.SimpleRecord} subclass. """ # Vary the TTL self._equalityTest( cls(b"example.com", 123), cls(b"example.com", 123), cls(b"example.com", 321) ) # Vary the name self._equalityTest( cls(b"example.com", 123), cls(b"example.com", 123), cls(b"example.org", 123) ) def test_rrheader(self): """ Two L{dns.RRHeader} instances compare equal if and only if they have the same name, type, class, time to live, payload, and authoritative bit. """ # Vary the name self._equalityTest( dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4")), dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4")), dns.RRHeader(b"example.org", payload=dns.Record_A("1.2.3.4")), ) # Vary the payload self._equalityTest( dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4")), dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4")), dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.5")), ) # Vary the type. Leave the payload as None so that we don't have to # provide non-equal values. self._equalityTest( dns.RRHeader(b"example.com", dns.A), dns.RRHeader(b"example.com", dns.A), dns.RRHeader(b"example.com", dns.MX), ) # Probably not likely to come up. Most people use the internet. self._equalityTest( dns.RRHeader(b"example.com", cls=dns.IN, payload=dns.Record_A("1.2.3.4")), dns.RRHeader(b"example.com", cls=dns.IN, payload=dns.Record_A("1.2.3.4")), dns.RRHeader(b"example.com", cls=dns.CS, payload=dns.Record_A("1.2.3.4")), ) # Vary the ttl self._equalityTest( dns.RRHeader(b"example.com", ttl=60, payload=dns.Record_A("1.2.3.4")), dns.RRHeader(b"example.com", ttl=60, payload=dns.Record_A("1.2.3.4")), dns.RRHeader(b"example.com", ttl=120, payload=dns.Record_A("1.2.3.4")), ) # Vary the auth bit self._equalityTest( dns.RRHeader(b"example.com", auth=1, payload=dns.Record_A("1.2.3.4")), dns.RRHeader(b"example.com", auth=1, payload=dns.Record_A("1.2.3.4")), dns.RRHeader(b"example.com", auth=0, payload=dns.Record_A("1.2.3.4")), ) def test_ns(self): """ Two L{dns.Record_NS} instances compare equal if and only if they have the same name and TTL. """ self._simpleEqualityTest(dns.Record_NS) def test_md(self): """ Two L{dns.Record_MD} instances compare equal if and only if they have the same name and TTL. """ self._simpleEqualityTest(dns.Record_MD) def test_mf(self): """ Two L{dns.Record_MF} instances compare equal if and only if they have the same name and TTL. """ self._simpleEqualityTest(dns.Record_MF) def test_cname(self): """ Two L{dns.Record_CNAME} instances compare equal if and only if they have the same name and TTL. """ self._simpleEqualityTest(dns.Record_CNAME) def test_mb(self): """ Two L{dns.Record_MB} instances compare equal if and only if they have the same name and TTL. """ self._simpleEqualityTest(dns.Record_MB) def test_mg(self): """ Two L{dns.Record_MG} instances compare equal if and only if they have the same name and TTL. """ self._simpleEqualityTest(dns.Record_MG) def test_mr(self): """ Two L{dns.Record_MR} instances compare equal if and only if they have the same name and TTL. """ self._simpleEqualityTest(dns.Record_MR) def test_ptr(self): """ Two L{dns.Record_PTR} instances compare equal if and only if they have the same name and TTL. """ self._simpleEqualityTest(dns.Record_PTR) def test_dname(self): """ Two L{dns.Record_MD} instances compare equal if and only if they have the same name and TTL. """ self._simpleEqualityTest(dns.Record_DNAME) def test_a(self): """ Two L{dns.Record_A} instances compare equal if and only if they have the same address and TTL. """ # Vary the TTL self._equalityTest( dns.Record_A("1.2.3.4", 5), dns.Record_A("1.2.3.4", 5), dns.Record_A("1.2.3.4", 6), ) # Vary the address self._equalityTest( dns.Record_A("1.2.3.4", 5), dns.Record_A("1.2.3.4", 5), dns.Record_A("1.2.3.5", 5), ) def test_soa(self): """ Two L{dns.Record_SOA} instances compare equal if and only if they have the same mname, rname, serial, refresh, minimum, expire, retry, and ttl. """ # Vary the mname self._equalityTest( dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30), dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30), dns.Record_SOA(b"xname", b"rname", 123, 456, 789, 10, 20, 30), ) # Vary the rname self._equalityTest( dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30), dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30), dns.Record_SOA(b"mname", b"xname", 123, 456, 789, 10, 20, 30), ) # Vary the serial self._equalityTest( dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30), dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30), dns.Record_SOA(b"mname", b"rname", 1, 456, 789, 10, 20, 30), ) # Vary the refresh self._equalityTest( dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30), dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30), dns.Record_SOA(b"mname", b"rname", 123, 1, 789, 10, 20, 30), ) # Vary the minimum self._equalityTest( dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30), dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30), dns.Record_SOA(b"mname", b"rname", 123, 456, 1, 10, 20, 30), ) # Vary the expire self._equalityTest( dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30), dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30), dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 1, 20, 30), ) # Vary the retry self._equalityTest( dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30), dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30), dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 1, 30), ) # Vary the ttl self._equalityTest( dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30), dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30), dns.Record_SOA(b"mname", b"xname", 123, 456, 789, 10, 20, 1), ) def test_null(self): """ Two L{dns.Record_NULL} instances compare equal if and only if they have the same payload and ttl. """ # Vary the payload self._equalityTest( dns.Record_NULL("foo bar", 10), dns.Record_NULL("foo bar", 10), dns.Record_NULL("bar foo", 10), ) # Vary the ttl self._equalityTest( dns.Record_NULL("foo bar", 10), dns.Record_NULL("foo bar", 10), dns.Record_NULL("foo bar", 100), ) def test_wks(self): """ Two L{dns.Record_WKS} instances compare equal if and only if they have the same address, protocol, map, and ttl. """ # Vary the address self._equalityTest( dns.Record_WKS("1.2.3.4", 1, "foo", 2), dns.Record_WKS("1.2.3.4", 1, "foo", 2), dns.Record_WKS("4.3.2.1", 1, "foo", 2), ) # Vary the protocol self._equalityTest( dns.Record_WKS("1.2.3.4", 1, "foo", 2), dns.Record_WKS("1.2.3.4", 1, "foo", 2), dns.Record_WKS("1.2.3.4", 100, "foo", 2), ) # Vary the map self._equalityTest( dns.Record_WKS("1.2.3.4", 1, "foo", 2), dns.Record_WKS("1.2.3.4", 1, "foo", 2), dns.Record_WKS("1.2.3.4", 1, "bar", 2), ) # Vary the ttl self._equalityTest( dns.Record_WKS("1.2.3.4", 1, "foo", 2), dns.Record_WKS("1.2.3.4", 1, "foo", 2), dns.Record_WKS("1.2.3.4", 1, "foo", 200), ) def test_aaaa(self): """ Two L{dns.Record_AAAA} instances compare equal if and only if they have the same address and ttl. """ # Vary the address self._equalityTest( dns.Record_AAAA("1::2", 1), dns.Record_AAAA("1::2", 1), dns.Record_AAAA("2::1", 1), ) # Vary the ttl self._equalityTest( dns.Record_AAAA("1::2", 1), dns.Record_AAAA("1::2", 1), dns.Record_AAAA("1::2", 10), ) def test_a6(self): """ Two L{dns.Record_A6} instances compare equal if and only if they have the same prefix, prefix length, suffix, and ttl. """ # Note, A6 is crazy, I'm not sure these values are actually legal. # Hopefully that doesn't matter for this test. -exarkun # Vary the prefix length self._equalityTest( dns.Record_A6(16, "::abcd", b"example.com", 10), dns.Record_A6(16, "::abcd", b"example.com", 10), dns.Record_A6(32, "::abcd", b"example.com", 10), ) # Vary the suffix self._equalityTest( dns.Record_A6(16, "::abcd", b"example.com", 10), dns.Record_A6(16, "::abcd", b"example.com", 10), dns.Record_A6(16, "::abcd:0", b"example.com", 10), ) # Vary the prefix self._equalityTest( dns.Record_A6(16, "::abcd", b"example.com", 10), dns.Record_A6(16, "::abcd", b"example.com", 10), dns.Record_A6(16, "::abcd", b"example.org", 10), ) # Vary the ttl self._equalityTest( dns.Record_A6(16, "::abcd", b"example.com", 10), dns.Record_A6(16, "::abcd", b"example.com", 10), dns.Record_A6(16, "::abcd", b"example.com", 100), ) def test_srv(self): """ Two L{dns.Record_SRV} instances compare equal if and only if they have the same priority, weight, port, target, and ttl. """ # Vary the priority self._equalityTest( dns.Record_SRV(10, 20, 30, b"example.com", 40), dns.Record_SRV(10, 20, 30, b"example.com", 40), dns.Record_SRV(100, 20, 30, b"example.com", 40), ) # Vary the weight self._equalityTest( dns.Record_SRV(10, 20, 30, b"example.com", 40), dns.Record_SRV(10, 20, 30, b"example.com", 40), dns.Record_SRV(10, 200, 30, b"example.com", 40), ) # Vary the port self._equalityTest( dns.Record_SRV(10, 20, 30, b"example.com", 40), dns.Record_SRV(10, 20, 30, b"example.com", 40), dns.Record_SRV(10, 20, 300, b"example.com", 40), ) # Vary the target self._equalityTest( dns.Record_SRV(10, 20, 30, b"example.com", 40), dns.Record_SRV(10, 20, 30, b"example.com", 40), dns.Record_SRV(10, 20, 30, b"example.org", 40), ) # Vary the ttl self._equalityTest( dns.Record_SRV(10, 20, 30, b"example.com", 40), dns.Record_SRV(10, 20, 30, b"example.com", 40), dns.Record_SRV(10, 20, 30, b"example.com", 400), ) def test_sshfp(self): """ Two L{dns.Record_SSHFP} instances compare equal if and only if they have the same key type, fingerprint type, fingerprint, and ttl. """ # Vary the key type. self._equalityTest( dns.Record_SSHFP(1, 2, b"happyday", 40), dns.Record_SSHFP(1, 2, b"happyday", 40), dns.Record_SSHFP(2, 2, b"happyday", 40), ) # Vary the fingerprint type. self._equalityTest( dns.Record_SSHFP(1, 2, b"happyday", 40), dns.Record_SSHFP(1, 2, b"happyday", 40), dns.Record_SSHFP(1, 1, b"happyday", 40), ) # Vary the fingerprint itself. self._equalityTest( dns.Record_SSHFP(1, 2, b"happyday", 40), dns.Record_SSHFP(1, 2, b"happyday", 40), dns.Record_SSHFP(1, 2, b"happxday", 40), ) # Vary the ttl. self._equalityTest( dns.Record_SSHFP(1, 2, b"happyday", 40), dns.Record_SSHFP(1, 2, b"happyday", 40), dns.Record_SSHFP(1, 2, b"happyday", 45), ) def test_naptr(self): """ Two L{dns.Record_NAPTR} instances compare equal if and only if they have the same order, preference, flags, service, regexp, replacement, and ttl. """ # Vary the order self._equalityTest( dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12), dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12), dns.Record_NAPTR(2, 2, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12), ) # Vary the preference self._equalityTest( dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12), dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12), dns.Record_NAPTR(1, 3, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12), ) # Vary the flags self._equalityTest( dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12), dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12), dns.Record_NAPTR(1, 2, b"p", b"sip+E2U", b"/foo/bar/", b"baz", 12), ) # Vary the service self._equalityTest( dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12), dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12), dns.Record_NAPTR(1, 2, b"u", b"http", b"/foo/bar/", b"baz", 12), ) # Vary the regexp self._equalityTest( dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12), dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12), dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/bar/foo/", b"baz", 12), ) # Vary the replacement self._equalityTest( dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12), dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12), dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/bar/foo/", b"quux", 12), ) # Vary the ttl self._equalityTest( dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12), dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12), dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/bar/foo/", b"baz", 5), ) def test_afsdb(self): """ Two L{dns.Record_AFSDB} instances compare equal if and only if they have the same subtype, hostname, and ttl. """ # Vary the subtype self._equalityTest( dns.Record_AFSDB(1, b"example.com", 2), dns.Record_AFSDB(1, b"example.com", 2), dns.Record_AFSDB(2, b"example.com", 2), ) # Vary the hostname self._equalityTest( dns.Record_AFSDB(1, b"example.com", 2), dns.Record_AFSDB(1, b"example.com", 2), dns.Record_AFSDB(1, b"example.org", 2), ) # Vary the ttl self._equalityTest( dns.Record_AFSDB(1, b"example.com", 2), dns.Record_AFSDB(1, b"example.com", 2), dns.Record_AFSDB(1, b"example.com", 3), ) def test_rp(self): """ Two L{Record_RP} instances compare equal if and only if they have the same mbox, txt, and ttl. """ # Vary the mbox self._equalityTest( dns.Record_RP(b"alice.example.com", b"alice is nice", 10), dns.Record_RP(b"alice.example.com", b"alice is nice", 10), dns.Record_RP(b"bob.example.com", b"alice is nice", 10), ) # Vary the txt self._equalityTest( dns.Record_RP(b"alice.example.com", b"alice is nice", 10), dns.Record_RP(b"alice.example.com", b"alice is nice", 10), dns.Record_RP(b"alice.example.com", b"alice is not nice", 10), ) # Vary the ttl self._equalityTest( dns.Record_RP(b"alice.example.com", b"alice is nice", 10), dns.Record_RP(b"alice.example.com", b"alice is nice", 10), dns.Record_RP(b"alice.example.com", b"alice is nice", 100), ) def test_hinfo(self): """ Two L{dns.Record_HINFO} instances compare equal if and only if they have the same cpu, os, and ttl. """ # Vary the cpu self._equalityTest( dns.Record_HINFO("x86-64", "plan9", 10), dns.Record_HINFO("x86-64", "plan9", 10), dns.Record_HINFO("i386", "plan9", 10), ) # Vary the os self._equalityTest( dns.Record_HINFO("x86-64", "plan9", 10), dns.Record_HINFO("x86-64", "plan9", 10), dns.Record_HINFO("x86-64", "plan11", 10), ) # Vary the ttl self._equalityTest( dns.Record_HINFO("x86-64", "plan9", 10), dns.Record_HINFO("x86-64", "plan9", 10), dns.Record_HINFO("x86-64", "plan9", 100), ) def test_minfo(self): """ Two L{dns.Record_MINFO} instances compare equal if and only if they have the same rmailbx, emailbx, and ttl. """ # Vary the rmailbx self._equalityTest( dns.Record_MINFO(b"rmailbox", b"emailbox", 10), dns.Record_MINFO(b"rmailbox", b"emailbox", 10), dns.Record_MINFO(b"someplace", b"emailbox", 10), ) # Vary the emailbx self._equalityTest( dns.Record_MINFO(b"rmailbox", b"emailbox", 10), dns.Record_MINFO(b"rmailbox", b"emailbox", 10), dns.Record_MINFO(b"rmailbox", b"something", 10), ) # Vary the ttl self._equalityTest( dns.Record_MINFO(b"rmailbox", b"emailbox", 10), dns.Record_MINFO(b"rmailbox", b"emailbox", 10), dns.Record_MINFO(b"rmailbox", b"emailbox", 100), ) def test_mx(self): """ Two L{dns.Record_MX} instances compare equal if and only if they have the same preference, name, and ttl. """ # Vary the preference self._equalityTest( dns.Record_MX(10, b"example.org", 20), dns.Record_MX(10, b"example.org", 20), dns.Record_MX(100, b"example.org", 20), ) # Vary the name self._equalityTest( dns.Record_MX(10, b"example.org", 20), dns.Record_MX(10, b"example.org", 20), dns.Record_MX(10, b"example.net", 20), ) # Vary the ttl self._equalityTest( dns.Record_MX(10, b"example.org", 20), dns.Record_MX(10, b"example.org", 20), dns.Record_MX(10, b"example.org", 200), ) def test_txt(self): """ Two L{dns.Record_TXT} instances compare equal if and only if they have the same data and ttl. """ # Vary the length of the data self._equalityTest( dns.Record_TXT("foo", "bar", ttl=10), dns.Record_TXT("foo", "bar", ttl=10), dns.Record_TXT("foo", "bar", "baz", ttl=10), ) # Vary the value of the data self._equalityTest( dns.Record_TXT("foo", "bar", ttl=10), dns.Record_TXT("foo", "bar", ttl=10), dns.Record_TXT("bar", "foo", ttl=10), ) # Vary the ttl self._equalityTest( dns.Record_TXT("foo", "bar", ttl=10), dns.Record_TXT("foo", "bar", ttl=10), dns.Record_TXT("foo", "bar", ttl=100), ) def test_spf(self): """ L{dns.Record_SPF} instances compare equal if and only if they have the same data and ttl. """ # Vary the length of the data self._equalityTest( dns.Record_SPF("foo", "bar", ttl=10), dns.Record_SPF("foo", "bar", ttl=10), dns.Record_SPF("foo", "bar", "baz", ttl=10), ) # Vary the value of the data self._equalityTest( dns.Record_SPF("foo", "bar", ttl=10), dns.Record_SPF("foo", "bar", ttl=10), dns.Record_SPF("bar", "foo", ttl=10), ) # Vary the ttl self._equalityTest( dns.Record_SPF("foo", "bar", ttl=10), dns.Record_SPF("foo", "bar", ttl=10), dns.Record_SPF("foo", "bar", ttl=100), ) def test_tsig(self): """ L{dns.Record_TSIG} instances compare equal if and only if they have the same RDATA (algorithm, timestamp, MAC, etc.) and ttl. """ baseargs = { "algorithm": "hmac-sha224", "timeSigned": 1515548975, "fudge": 5, "MAC": b"\x01\x02\x03\x04\x05", "originalID": 99, "error": dns.OK, "otherData": b"", "ttl": 40, } altargs = { "algorithm": "hmac-sha512", "timeSigned": 1515548875, "fudge": 0, "MAC": b"\x05\x04\x03\x02\x01", "originalID": 65437, "error": dns.EBADTIME, "otherData": b"\x00\x00", "ttl": 400, } for kw in baseargs.keys(): altered = baseargs.copy() altered[kw] = altargs[kw] self._equalityTest( dns.Record_TSIG(**altered), dns.Record_TSIG(**altered), dns.Record_TSIG(**baseargs), ) def test_unknown(self): """ L{dns.UnknownRecord} instances compare equal if and only if they have the same data and ttl. """ # Vary the length of the data self._equalityTest( dns.UnknownRecord("foo", ttl=10), dns.UnknownRecord("foo", ttl=10), dns.UnknownRecord("foobar", ttl=10), ) # Vary the value of the data self._equalityTest( dns.UnknownRecord("foo", ttl=10), dns.UnknownRecord("foo", ttl=10), dns.UnknownRecord("bar", ttl=10), ) # Vary the ttl self._equalityTest( dns.UnknownRecord("foo", ttl=10), dns.UnknownRecord("foo", ttl=10), dns.UnknownRecord("foo", ttl=100), ) class RRHeaderTests(unittest.TestCase): """ Tests for L{twisted.names.dns.RRHeader}. """ def test_negativeTTL(self): """ Attempting to create a L{dns.RRHeader} instance with a negative TTL causes L{ValueError} to be raised. """ self.assertRaises( ValueError, dns.RRHeader, "example.com", dns.A, dns.IN, -1, dns.Record_A("127.0.0.1"), ) def test_nonIntegralTTL(self): """ L{dns.RRHeader} converts TTLs to integers. """ ttlAsFloat = 123.45 header = dns.RRHeader( "example.com", dns.A, dns.IN, ttlAsFloat, dns.Record_A("127.0.0.1") ) self.assertEqual(header.ttl, int(ttlAsFloat)) def test_nonNumericTTLRaisesTypeError(self): """ Attempting to create a L{dns.RRHeader} instance with a TTL that L{int} cannot convert to an integer raises a L{TypeError}. """ self.assertRaises( ValueError, dns.RRHeader, "example.com", dns.A, dns.IN, "this is not a number", dns.Record_A("127.0.0.1"), ) class NameToLabelsTests(unittest.SynchronousTestCase): """ Tests for L{twisted.names.dns._nameToLabels}. """ def test_empty(self): """ L{dns._nameToLabels} returns a list containing a single empty label for an empty name. """ self.assertEqual(dns._nameToLabels(b""), [b""]) def test_onlyDot(self): """ L{dns._nameToLabels} returns a list containing a single empty label for a name containing only a dot. """ self.assertEqual(dns._nameToLabels(b"."), [b""]) def test_withoutTrailingDot(self): """ L{dns._nameToLabels} returns a list ending with an empty label for a name without a trailing dot. """ self.assertEqual(dns._nameToLabels(b"com"), [b"com", b""]) def test_withTrailingDot(self): """ L{dns._nameToLabels} returns a list ending with an empty label for a name with a trailing dot. """ self.assertEqual(dns._nameToLabels(b"com."), [b"com", b""]) def test_subdomain(self): """ L{dns._nameToLabels} returns a list containing entries for all labels in a subdomain name. """ self.assertEqual( dns._nameToLabels(b"foo.bar.baz.example.com."), [b"foo", b"bar", b"baz", b"example", b"com", b""], ) def test_casePreservation(self): """ L{dns._nameToLabels} preserves the case of ascii characters in labels. """ self.assertEqual(dns._nameToLabels(b"EXAMPLE.COM"), [b"EXAMPLE", b"COM", b""]) def assertIsSubdomainOf(testCase, descendant, ancestor): """ Assert that C{descendant} *is* a subdomain of C{ancestor}. @type testCase: L{unittest.SynchronousTestCase} @param testCase: The test case on which to run the assertions. @type descendant: C{str} @param descendant: The subdomain name to test. @type ancestor: C{str} @param ancestor: The superdomain name to test. """ testCase.assertTrue( dns._isSubdomainOf(descendant, ancestor), f"{descendant!r} is not a subdomain of {ancestor!r}", ) def assertIsNotSubdomainOf(testCase, descendant, ancestor): """ Assert that C{descendant} *is not* a subdomain of C{ancestor}. @type testCase: L{unittest.SynchronousTestCase} @param testCase: The test case on which to run the assertions. @type descendant: C{str} @param descendant: The subdomain name to test. @type ancestor: C{str} @param ancestor: The superdomain name to test. """ testCase.assertFalse( dns._isSubdomainOf(descendant, ancestor), f"{descendant!r} is a subdomain of {ancestor!r}", ) class IsSubdomainOfTests(unittest.SynchronousTestCase): """ Tests for L{twisted.names.dns._isSubdomainOf}. """ def test_identical(self): """ L{dns._isSubdomainOf} returns C{True} for identical domain names. """ assertIsSubdomainOf(self, b"example.com", b"example.com") def test_parent(self): """ L{dns._isSubdomainOf} returns C{True} when the first name is an immediate descendant of the second name. """ assertIsSubdomainOf(self, b"foo.example.com", b"example.com") def test_distantAncestor(self): """ L{dns._isSubdomainOf} returns C{True} when the first name is a distant descendant of the second name. """ assertIsSubdomainOf(self, b"foo.bar.baz.example.com", b"com") def test_superdomain(self): """ L{dns._isSubdomainOf} returns C{False} when the first name is an ancestor of the second name. """ assertIsNotSubdomainOf(self, b"example.com", b"foo.example.com") def test_sibling(self): """ L{dns._isSubdomainOf} returns C{False} if the first name is a sibling of the second name. """ assertIsNotSubdomainOf(self, b"foo.example.com", b"bar.example.com") def test_unrelatedCommonSuffix(self): """ L{dns._isSubdomainOf} returns C{False} even when domain names happen to share a common suffix. """ assertIsNotSubdomainOf(self, b"foo.myexample.com", b"example.com") def test_subdomainWithTrailingDot(self): """ L{dns._isSubdomainOf} returns C{True} if the first name is a subdomain of the second name but the first name has a trailing ".". """ assertIsSubdomainOf(self, b"foo.example.com.", b"example.com") def test_superdomainWithTrailingDot(self): """ L{dns._isSubdomainOf} returns C{True} if the first name is a subdomain of the second name but the second name has a trailing ".". """ assertIsSubdomainOf(self, b"foo.example.com", b"example.com.") def test_bothWithTrailingDot(self): """ L{dns._isSubdomainOf} returns C{True} if the first name is a subdomain of the second name and both names have a trailing ".". """ assertIsSubdomainOf(self, b"foo.example.com.", b"example.com.") def test_emptySubdomain(self): """ L{dns._isSubdomainOf} returns C{False} if the first name is empty and the second name is not. """ assertIsNotSubdomainOf(self, b"", b"example.com") def test_emptySuperdomain(self): """ L{dns._isSubdomainOf} returns C{True} if the second name is empty and the first name is not. """ assertIsSubdomainOf(self, b"foo.example.com", b"") def test_caseInsensitiveComparison(self): """ L{dns._isSubdomainOf} does case-insensitive comparison of name labels. """ assertIsSubdomainOf(self, b"foo.example.com", b"EXAMPLE.COM") assertIsSubdomainOf(self, b"FOO.EXAMPLE.COM", b"example.com") class OPTNonStandardAttributes: """ Generate byte and instance representations of an L{dns._OPTHeader} where all attributes are set to non-default values. For testing whether attributes have really been read from the byte string during decoding. """ @classmethod def bytes(cls, excludeName=False, excludeOptions=False): """ Return L{bytes} representing an encoded OPT record. @param excludeName: A flag that controls whether to exclude the name field. This allows a non-standard name to be prepended during the test. @type excludeName: L{bool} @param excludeOptions: A flag that controls whether to exclude the RDLEN field. This allows encoded variable options to be appended during the test. @type excludeOptions: L{bool} @return: L{bytes} representing the encoded OPT record returned by L{object}. """ rdlen = b"\x00\x00" # RDLEN 0 if excludeOptions: rdlen = b"" return ( b"\x00" # 0 root zone b"\x00\x29" # type 41 b"\x02\x00" # udpPayloadsize 512 b"\x03" # extendedRCODE 3 b"\x04" # version 4 b"\x80\x00" # DNSSEC OK 1 + Z ) + rdlen @classmethod def object(cls): """ Return a new L{dns._OPTHeader} instance. @return: A L{dns._OPTHeader} instance with attributes that match the encoded record returned by L{bytes}. """ return dns._OPTHeader( udpPayloadSize=512, extendedRCODE=3, version=4, dnssecOK=True ) class OPTHeaderTests(ComparisonTestsMixin, unittest.TestCase): """ Tests for L{twisted.names.dns._OPTHeader}. """ def test_interface(self): """ L{dns._OPTHeader} implements L{dns.IEncodable}. """ verifyClass(dns.IEncodable, dns._OPTHeader) def test_name(self): """ L{dns._OPTHeader.name} is an instance attribute whose value is fixed as the root domain """ self.assertEqual(dns._OPTHeader().name, dns.Name(b"")) def test_nameReadonly(self): """ L{dns._OPTHeader.name} is readonly. """ h = dns._OPTHeader() self.assertRaises(AttributeError, setattr, h, "name", dns.Name(b"example.com")) def test_type(self): """ L{dns._OPTHeader.type} is an instance attribute with fixed value 41. """ self.assertEqual(dns._OPTHeader().type, 41) def test_typeReadonly(self): """ L{dns._OPTHeader.type} is readonly. """ h = dns._OPTHeader() self.assertRaises(AttributeError, setattr, h, "type", dns.A) def test_udpPayloadSize(self): """ L{dns._OPTHeader.udpPayloadSize} defaults to 4096 as recommended in rfc6891 section-6.2.5. """ self.assertEqual(dns._OPTHeader().udpPayloadSize, 4096) def test_udpPayloadSizeOverride(self): """ L{dns._OPTHeader.udpPayloadSize} can be overridden in the constructor. """ self.assertEqual(dns._OPTHeader(udpPayloadSize=512).udpPayloadSize, 512) def test_extendedRCODE(self): """ L{dns._OPTHeader.extendedRCODE} defaults to 0. """ self.assertEqual(dns._OPTHeader().extendedRCODE, 0) def test_extendedRCODEOverride(self): """ L{dns._OPTHeader.extendedRCODE} can be overridden in the constructor. """ self.assertEqual(dns._OPTHeader(extendedRCODE=1).extendedRCODE, 1) def test_version(self): """ L{dns._OPTHeader.version} defaults to 0. """ self.assertEqual(dns._OPTHeader().version, 0) def test_versionOverride(self): """ L{dns._OPTHeader.version} can be overridden in the constructor. """ self.assertEqual(dns._OPTHeader(version=1).version, 1) def test_dnssecOK(self): """ L{dns._OPTHeader.dnssecOK} defaults to False. """ self.assertFalse(dns._OPTHeader().dnssecOK) def test_dnssecOKOverride(self): """ L{dns._OPTHeader.dnssecOK} can be overridden in the constructor. """ self.assertTrue(dns._OPTHeader(dnssecOK=True).dnssecOK) def test_options(self): """ L{dns._OPTHeader.options} defaults to empty list. """ self.assertEqual(dns._OPTHeader().options, []) def test_optionsOverride(self): """ L{dns._OPTHeader.options} can be overridden in the constructor. """ h = dns._OPTHeader(options=[(1, 1, b"\x00")]) self.assertEqual(h.options, [(1, 1, b"\x00")]) def test_encode(self): """ L{dns._OPTHeader.encode} packs the header fields and writes them to a file like object passed in as an argument. """ b = BytesIO() OPTNonStandardAttributes.object().encode(b) self.assertEqual(b.getvalue(), OPTNonStandardAttributes.bytes()) def test_encodeWithOptions(self): """ L{dns._OPTHeader.options} is a list of L{dns._OPTVariableOption} instances which are packed into the rdata area of the header. """ h = OPTNonStandardAttributes.object() h.options = [ dns._OPTVariableOption(1, b"foobarbaz"), dns._OPTVariableOption(2, b"qux"), ] b = BytesIO() h.encode(b) self.assertEqual( b.getvalue(), OPTNonStandardAttributes.bytes(excludeOptions=True) + ( b"\x00\x14" # RDLEN 20 b"\x00\x01" # OPTION-CODE b"\x00\x09" # OPTION-LENGTH b"foobarbaz" # OPTION-DATA b"\x00\x02" # OPTION-CODE b"\x00\x03" # OPTION-LENGTH b"qux" # OPTION-DATA ), ) def test_decode(self): """ L{dns._OPTHeader.decode} unpacks the header fields from a file like object and populates the attributes of an existing L{dns._OPTHeader} instance. """ decodedHeader = dns._OPTHeader() decodedHeader.decode(BytesIO(OPTNonStandardAttributes.bytes())) self.assertEqual(decodedHeader, OPTNonStandardAttributes.object()) def test_decodeAllExpectedBytes(self): """ L{dns._OPTHeader.decode} reads all the bytes of the record that is being decoded. """ # Check that all the input data has been consumed. b = BytesIO(OPTNonStandardAttributes.bytes()) decodedHeader = dns._OPTHeader() decodedHeader.decode(b) self.assertEqual(b.tell(), len(b.getvalue())) def test_decodeOnlyExpectedBytes(self): """ L{dns._OPTHeader.decode} reads only the bytes from the current file position to the end of the record that is being decoded. Trailing bytes are not consumed. """ b = BytesIO(OPTNonStandardAttributes.bytes() + b"xxxx") # Trailing bytes decodedHeader = dns._OPTHeader() decodedHeader.decode(b) self.assertEqual(b.tell(), len(b.getvalue()) - len(b"xxxx")) def test_decodeDiscardsName(self): """ L{dns._OPTHeader.decode} discards the name which is encoded in the supplied bytes. The name attribute of the resulting L{dns._OPTHeader} instance will always be L{dns.Name(b'')}. """ b = BytesIO( OPTNonStandardAttributes.bytes(excludeName=True) + b"\x07example\x03com\x00" ) h = dns._OPTHeader() h.decode(b) self.assertEqual(h.name, dns.Name(b"")) def test_decodeRdlengthTooShort(self): """ L{dns._OPTHeader.decode} raises an exception if the supplied RDLEN is too short. """ b = BytesIO( OPTNonStandardAttributes.bytes(excludeOptions=True) + ( b"\x00\x05" # RDLEN 5 Too short - should be 6 b"\x00\x01" # OPTION-CODE b"\x00\x02" # OPTION-LENGTH b"\x00\x00" # OPTION-DATA ) ) h = dns._OPTHeader() self.assertRaises(EOFError, h.decode, b) def test_decodeRdlengthTooLong(self): """ L{dns._OPTHeader.decode} raises an exception if the supplied RDLEN is too long. """ b = BytesIO( OPTNonStandardAttributes.bytes(excludeOptions=True) + ( b"\x00\x07" # RDLEN 7 Too long - should be 6 b"\x00\x01" # OPTION-CODE b"\x00\x02" # OPTION-LENGTH b"\x00\x00" # OPTION-DATA ) ) h = dns._OPTHeader() self.assertRaises(EOFError, h.decode, b) def test_decodeWithOptions(self): """ If the OPT bytes contain variable options, L{dns._OPTHeader.decode} will populate a list L{dns._OPTHeader.options} with L{dns._OPTVariableOption} instances. """ b = BytesIO( OPTNonStandardAttributes.bytes(excludeOptions=True) + ( b"\x00\x14" # RDLEN 20 b"\x00\x01" # OPTION-CODE b"\x00\x09" # OPTION-LENGTH b"foobarbaz" # OPTION-DATA b"\x00\x02" # OPTION-CODE b"\x00\x03" # OPTION-LENGTH b"qux" # OPTION-DATA ) ) h = dns._OPTHeader() h.decode(b) self.assertEqual( h.options, [ dns._OPTVariableOption(1, b"foobarbaz"), dns._OPTVariableOption(2, b"qux"), ], ) def test_fromRRHeader(self): """ L{_OPTHeader.fromRRHeader} accepts an L{RRHeader} instance and returns an L{_OPTHeader} instance whose attribute values have been derived from the C{cls}, C{ttl} and C{payload} attributes of the original header. """ genericHeader = dns.RRHeader( b"example.com", type=dns.OPT, cls=0xFFFF, ttl=(0xFE << 24 | 0xFD << 16 | True << 15), payload=dns.UnknownRecord(b"\xff\xff\x00\x03abc"), ) decodedOptHeader = dns._OPTHeader.fromRRHeader(genericHeader) expectedOptHeader = dns._OPTHeader( udpPayloadSize=0xFFFF, extendedRCODE=0xFE, version=0xFD, dnssecOK=True, options=[dns._OPTVariableOption(code=0xFFFF, data=b"abc")], ) self.assertEqual(decodedOptHeader, expectedOptHeader) def test_repr(self): """ L{dns._OPTHeader.__repr__} displays the name and type and all the fixed and extended header values of the OPT record. """ self.assertEqual( repr(dns._OPTHeader()), "<_OPTHeader " "name= " "type=41 " "udpPayloadSize=4096 " "extendedRCODE=0 " "version=0 " "dnssecOK=False " "options=[]>", ) def test_equalityUdpPayloadSize(self): """ Two L{OPTHeader} instances compare equal if they have the same udpPayloadSize. """ self.assertNormalEqualityImplementation( dns._OPTHeader(udpPayloadSize=512), dns._OPTHeader(udpPayloadSize=512), dns._OPTHeader(udpPayloadSize=4096), ) def test_equalityExtendedRCODE(self): """ Two L{OPTHeader} instances compare equal if they have the same extendedRCODE. """ self.assertNormalEqualityImplementation( dns._OPTHeader(extendedRCODE=1), dns._OPTHeader(extendedRCODE=1), dns._OPTHeader(extendedRCODE=2), ) def test_equalityVersion(self): """ Two L{OPTHeader} instances compare equal if they have the same version. """ self.assertNormalEqualityImplementation( dns._OPTHeader(version=1), dns._OPTHeader(version=1), dns._OPTHeader(version=2), ) def test_equalityDnssecOK(self): """ Two L{OPTHeader} instances compare equal if they have the same dnssecOK flags. """ self.assertNormalEqualityImplementation( dns._OPTHeader(dnssecOK=True), dns._OPTHeader(dnssecOK=True), dns._OPTHeader(dnssecOK=False), ) def test_equalityOptions(self): """ Two L{OPTHeader} instances compare equal if they have the same options. """ self.assertNormalEqualityImplementation( dns._OPTHeader(options=[dns._OPTVariableOption(1, b"x")]), dns._OPTHeader(options=[dns._OPTVariableOption(1, b"x")]), dns._OPTHeader(options=[dns._OPTVariableOption(2, b"y")]), ) class OPTVariableOptionTests(ComparisonTestsMixin, unittest.TestCase): """ Tests for L{dns._OPTVariableOption}. """ def test_interface(self): """ L{dns._OPTVariableOption} implements L{dns.IEncodable}. """ verifyClass(dns.IEncodable, dns._OPTVariableOption) def test_constructorArguments(self): """ L{dns._OPTVariableOption.__init__} requires code and data arguments which are saved as public instance attributes. """ h = dns._OPTVariableOption(1, b"x") self.assertEqual(h.code, 1) self.assertEqual(h.data, b"x") def test_repr(self): """ L{dns._OPTVariableOption.__repr__} displays the code and data of the option. """ self.assertEqual( repr(dns._OPTVariableOption(1, b"x")), "<_OPTVariableOption " "code=1 " "data=x" ">", ) def test_equality(self): """ Two OPTVariableOption instances compare equal if they have the same code and data values. """ self.assertNormalEqualityImplementation( dns._OPTVariableOption(1, b"x"), dns._OPTVariableOption(1, b"x"), dns._OPTVariableOption(2, b"x"), ) self.assertNormalEqualityImplementation( dns._OPTVariableOption(1, b"x"), dns._OPTVariableOption(1, b"x"), dns._OPTVariableOption(1, b"y"), ) def test_encode(self): """ L{dns._OPTVariableOption.encode} encodes the code and data instance attributes to a byte string which also includes the data length. """ o = dns._OPTVariableOption(1, b"foobar") b = BytesIO() o.encode(b) self.assertEqual( b.getvalue(), b"\x00\x01" # OPTION-CODE 1 b"\x00\x06" # OPTION-LENGTH 6 b"foobar", # OPTION-DATA ) def test_decode(self): """ L{dns._OPTVariableOption.decode} is a classmethod that decodes a byte string and returns a L{dns._OPTVariableOption} instance. """ b = BytesIO( b"\x00\x01" # OPTION-CODE 1 b"\x00\x06" # OPTION-LENGTH 6 b"foobar" # OPTION-DATA ) o = dns._OPTVariableOption() o.decode(b) self.assertEqual(o.code, 1) self.assertEqual(o.data, b"foobar") class RaisedArgs(Exception): """ An exception which can be raised by fakes to test that the fake is called with expected arguments. """ def __init__(self, args, kwargs): """ Store the positional and keyword arguments as attributes. @param args: The positional args. @param kwargs: The keyword args. """ self.args = args self.kwargs = kwargs class MessageEmpty: """ Generate byte string and constructor arguments for an empty L{dns._EDNSMessage}. """ @classmethod def bytes(cls): """ Bytes which are expected when encoding an instance constructed using C{kwargs} and which are expected to result in an identical instance when decoded. @return: The L{bytes} of a wire encoded message. """ return ( b"\x01\x00" # id: 256 b"\x97" # QR: 1, OPCODE: 2, AA: 0, TC: 0, RD: 1 b"\x8f" # RA: 1, Z, RCODE: 15 b"\x00\x00" # number of queries b"\x00\x00" # number of answers b"\x00\x00" # number of authorities b"\x00\x00" # number of additionals ) @classmethod def kwargs(cls): """ Keyword constructor arguments which are expected to result in an instance which returns C{bytes} when encoded. @return: A L{dict} of keyword arguments. """ return dict( id=256, answer=True, opCode=dns.OP_STATUS, auth=True, trunc=True, recDes=True, recAv=True, rCode=15, ednsVersion=None, ) class MessageTruncated: """ An empty response message whose TR bit is set to 1. """ @classmethod def bytes(cls): """ Bytes which are expected when encoding an instance constructed using C{kwargs} and which are expected to result in an identical instance when decoded. @return: The L{bytes} of a wire encoded message. """ return ( b"\x01\x00" # ID: 256 b"\x82" # QR: 1, OPCODE: 0, AA: 0, TC: 1, RD: 0 b"\x00" # RA: 0, Z, RCODE: 0 b"\x00\x00" # Number of queries b"\x00\x00" # Number of answers b"\x00\x00" # Number of authorities b"\x00\x00" # Number of additionals ) @classmethod def kwargs(cls): """ Keyword constructor arguments which are expected to result in an instance which returns C{bytes} when encoded. @return: A L{dict} of keyword arguments. """ return dict( id=256, answer=1, opCode=0, auth=0, trunc=1, recDes=0, recAv=0, rCode=0, ednsVersion=None, ) class MessageNonAuthoritative: """ A minimal non-authoritative message. """ @classmethod def bytes(cls): """ Bytes which are expected when encoding an instance constructed using C{kwargs} and which are expected to result in an identical instance when decoded. @return: The L{bytes} of a wire encoded message. """ return ( b"\x01\x00" # ID 256 b"\x00" # QR: 0, OPCODE: 0, AA: 0, TC: 0, RD: 0 b"\x00" # RA: 0, Z, RCODE: 0 b"\x00\x00" # Query count b"\x00\x01" # Answer count b"\x00\x00" # Authorities count b"\x00\x00" # Additionals count # Answer b"\x00" # RR NAME (root) b"\x00\x01" # RR TYPE 1 (A) b"\x00\x01" # RR CLASS 1 (IN) b"\x00\x00\x00\x00" # RR TTL b"\x00\x04" # RDLENGTH 4 b"\x01\x02\x03\x04" # IPv4 1.2.3.4 ) @classmethod def kwargs(cls): """ Keyword constructor arguments which are expected to result in an instance which returns C{bytes} when encoded. @return: A L{dict} of keyword arguments. """ return dict( id=256, auth=0, ednsVersion=None, answers=[ dns.RRHeader(b"", payload=dns.Record_A("1.2.3.4", ttl=0), auth=False) ], ) class MessageAuthoritative: """ A minimal authoritative message. """ @classmethod def bytes(cls): """ Bytes which are expected when encoding an instance constructed using C{kwargs} and which are expected to result in an identical instance when decoded. @return: The L{bytes} of a wire encoded message. """ return ( b"\x01\x00" # ID: 256 b"\x04" # QR: 0, OPCODE: 0, AA: 1, TC: 0, RD: 0 b"\x00" # RA: 0, Z, RCODE: 0 b"\x00\x00" # Query count b"\x00\x01" # Answer count b"\x00\x00" # Authorities count b"\x00\x00" # Additionals count # Answer b"\x00" # RR NAME (root) b"\x00\x01" # RR TYPE 1 (A) b"\x00\x01" # RR CLASS 1 (IN) b"\x00\x00\x00\x00" # RR TTL b"\x00\x04" # RDLENGTH 4 b"\x01\x02\x03\x04" # IPv4 1.2.3.4 ) @classmethod def kwargs(cls): """ Keyword constructor arguments which are expected to result in an instance which returns C{bytes} when encoded. @return: A L{dict} of keyword arguments. """ return dict( id=256, auth=1, ednsVersion=None, answers=[ dns.RRHeader(b"", payload=dns.Record_A("1.2.3.4", ttl=0), auth=True) ], ) class MessageComplete: """ An example of a fully populated non-edns response message. Contains name compression, answers, authority, and additional records. """ @classmethod def bytes(cls): """ Bytes which are expected when encoding an instance constructed using C{kwargs} and which are expected to result in an identical instance when decoded. @return: The L{bytes} of a wire encoded message. """ return ( b"\x01\x00" # ID: 256 b"\x95" # QR: 1, OPCODE: 2, AA: 1, TC: 0, RD: 1 b"\x8f" # RA: 1, Z, RCODE: 15 b"\x00\x01" # Query count b"\x00\x01" # Answer count b"\x00\x01" # Authorities count b"\x00\x01" # Additionals count # Query begins at Byte 12 b"\x07example\x03com\x00" # QNAME b"\x00\x06" # QTYPE 6 (SOA) b"\x00\x01" # QCLASS 1 (IN) # Answers b"\xc0\x0c" # RR NAME (compression ref b12) b"\x00\x06" # RR TYPE 6 (SOA) b"\x00\x01" # RR CLASS 1 (IN) b"\xff\xff\xff\xff" # RR TTL b"\x00\x27" # RDLENGTH 39 b"\x03ns1\xc0\x0c" # Mname (ns1.example.com (compression ref b15) b"\x0ahostmaster\xc0\x0c" # rname (hostmaster.example.com) b"\xff\xff\xff\xfe" # Serial b"\x7f\xff\xff\xfd" # Refresh b"\x7f\xff\xff\xfc" # Retry b"\x7f\xff\xff\xfb" # Expire b"\xff\xff\xff\xfa" # Minimum # Authority b"\xc0\x0c" # RR NAME (example.com compression ref b12) b"\x00\x02" # RR TYPE 2 (NS) b"\x00\x01" # RR CLASS 1 (IN) b"\xff\xff\xff\xff" # RR TTL b"\x00\x02" # RDLENGTH b"\xc0\x29" # RDATA (ns1.example.com (compression ref b41) # Additional b"\xc0\x29" # RR NAME (ns1.example.com compression ref b41) b"\x00\x01" # RR TYPE 1 (A) b"\x00\x01" # RR CLASS 1 (IN) b"\xff\xff\xff\xff" # RR TTL b"\x00\x04" # RDLENGTH b"\x05\x06\x07\x08" # RDATA 5.6.7.8 ) @classmethod def kwargs(cls): """ Keyword constructor arguments which are expected to result in an instance which returns C{bytes} when encoded. @return: A L{dict} of keyword arguments. """ return dict( id=256, answer=1, opCode=dns.OP_STATUS, auth=1, recDes=1, recAv=1, rCode=15, ednsVersion=None, queries=[dns.Query(b"example.com", dns.SOA)], answers=[ dns.RRHeader( b"example.com", type=dns.SOA, ttl=0xFFFFFFFF, auth=True, payload=dns.Record_SOA( ttl=0xFFFFFFFF, mname=b"ns1.example.com", rname=b"hostmaster.example.com", serial=0xFFFFFFFE, refresh=0x7FFFFFFD, retry=0x7FFFFFFC, expire=0x7FFFFFFB, minimum=0xFFFFFFFA, ), ) ], authority=[ dns.RRHeader( b"example.com", type=dns.NS, ttl=0xFFFFFFFF, auth=True, payload=dns.Record_NS("ns1.example.com", ttl=0xFFFFFFFF), ) ], additional=[ dns.RRHeader( b"ns1.example.com", type=dns.A, ttl=0xFFFFFFFF, auth=True, payload=dns.Record_A("5.6.7.8", ttl=0xFFFFFFFF), ) ], ) class MessageEDNSQuery: """ A minimal EDNS query message. """ @classmethod def bytes(cls): """ Bytes which are expected when encoding an instance constructed using C{kwargs} and which are expected to result in an identical instance when decoded. @return: The L{bytes} of a wire encoded message. """ return ( b"\x00\x00" # ID: 0 b"\x00" # QR: 0, OPCODE: 0, AA: 0, TC: 0, RD: 0 b"\x00" # RA: 0, Z, RCODE: 0 b"\x00\x01" # Queries count b"\x00\x00" # Anwers count b"\x00\x00" # Authority count b"\x00\x01" # Additionals count # Queries b"\x03www\x07example\x03com\x00" # QNAME b"\x00\x01" # QTYPE (A) b"\x00\x01" # QCLASS (IN) # Additional OPT record b"\x00" # NAME (.) b"\x00\x29" # TYPE (OPT 41) b"\x10\x00" # UDP Payload Size (4096) b"\x00" # Extended RCODE b"\x03" # EDNS version b"\x00\x00" # DO: False + Z b"\x00\x00" # RDLENGTH ) @classmethod def kwargs(cls): """ Keyword constructor arguments which are expected to result in an instance which returns C{bytes} when encoded. @return: A L{dict} of keyword arguments. """ return dict( id=0, answer=0, opCode=dns.OP_QUERY, auth=0, recDes=0, recAv=0, rCode=0, ednsVersion=3, dnssecOK=False, queries=[dns.Query(b"www.example.com", dns.A)], additional=[], ) class MessageEDNSComplete: """ An example of a fully populated edns response message. Contains name compression, answers, authority, and additional records. """ @classmethod def bytes(cls): """ Bytes which are expected when encoding an instance constructed using C{kwargs} and which are expected to result in an identical instance when decoded. @return: The L{bytes} of a wire encoded message. """ return ( b"\x01\x00" # ID: 256 b"\x95" # QR: 1, OPCODE: 2, AA: 1, TC: 0, RD: 1 b"\xbf" # RA: 1, AD: 1, RCODE: 15 b"\x00\x01" # Query count b"\x00\x01" # Answer count b"\x00\x01" # Authorities count b"\x00\x02" # Additionals count # Query begins at Byte 12 b"\x07example\x03com\x00" # QNAME b"\x00\x06" # QTYPE 6 (SOA) b"\x00\x01" # QCLASS 1 (IN) # Answers b"\xc0\x0c" # RR NAME (compression ref b12) b"\x00\x06" # RR TYPE 6 (SOA) b"\x00\x01" # RR CLASS 1 (IN) b"\xff\xff\xff\xff" # RR TTL b"\x00\x27" # RDLENGTH 39 b"\x03ns1\xc0\x0c" # mname (ns1.example.com (compression ref b15) b"\x0ahostmaster\xc0\x0c" # rname (hostmaster.example.com) b"\xff\xff\xff\xfe" # Serial b"\x7f\xff\xff\xfd" # Refresh b"\x7f\xff\xff\xfc" # Retry b"\x7f\xff\xff\xfb" # Expire b"\xff\xff\xff\xfa" # Minimum # Authority b"\xc0\x0c" # RR NAME (example.com compression ref b12) b"\x00\x02" # RR TYPE 2 (NS) b"\x00\x01" # RR CLASS 1 (IN) b"\xff\xff\xff\xff" # RR TTL b"\x00\x02" # RDLENGTH b"\xc0\x29" # RDATA (ns1.example.com (compression ref b41) # Additional b"\xc0\x29" # RR NAME (ns1.example.com compression ref b41) b"\x00\x01" # RR TYPE 1 (A) b"\x00\x01" # RR CLASS 1 (IN) b"\xff\xff\xff\xff" # RR TTL b"\x00\x04" # RDLENGTH b"\x05\x06\x07\x08" # RDATA 5.6.7.8 # Additional OPT record b"\x00" # NAME (.) b"\x00\x29" # TYPE (OPT 41) b"\x04\x00" # UDP Payload Size (1024) b"\x00" # Extended RCODE b"\x03" # EDNS version b"\x80\x00" # DO: True + Z b"\x00\x00" # RDLENGTH ) @classmethod def kwargs(cls): """ Keyword constructor arguments which are expected to result in an instance which returns C{bytes} when encoded. @return: A L{dict} of keyword arguments. """ return dict( id=256, answer=1, opCode=dns.OP_STATUS, auth=1, trunc=0, recDes=1, recAv=1, rCode=15, ednsVersion=3, dnssecOK=True, authenticData=True, checkingDisabled=True, maxSize=1024, queries=[dns.Query(b"example.com", dns.SOA)], answers=[ dns.RRHeader( b"example.com", type=dns.SOA, ttl=0xFFFFFFFF, auth=True, payload=dns.Record_SOA( ttl=0xFFFFFFFF, mname=b"ns1.example.com", rname=b"hostmaster.example.com", serial=0xFFFFFFFE, refresh=0x7FFFFFFD, retry=0x7FFFFFFC, expire=0x7FFFFFFB, minimum=0xFFFFFFFA, ), ) ], authority=[ dns.RRHeader( b"example.com", type=dns.NS, ttl=0xFFFFFFFF, auth=True, payload=dns.Record_NS("ns1.example.com", ttl=0xFFFFFFFF), ) ], additional=[ dns.RRHeader( b"ns1.example.com", type=dns.A, ttl=0xFFFFFFFF, auth=True, payload=dns.Record_A("5.6.7.8", ttl=0xFFFFFFFF), ) ], ) class MessageEDNSExtendedRCODE: """ An example of an EDNS message with an extended RCODE. """ @classmethod def bytes(cls): """ Bytes which are expected when encoding an instance constructed using C{kwargs} and which are expected to result in an identical instance when decoded. @return: The L{bytes} of a wire encoded message. """ return ( b"\x00\x00" b"\x00" b"\x0c" # RA: 0, Z, RCODE: 12 b"\x00\x00" b"\x00\x00" b"\x00\x00" b"\x00\x01" # 1 additionals # Additional OPT record b"\x00" b"\x00\x29" b"\x10\x00" b"\xab" # Extended RCODE: 171 b"\x00" b"\x00\x00" b"\x00\x00" ) @classmethod def kwargs(cls): """ Keyword constructor arguments which are expected to result in an instance which returns C{bytes} when encoded. @return: A L{dict} of keyword arguments. """ return dict( id=0, answer=False, opCode=dns.OP_QUERY, auth=False, trunc=False, recDes=False, recAv=False, rCode=0xABC, # Combined OPT extended RCODE + Message RCODE ednsVersion=0, dnssecOK=False, maxSize=4096, queries=[], answers=[], authority=[], additional=[], ) class MessageComparable(FancyEqMixin, FancyStrMixin): """ A wrapper around L{dns.Message} which is comparable so that it can be tested using some of the L{dns._EDNSMessage} tests. """ showAttributes = compareAttributes = ( "id", "answer", "opCode", "auth", "trunc", "recDes", "recAv", "rCode", "queries", "answers", "authority", "additional", ) def __init__(self, original): self.original = original def __getattr__(self, key): return getattr(self.original, key) def verifyConstructorArgument( testCase, cls, argName, defaultVal, altVal, attrName=None ): """ Verify that an attribute has the expected default value and that a corresponding argument passed to a constructor is assigned to that attribute. @param testCase: The L{TestCase} whose assert methods will be called. @type testCase: L{unittest.TestCase} @param cls: The constructor under test. @type cls: L{type} @param argName: The name of the constructor argument under test. @type argName: L{str} @param defaultVal: The expected default value of C{attrName} / C{argName} @type defaultVal: L{object} @param altVal: A value which is different from the default. Used to test that supplied constructor arguments are actually assigned to the correct attribute. @type altVal: L{object} @param attrName: The name of the attribute under test if different from C{argName}. Defaults to C{argName} @type attrName: L{str} """ if attrName is None: attrName = argName actual = {} expected = {"defaultVal": defaultVal, "altVal": altVal} o = cls() actual["defaultVal"] = getattr(o, attrName) o = cls(**{argName: altVal}) actual["altVal"] = getattr(o, attrName) testCase.assertEqual(expected, actual) class ConstructorTestsMixin: """ Helper methods for verifying default attribute values and corresponding constructor arguments. """ def _verifyConstructorArgument(self, argName, defaultVal, altVal): """ Wrap L{verifyConstructorArgument} to provide simpler interface for testing Message and _EDNSMessage constructor arguments. @param argName: The name of the constructor argument. @param defaultVal: The expected default value. @param altVal: An alternative value which is expected to be assigned to a correspondingly named attribute. """ verifyConstructorArgument( testCase=self, cls=self.messageFactory, argName=argName, defaultVal=defaultVal, altVal=altVal, ) def _verifyConstructorFlag(self, argName, defaultVal): """ Wrap L{verifyConstructorArgument} to provide simpler interface for testing _EDNSMessage constructor flags. @param argName: The name of the constructor flag argument @param defaultVal: The expected default value of the flag """ assert defaultVal in (True, False) verifyConstructorArgument( testCase=self, cls=self.messageFactory, argName=argName, defaultVal=defaultVal, altVal=not defaultVal, ) class CommonConstructorTestsMixin: """ Tests for constructor arguments and their associated attributes that are common to both L{twisted.names.dns._EDNSMessage} and L{dns.Message}. TestCase classes that use this mixin must provide a C{messageFactory} method which accepts any argment supported by L{dns.Message.__init__}. TestCases must also mixin ConstructorTestsMixin which provides some custom assertions for testing constructor arguments. """ def test_id(self): """ L{dns._EDNSMessage.id} defaults to C{0} and can be overridden in the constructor. """ self._verifyConstructorArgument("id", defaultVal=0, altVal=1) def test_answer(self): """ L{dns._EDNSMessage.answer} defaults to C{False} and can be overridden in the constructor. """ self._verifyConstructorFlag("answer", defaultVal=False) def test_opCode(self): """ L{dns._EDNSMessage.opCode} defaults to L{dns.OP_QUERY} and can be overridden in the constructor. """ self._verifyConstructorArgument( "opCode", defaultVal=dns.OP_QUERY, altVal=dns.OP_STATUS ) def test_auth(self): """ L{dns._EDNSMessage.auth} defaults to C{False} and can be overridden in the constructor. """ self._verifyConstructorFlag("auth", defaultVal=False) def test_trunc(self): """ L{dns._EDNSMessage.trunc} defaults to C{False} and can be overridden in the constructor. """ self._verifyConstructorFlag("trunc", defaultVal=False) def test_recDes(self): """ L{dns._EDNSMessage.recDes} defaults to C{False} and can be overridden in the constructor. """ self._verifyConstructorFlag("recDes", defaultVal=False) def test_recAv(self): """ L{dns._EDNSMessage.recAv} defaults to C{False} and can be overridden in the constructor. """ self._verifyConstructorFlag("recAv", defaultVal=False) def test_rCode(self): """ L{dns._EDNSMessage.rCode} defaults to C{0} and can be overridden in the constructor. """ self._verifyConstructorArgument("rCode", defaultVal=0, altVal=123) def test_maxSize(self): """ L{dns._EDNSMessage.maxSize} defaults to C{512} and can be overridden in the constructor. """ self._verifyConstructorArgument("maxSize", defaultVal=512, altVal=1024) def test_queries(self): """ L{dns._EDNSMessage.queries} defaults to C{[]}. """ self.assertEqual(self.messageFactory().queries, []) def test_answers(self): """ L{dns._EDNSMessage.answers} defaults to C{[]}. """ self.assertEqual(self.messageFactory().answers, []) def test_authority(self): """ L{dns._EDNSMessage.authority} defaults to C{[]}. """ self.assertEqual(self.messageFactory().authority, []) def test_additional(self): """ L{dns._EDNSMessage.additional} defaults to C{[]}. """ self.assertEqual(self.messageFactory().additional, []) class EDNSMessageConstructorTests( ConstructorTestsMixin, CommonConstructorTestsMixin, unittest.SynchronousTestCase ): """ Tests for L{twisted.names.dns._EDNSMessage} constructor arguments that are shared with L{dns.Message}. """ messageFactory = dns._EDNSMessage class MessageConstructorTests( ConstructorTestsMixin, CommonConstructorTestsMixin, unittest.SynchronousTestCase ): """ Tests for L{twisted.names.dns.Message} constructor arguments that are shared with L{dns._EDNSMessage}. """ messageFactory = dns.Message class EDNSMessageSpecificsTests(ConstructorTestsMixin, unittest.SynchronousTestCase): """ Tests for L{dns._EDNSMessage}. These tests are for L{dns._EDNSMessage} APIs which are not shared with L{dns.Message}. """ messageFactory = dns._EDNSMessage def test_ednsVersion(self): """ L{dns._EDNSMessage.ednsVersion} defaults to C{0} and can be overridden in the constructor. """ self._verifyConstructorArgument("ednsVersion", defaultVal=0, altVal=None) def test_dnssecOK(self): """ L{dns._EDNSMessage.dnssecOK} defaults to C{False} and can be overridden in the constructor. """ self._verifyConstructorFlag("dnssecOK", defaultVal=False) def test_authenticData(self): """ L{dns._EDNSMessage.authenticData} defaults to C{False} and can be overridden in the constructor. """ self._verifyConstructorFlag("authenticData", defaultVal=False) def test_checkingDisabled(self): """ L{dns._EDNSMessage.checkingDisabled} defaults to C{False} and can be overridden in the constructor. """ self._verifyConstructorFlag("checkingDisabled", defaultVal=False) def test_queriesOverride(self): """ L{dns._EDNSMessage.queries} can be overridden in the constructor. """ msg = self.messageFactory(queries=[dns.Query(b"example.com")]) self.assertEqual(msg.queries, [dns.Query(b"example.com")]) def test_answersOverride(self): """ L{dns._EDNSMessage.answers} can be overridden in the constructor. """ msg = self.messageFactory( answers=[dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4"))] ) self.assertEqual( msg.answers, [dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4"))] ) def test_authorityOverride(self): """ L{dns._EDNSMessage.authority} can be overridden in the constructor. """ msg = self.messageFactory( authority=[ dns.RRHeader(b"example.com", type=dns.SOA, payload=dns.Record_SOA()) ] ) self.assertEqual( msg.authority, [dns.RRHeader(b"example.com", type=dns.SOA, payload=dns.Record_SOA())], ) def test_additionalOverride(self): """ L{dns._EDNSMessage.authority} can be overridden in the constructor. """ msg = self.messageFactory( additional=[dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4"))] ) self.assertEqual( msg.additional, [dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4"))], ) def test_reprDefaults(self): """ L{dns._EDNSMessage.__repr__} omits field values and sections which are identical to their defaults. The id field value is always shown. """ self.assertEqual("<_EDNSMessage id=0>", repr(self.messageFactory())) def test_reprFlagsIfSet(self): """ L{dns._EDNSMessage.__repr__} displays flags if they are L{True}. """ m = self.messageFactory( answer=True, auth=True, trunc=True, recDes=True, recAv=True, authenticData=True, checkingDisabled=True, dnssecOK=True, ) self.assertEqual( "<_EDNSMessage " "id=0 " "flags=answer,auth,trunc,recDes,recAv,authenticData," "checkingDisabled,dnssecOK" ">", repr(m), ) def test_reprNonDefautFields(self): """ L{dns._EDNSMessage.__repr__} displays field values if they differ from their defaults. """ m = self.messageFactory(id=10, opCode=20, rCode=30, maxSize=40, ednsVersion=50) self.assertEqual( "<_EDNSMessage " "id=10 " "opCode=20 " "rCode=30 " "maxSize=40 " "ednsVersion=50" ">", repr(m), ) def test_reprNonDefaultSections(self): """ L{dns.Message.__repr__} displays sections which differ from their defaults. """ m = self.messageFactory() m.queries = [1, 2, 3] m.answers = [4, 5, 6] m.authority = [7, 8, 9] m.additional = [10, 11, 12] self.assertEqual( "<_EDNSMessage " "id=0 " "queries=[1, 2, 3] " "answers=[4, 5, 6] " "authority=[7, 8, 9] " "additional=[10, 11, 12]" ">", repr(m), ) def test_fromStrCallsMessageFactory(self): """ L{dns._EDNSMessage.fromString} calls L{dns._EDNSMessage._messageFactory} to create a new L{dns.Message} instance which is used to decode the supplied bytes. """ class FakeMessageFactory: """ Fake message factory. """ def fromStr(self, *args, **kwargs): """ Fake fromStr method which raises the arguments it was passed. @param args: positional arguments @param kwargs: keyword arguments """ raise RaisedArgs(args, kwargs) m = dns._EDNSMessage() m._messageFactory = FakeMessageFactory dummyBytes = object() e = self.assertRaises(RaisedArgs, m.fromStr, dummyBytes) self.assertEqual(((dummyBytes,), {}), (e.args, e.kwargs)) def test_fromStrCallsFromMessage(self): """ L{dns._EDNSMessage.fromString} calls L{dns._EDNSMessage._fromMessage} with a L{dns.Message} instance """ m = dns._EDNSMessage() class FakeMessageFactory: """ Fake message factory. """ def fromStr(self, bytes): """ A noop fake version of fromStr @param bytes: the bytes to be decoded """ fakeMessage = FakeMessageFactory() m._messageFactory = lambda: fakeMessage def fakeFromMessage(*args, **kwargs): raise RaisedArgs(args, kwargs) m._fromMessage = fakeFromMessage e = self.assertRaises(RaisedArgs, m.fromStr, b"") self.assertEqual(((fakeMessage,), {}), (e.args, e.kwargs)) def test_toStrCallsToMessage(self): """ L{dns._EDNSMessage.toStr} calls L{dns._EDNSMessage._toMessage} """ m = dns._EDNSMessage() def fakeToMessage(*args, **kwargs): raise RaisedArgs(args, kwargs) m._toMessage = fakeToMessage e = self.assertRaises(RaisedArgs, m.toStr) self.assertEqual(((), {}), (e.args, e.kwargs)) def test_toStrCallsToMessageToStr(self): """ L{dns._EDNSMessage.toStr} calls C{toStr} on the message returned by L{dns._EDNSMessage._toMessage}. """ m = dns._EDNSMessage() dummyBytes = object() class FakeMessage: """ Fake Message """ def toStr(self): """ Fake toStr which returns dummyBytes. @return: dummyBytes """ return dummyBytes def fakeToMessage(*args, **kwargs): return FakeMessage() m._toMessage = fakeToMessage self.assertEqual(dummyBytes, m.toStr()) class EDNSMessageEqualityTests(ComparisonTestsMixin, unittest.SynchronousTestCase): """ Tests for equality between L{dns._EDNSMessage} instances. These tests will not work with L{dns.Message} because it does not use L{twisted.python.util.FancyEqMixin}. """ messageFactory = dns._EDNSMessage def test_id(self): """ Two L{dns._EDNSMessage} instances compare equal if they have the same id. """ self.assertNormalEqualityImplementation( self.messageFactory(id=1), self.messageFactory(id=1), self.messageFactory(id=2), ) def test_answer(self): """ Two L{dns._EDNSMessage} instances compare equal if they have the same answer flag. """ self.assertNormalEqualityImplementation( self.messageFactory(answer=True), self.messageFactory(answer=True), self.messageFactory(answer=False), ) def test_opCode(self): """ Two L{dns._EDNSMessage} instances compare equal if they have the same opCode. """ self.assertNormalEqualityImplementation( self.messageFactory(opCode=dns.OP_STATUS), self.messageFactory(opCode=dns.OP_STATUS), self.messageFactory(opCode=dns.OP_INVERSE), ) def test_auth(self): """ Two L{dns._EDNSMessage} instances compare equal if they have the same auth flag. """ self.assertNormalEqualityImplementation( self.messageFactory(auth=True), self.messageFactory(auth=True), self.messageFactory(auth=False), ) def test_trunc(self): """ Two L{dns._EDNSMessage} instances compare equal if they have the same trunc flag. """ self.assertNormalEqualityImplementation( self.messageFactory(trunc=True), self.messageFactory(trunc=True), self.messageFactory(trunc=False), ) def test_recDes(self): """ Two L{dns._EDNSMessage} instances compare equal if they have the same recDes flag. """ self.assertNormalEqualityImplementation( self.messageFactory(recDes=True), self.messageFactory(recDes=True), self.messageFactory(recDes=False), ) def test_recAv(self): """ Two L{dns._EDNSMessage} instances compare equal if they have the same recAv flag. """ self.assertNormalEqualityImplementation( self.messageFactory(recAv=True), self.messageFactory(recAv=True), self.messageFactory(recAv=False), ) def test_rCode(self): """ Two L{dns._EDNSMessage} instances compare equal if they have the same rCode. """ self.assertNormalEqualityImplementation( self.messageFactory(rCode=16), self.messageFactory(rCode=16), self.messageFactory(rCode=15), ) def test_ednsVersion(self): """ Two L{dns._EDNSMessage} instances compare equal if they have the same ednsVersion. """ self.assertNormalEqualityImplementation( self.messageFactory(ednsVersion=1), self.messageFactory(ednsVersion=1), self.messageFactory(ednsVersion=None), ) def test_dnssecOK(self): """ Two L{dns._EDNSMessage} instances compare equal if they have the same dnssecOK. """ self.assertNormalEqualityImplementation( self.messageFactory(dnssecOK=True), self.messageFactory(dnssecOK=True), self.messageFactory(dnssecOK=False), ) def test_authenticData(self): """ Two L{dns._EDNSMessage} instances compare equal if they have the same authenticData flags. """ self.assertNormalEqualityImplementation( self.messageFactory(authenticData=True), self.messageFactory(authenticData=True), self.messageFactory(authenticData=False), ) def test_checkingDisabled(self): """ Two L{dns._EDNSMessage} instances compare equal if they have the same checkingDisabled flags. """ self.assertNormalEqualityImplementation( self.messageFactory(checkingDisabled=True), self.messageFactory(checkingDisabled=True), self.messageFactory(checkingDisabled=False), ) def test_maxSize(self): """ Two L{dns._EDNSMessage} instances compare equal if they have the same maxSize. """ self.assertNormalEqualityImplementation( self.messageFactory(maxSize=2048), self.messageFactory(maxSize=2048), self.messageFactory(maxSize=1024), ) def test_queries(self): """ Two L{dns._EDNSMessage} instances compare equal if they have the same queries. """ self.assertNormalEqualityImplementation( self.messageFactory(queries=[dns.Query(b"example.com")]), self.messageFactory(queries=[dns.Query(b"example.com")]), self.messageFactory(queries=[dns.Query(b"example.org")]), ) def test_answers(self): """ Two L{dns._EDNSMessage} instances compare equal if they have the same answers. """ self.assertNormalEqualityImplementation( self.messageFactory( answers=[dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4"))] ), self.messageFactory( answers=[dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4"))] ), self.messageFactory( answers=[dns.RRHeader(b"example.org", payload=dns.Record_A("4.3.2.1"))] ), ) def test_authority(self): """ Two L{dns._EDNSMessage} instances compare equal if they have the same authority records. """ self.assertNormalEqualityImplementation( self.messageFactory( authority=[ dns.RRHeader(b"example.com", type=dns.SOA, payload=dns.Record_SOA()) ] ), self.messageFactory( authority=[ dns.RRHeader(b"example.com", type=dns.SOA, payload=dns.Record_SOA()) ] ), self.messageFactory( authority=[ dns.RRHeader(b"example.org", type=dns.SOA, payload=dns.Record_SOA()) ] ), ) def test_additional(self): """ Two L{dns._EDNSMessage} instances compare equal if they have the same additional records. """ self.assertNormalEqualityImplementation( self.messageFactory( additional=[ dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4")) ] ), self.messageFactory( additional=[ dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4")) ] ), self.messageFactory( additional=[ dns.RRHeader(b"example.org", payload=dns.Record_A("1.2.3.4")) ] ), ) class StandardEncodingTestsMixin: """ Tests for the encoding and decoding of various standard (not EDNS) messages. These tests should work with both L{dns._EDNSMessage} and L{dns.Message}. TestCase classes that use this mixin must provide a C{messageFactory} method which accepts any argment supported by L{dns._EDNSMessage.__init__}. EDNS specific arguments may be discarded if not supported by the message class under construction. """ def test_emptyMessageEncode(self): """ An empty message can be encoded. """ self.assertEqual( self.messageFactory(**MessageEmpty.kwargs()).toStr(), MessageEmpty.bytes() ) def test_emptyMessageDecode(self): """ An empty message byte sequence can be decoded. """ m = self.messageFactory() m.fromStr(MessageEmpty.bytes()) self.assertEqual(m, self.messageFactory(**MessageEmpty.kwargs())) def test_completeQueryEncode(self): """ A fully populated query message can be encoded. """ self.assertEqual( self.messageFactory(**MessageComplete.kwargs()).toStr(), MessageComplete.bytes(), ) def test_completeQueryDecode(self): """ A fully populated message byte string can be decoded. """ m = self.messageFactory() m.fromStr(MessageComplete.bytes()), self.assertEqual(m, self.messageFactory(**MessageComplete.kwargs())) def test_NULL(self): """ A I{NULL} record with an arbitrary payload can be encoded and decoded as part of a message. """ bytes = b"".join([dns._ord2bytes(i) for i in range(256)]) rec = dns.Record_NULL(bytes) rr = dns.RRHeader(b"testname", dns.NULL, payload=rec) msg1 = self.messageFactory() msg1.answers.append(rr) s = msg1.toStr() msg2 = self.messageFactory() msg2.fromStr(s) self.assertIsInstance(msg2.answers[0].payload, dns.Record_NULL) self.assertEqual(msg2.answers[0].payload.payload, bytes) def test_nonAuthoritativeMessageEncode(self): """ If the message C{authoritative} attribute is set to 0, the encoded bytes will have AA bit 0. """ self.assertEqual( self.messageFactory(**MessageNonAuthoritative.kwargs()).toStr(), MessageNonAuthoritative.bytes(), ) def test_nonAuthoritativeMessageDecode(self): """ The L{dns.RRHeader} instances created by a message from a non-authoritative message byte string are marked as not authoritative. """ m = self.messageFactory() m.fromStr(MessageNonAuthoritative.bytes()) self.assertEqual(m, self.messageFactory(**MessageNonAuthoritative.kwargs())) def test_authoritativeMessageEncode(self): """ If the message C{authoritative} attribute is set to 1, the encoded bytes will have AA bit 1. """ self.assertEqual( self.messageFactory(**MessageAuthoritative.kwargs()).toStr(), MessageAuthoritative.bytes(), ) def test_authoritativeMessageDecode(self): """ The message and its L{dns.RRHeader} instances created by C{decode} from an authoritative message byte string, are marked as authoritative. """ m = self.messageFactory() m.fromStr(MessageAuthoritative.bytes()) self.assertEqual(m, self.messageFactory(**MessageAuthoritative.kwargs())) def test_truncatedMessageEncode(self): """ If the message C{trunc} attribute is set to 1 the encoded bytes will have TR bit 1. """ self.assertEqual( self.messageFactory(**MessageTruncated.kwargs()).toStr(), MessageTruncated.bytes(), ) def test_truncatedMessageDecode(self): """ The message instance created by decoding a truncated message is marked as truncated. """ m = self.messageFactory() m.fromStr(MessageTruncated.bytes()) self.assertEqual(m, self.messageFactory(**MessageTruncated.kwargs())) class EDNSMessageStandardEncodingTests( StandardEncodingTestsMixin, unittest.SynchronousTestCase ): """ Tests for the encoding and decoding of various standard (non-EDNS) messages by L{dns._EDNSMessage}. """ messageFactory = dns._EDNSMessage class MessageStandardEncodingTests( StandardEncodingTestsMixin, unittest.SynchronousTestCase ): """ Tests for the encoding and decoding of various standard (non-EDNS) messages by L{dns.Message}. """ @staticmethod def messageFactory(**kwargs): """ This function adapts constructor arguments expected by _EDNSMessage.__init__ to arguments suitable for use with the Message.__init__. Also handles the fact that unlike L{dns._EDNSMessage}, L{dns.Message.__init__} does not accept queries, answers etc as arguments. Also removes any L{dns._EDNSMessage} specific arguments. @param args: The positional arguments which will be passed to L{dns.Message.__init__}. @param kwargs: The keyword arguments which will be stripped of EDNS specific arguments before being passed to L{dns.Message.__init__}. @return: An L{dns.Message} instance. """ queries = kwargs.pop("queries", []) answers = kwargs.pop("answers", []) authority = kwargs.pop("authority", []) additional = kwargs.pop("additional", []) kwargs.pop("ednsVersion", None) m = dns.Message(**kwargs) m.queries = queries m.answers = answers m.authority = authority m.additional = additional return MessageComparable(m) class EDNSMessageEDNSEncodingTests(unittest.SynchronousTestCase): """ Tests for the encoding and decoding of various EDNS messages. These test will not work with L{dns.Message}. """ messageFactory = dns._EDNSMessage def test_ednsMessageDecodeStripsOptRecords(self): """ The L(_EDNSMessage} instance created by L{dns._EDNSMessage.decode} from an EDNS query never includes OPT records in the additional section. """ m = self.messageFactory() m.fromStr(MessageEDNSQuery.bytes()) self.assertEqual(m.additional, []) def test_ednsMessageDecodeMultipleOptRecords(self): """ An L(_EDNSMessage} instance created from a byte string containing multiple I{OPT} records will discard all the C{OPT} records. C{ednsVersion} will be set to L{None}. @see: U{https://tools.ietf.org/html/rfc6891#section-6.1.1} """ m = dns.Message() m.additional = [dns._OPTHeader(version=2), dns._OPTHeader(version=3)] ednsMessage = dns._EDNSMessage() ednsMessage.fromStr(m.toStr()) self.assertIsNone(ednsMessage.ednsVersion) def test_fromMessageCopiesSections(self): """ L{dns._EDNSMessage._fromMessage} returns an L{_EDNSMessage} instance whose queries, answers, authority and additional lists are copies (not references to) the original message lists. """ standardMessage = dns.Message() standardMessage.fromStr(MessageEDNSQuery.bytes()) ednsMessage = dns._EDNSMessage._fromMessage(standardMessage) duplicates = [] for attrName in ("queries", "answers", "authority", "additional"): if getattr(standardMessage, attrName) is getattr(ednsMessage, attrName): duplicates.append(attrName) if duplicates: self.fail( "Message and _EDNSMessage shared references to the following " "section lists after decoding: %s" % (duplicates,) ) def test_toMessageCopiesSections(self): """ L{dns._EDNSMessage.toStr} makes no in place changes to the message instance. """ ednsMessage = dns._EDNSMessage(ednsVersion=1) ednsMessage.toStr() self.assertEqual(ednsMessage.additional, []) def test_optHeaderPosition(self): """ L{dns._EDNSMessage} can decode OPT records, regardless of their position in the additional records section. "The OPT RR MAY be placed anywhere within the additional data section." @see: U{https://tools.ietf.org/html/rfc6891#section-6.1.1} """ # XXX: We need an _OPTHeader.toRRHeader method. See #6779. b = BytesIO() optRecord = dns._OPTHeader(version=1) optRecord.encode(b) optRRHeader = dns.RRHeader() b.seek(0) optRRHeader.decode(b) m = dns.Message() m.additional = [optRRHeader] actualMessages = [] actualMessages.append(dns._EDNSMessage._fromMessage(m).ednsVersion) m.additional.append(dns.RRHeader(type=dns.A)) actualMessages.append(dns._EDNSMessage._fromMessage(m).ednsVersion) m.additional.insert(0, dns.RRHeader(type=dns.A)) actualMessages.append(dns._EDNSMessage._fromMessage(m).ednsVersion) self.assertEqual([1] * 3, actualMessages) def test_ednsDecode(self): """ The L(_EDNSMessage} instance created by L{dns._EDNSMessage.fromStr} derives its edns specific values (C{ednsVersion}, etc) from the supplied OPT record. """ m = self.messageFactory() m.fromStr(MessageEDNSComplete.bytes()) self.assertEqual(m, self.messageFactory(**MessageEDNSComplete.kwargs())) def test_ednsEncode(self): """ The L(_EDNSMessage} instance created by L{dns._EDNSMessage.toStr} encodes its edns specific values (C{ednsVersion}, etc) into an OPT record added to the additional section. """ self.assertEqual( self.messageFactory(**MessageEDNSComplete.kwargs()).toStr(), MessageEDNSComplete.bytes(), ) def test_extendedRcodeEncode(self): """ The L(_EDNSMessage.toStr} encodes the extended I{RCODE} (>=16) by assigning the lower 4bits to the message RCODE field and the upper 4bits to the OPT pseudo record. """ self.assertEqual( self.messageFactory(**MessageEDNSExtendedRCODE.kwargs()).toStr(), MessageEDNSExtendedRCODE.bytes(), ) def test_extendedRcodeDecode(self): """ The L(_EDNSMessage} instance created by L{dns._EDNSMessage.fromStr} derives RCODE from the supplied OPT record. """ m = self.messageFactory() m.fromStr(MessageEDNSExtendedRCODE.bytes()) self.assertEqual(m, self.messageFactory(**MessageEDNSExtendedRCODE.kwargs())) def test_extendedRcodeZero(self): """ Note that EXTENDED-RCODE value 0 indicates that an unextended RCODE is in use (values 0 through 15). https://tools.ietf.org/html/rfc6891#section-6.1.3 """ ednsMessage = self.messageFactory(rCode=15, ednsVersion=0) standardMessage = ednsMessage._toMessage() self.assertEqual( (15, 0), (standardMessage.rCode, standardMessage.additional[0].extendedRCODE), ) class ResponseFromMessageTests(unittest.SynchronousTestCase): """ Tests for L{dns._responseFromMessage}. """ def test_responseFromMessageResponseType(self): """ L{dns.Message._responseFromMessage} is a constructor function which generates a new I{answer} message from an existing L{dns.Message} like instance. """ request = dns.Message() response = dns._responseFromMessage( responseConstructor=dns.Message, message=request ) self.assertIsNot(request, response) def test_responseType(self): """ L{dns._responseFromMessage} returns a new instance of C{cls} """ class SuppliedClass: id = 1 queries = [] expectedClass = dns.Message self.assertIsInstance( dns._responseFromMessage( responseConstructor=expectedClass, message=SuppliedClass() ), expectedClass, ) def test_responseId(self): """ L{dns._responseFromMessage} copies the C{id} attribute of the original message. """ self.assertEqual( 1234, dns._responseFromMessage( responseConstructor=dns.Message, message=dns.Message(id=1234) ).id, ) def test_responseAnswer(self): """ L{dns._responseFromMessage} sets the C{answer} flag to L{True} """ request = dns.Message() response = dns._responseFromMessage( responseConstructor=dns.Message, message=request ) self.assertEqual((False, True), (request.answer, response.answer)) def test_responseQueries(self): """ L{dns._responseFromMessage} copies the C{queries} attribute of the original message. """ request = dns.Message() expectedQueries = [object(), object(), object()] request.queries = expectedQueries[:] self.assertEqual( expectedQueries, dns._responseFromMessage( responseConstructor=dns.Message, message=request ).queries, ) def test_responseKwargs(self): """ L{dns._responseFromMessage} accepts other C{kwargs} which are assigned to the new message before it is returned. """ self.assertEqual( 123, dns._responseFromMessage( responseConstructor=dns.Message, message=dns.Message(), rCode=123 ).rCode, ) class Foo: """ An example class for use in L{dns._compactRepr} tests. It follows the pattern of initialiser settable flags, fields and sections found in L{dns.Message} and L{dns._EDNSMessage}. """ def __init__( self, field1=1, field2=2, alwaysShowField="AS", flagTrue=True, flagFalse=False, section1=None, ): """ Set some flags, fields and sections as public attributes. """ self.field1 = field1 self.field2 = field2 self.alwaysShowField = alwaysShowField self.flagTrue = flagTrue self.flagFalse = flagFalse if section1 is None: section1 = [] self.section1 = section1 def __repr__(self) -> str: """ Call L{dns._compactRepr} to generate a string representation. """ return cast( str, dns._compactRepr( self, alwaysShow="alwaysShowField".split(), fieldNames="field1 field2 alwaysShowField".split(), flagNames="flagTrue flagFalse".split(), sectionNames="section1 section2".split(), ), ) class CompactReprTests(unittest.SynchronousTestCase): """ Tests for L{dns._compactRepr}. """ messageFactory = Foo def test_defaults(self): """ L{dns._compactRepr} omits field values and sections which have the default value. Flags which are True are always shown. """ self.assertEqual( "", repr(self.messageFactory()) ) def test_flagsIfSet(self): """ L{dns._compactRepr} displays flags if they have a non-default value. """ m = self.messageFactory(flagTrue=True, flagFalse=True) self.assertEqual( "", repr(m), ) def test_nonDefautFields(self): """ L{dns._compactRepr} displays field values if they differ from their defaults. """ m = self.messageFactory(field1=10, field2=20) self.assertEqual( "", repr(m), ) def test_nonDefaultSections(self): """ L{dns._compactRepr} displays sections which differ from their defaults. """ m = self.messageFactory() m.section1 = [1, 1, 1] m.section2 = [2, 2, 2] self.assertEqual( "", repr(m), )