query.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
  2. #
  3. # Permission to use, copy, modify, and distribute this software and its
  4. # documentation for any purpose with or without fee is hereby granted,
  5. # provided that the above copyright notice and this permission notice
  6. # appear in all copies.
  7. #
  8. # THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
  9. # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
  10. # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
  11. # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
  12. # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
  13. # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
  14. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  15. """Talk to a DNS server."""
  16. from __future__ import generators
  17. import errno
  18. import select
  19. import socket
  20. import struct
  21. import sys
  22. import time
  23. import dns.exception
  24. import dns.inet
  25. import dns.name
  26. import dns.message
  27. import dns.rdataclass
  28. import dns.rdatatype
  29. from ._compat import long, string_types
  30. if sys.version_info > (3,):
  31. select_error = OSError
  32. else:
  33. select_error = select.error
  34. # Function used to create a socket. Can be overridden if needed in special
  35. # situations.
  36. socket_factory = socket.socket
  37. class UnexpectedSource(dns.exception.DNSException):
  38. """A DNS query response came from an unexpected address or port."""
  39. class BadResponse(dns.exception.FormError):
  40. """A DNS query response does not respond to the question asked."""
  41. def _compute_expiration(timeout):
  42. if timeout is None:
  43. return None
  44. else:
  45. return time.time() + timeout
  46. def _poll_for(fd, readable, writable, error, timeout):
  47. """Poll polling backend.
  48. @param fd: File descriptor
  49. @type fd: int
  50. @param readable: Whether to wait for readability
  51. @type readable: bool
  52. @param writable: Whether to wait for writability
  53. @type writable: bool
  54. @param timeout: Deadline timeout (expiration time, in seconds)
  55. @type timeout: float
  56. @return True on success, False on timeout
  57. """
  58. event_mask = 0
  59. if readable:
  60. event_mask |= select.POLLIN
  61. if writable:
  62. event_mask |= select.POLLOUT
  63. if error:
  64. event_mask |= select.POLLERR
  65. pollable = select.poll()
  66. pollable.register(fd, event_mask)
  67. if timeout:
  68. event_list = pollable.poll(long(timeout * 1000))
  69. else:
  70. event_list = pollable.poll()
  71. return bool(event_list)
  72. def _select_for(fd, readable, writable, error, timeout):
  73. """Select polling backend.
  74. @param fd: File descriptor
  75. @type fd: int
  76. @param readable: Whether to wait for readability
  77. @type readable: bool
  78. @param writable: Whether to wait for writability
  79. @type writable: bool
  80. @param timeout: Deadline timeout (expiration time, in seconds)
  81. @type timeout: float
  82. @return True on success, False on timeout
  83. """
  84. rset, wset, xset = [], [], []
  85. if readable:
  86. rset = [fd]
  87. if writable:
  88. wset = [fd]
  89. if error:
  90. xset = [fd]
  91. if timeout is None:
  92. (rcount, wcount, xcount) = select.select(rset, wset, xset)
  93. else:
  94. (rcount, wcount, xcount) = select.select(rset, wset, xset, timeout)
  95. return bool((rcount or wcount or xcount))
  96. def _wait_for(fd, readable, writable, error, expiration):
  97. done = False
  98. while not done:
  99. if expiration is None:
  100. timeout = None
  101. else:
  102. timeout = expiration - time.time()
  103. if timeout <= 0.0:
  104. raise dns.exception.Timeout
  105. try:
  106. if not _polling_backend(fd, readable, writable, error, timeout):
  107. raise dns.exception.Timeout
  108. except select_error as e:
  109. if e.args[0] != errno.EINTR:
  110. raise e
  111. done = True
  112. def _set_polling_backend(fn):
  113. """
  114. Internal API. Do not use.
  115. """
  116. global _polling_backend
  117. _polling_backend = fn
  118. if hasattr(select, 'poll'):
  119. # Prefer poll() on platforms that support it because it has no
  120. # limits on the maximum value of a file descriptor (plus it will
  121. # be more efficient for high values).
  122. _polling_backend = _poll_for
  123. else:
  124. _polling_backend = _select_for
  125. def _wait_for_readable(s, expiration):
  126. _wait_for(s, True, False, True, expiration)
  127. def _wait_for_writable(s, expiration):
  128. _wait_for(s, False, True, True, expiration)
  129. def _addresses_equal(af, a1, a2):
  130. # Convert the first value of the tuple, which is a textual format
  131. # address into binary form, so that we are not confused by different
  132. # textual representations of the same address
  133. n1 = dns.inet.inet_pton(af, a1[0])
  134. n2 = dns.inet.inet_pton(af, a2[0])
  135. return n1 == n2 and a1[1:] == a2[1:]
  136. def _destination_and_source(af, where, port, source, source_port):
  137. # Apply defaults and compute destination and source tuples
  138. # suitable for use in connect(), sendto(), or bind().
  139. if af is None:
  140. try:
  141. af = dns.inet.af_for_address(where)
  142. except Exception:
  143. af = dns.inet.AF_INET
  144. if af == dns.inet.AF_INET:
  145. destination = (where, port)
  146. if source is not None or source_port != 0:
  147. if source is None:
  148. source = '0.0.0.0'
  149. source = (source, source_port)
  150. elif af == dns.inet.AF_INET6:
  151. destination = (where, port, 0, 0)
  152. if source is not None or source_port != 0:
  153. if source is None:
  154. source = '::'
  155. source = (source, source_port, 0, 0)
  156. return (af, destination, source)
  157. def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
  158. ignore_unexpected=False, one_rr_per_rrset=False):
  159. """Return the response obtained after sending a query via UDP.
  160. @param q: the query
  161. @type q: dns.message.Message
  162. @param where: where to send the message
  163. @type where: string containing an IPv4 or IPv6 address
  164. @param timeout: The number of seconds to wait before the query times out.
  165. If None, the default, wait forever.
  166. @type timeout: float
  167. @param port: The port to which to send the message. The default is 53.
  168. @type port: int
  169. @param af: the address family to use. The default is None, which
  170. causes the address family to use to be inferred from the form of where.
  171. If the inference attempt fails, AF_INET is used.
  172. @type af: int
  173. @rtype: dns.message.Message object
  174. @param source: source address. The default is the wildcard address.
  175. @type source: string
  176. @param source_port: The port from which to send the message.
  177. The default is 0.
  178. @type source_port: int
  179. @param ignore_unexpected: If True, ignore responses from unexpected
  180. sources. The default is False.
  181. @type ignore_unexpected: bool
  182. @param one_rr_per_rrset: Put each RR into its own RRset
  183. @type one_rr_per_rrset: bool
  184. """
  185. wire = q.to_wire()
  186. (af, destination, source) = _destination_and_source(af, where, port,
  187. source, source_port)
  188. s = socket_factory(af, socket.SOCK_DGRAM, 0)
  189. begin_time = None
  190. try:
  191. expiration = _compute_expiration(timeout)
  192. s.setblocking(0)
  193. if source is not None:
  194. s.bind(source)
  195. _wait_for_writable(s, expiration)
  196. begin_time = time.time()
  197. s.sendto(wire, destination)
  198. while 1:
  199. _wait_for_readable(s, expiration)
  200. (wire, from_address) = s.recvfrom(65535)
  201. if _addresses_equal(af, from_address, destination) or \
  202. (dns.inet.is_multicast(where) and
  203. from_address[1:] == destination[1:]):
  204. break
  205. if not ignore_unexpected:
  206. raise UnexpectedSource('got a response from '
  207. '%s instead of %s' % (from_address,
  208. destination))
  209. finally:
  210. if begin_time is None:
  211. response_time = 0
  212. else:
  213. response_time = time.time() - begin_time
  214. s.close()
  215. r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac,
  216. one_rr_per_rrset=one_rr_per_rrset)
  217. r.time = response_time
  218. if not q.is_response(r):
  219. raise BadResponse
  220. return r
  221. def _net_read(sock, count, expiration):
  222. """Read the specified number of bytes from sock. Keep trying until we
  223. either get the desired amount, or we hit EOF.
  224. A Timeout exception will be raised if the operation is not completed
  225. by the expiration time.
  226. """
  227. s = b''
  228. while count > 0:
  229. _wait_for_readable(sock, expiration)
  230. n = sock.recv(count)
  231. if n == b'':
  232. raise EOFError
  233. count = count - len(n)
  234. s = s + n
  235. return s
  236. def _net_write(sock, data, expiration):
  237. """Write the specified data to the socket.
  238. A Timeout exception will be raised if the operation is not completed
  239. by the expiration time.
  240. """
  241. current = 0
  242. l = len(data)
  243. while current < l:
  244. _wait_for_writable(sock, expiration)
  245. current += sock.send(data[current:])
  246. def _connect(s, address):
  247. try:
  248. s.connect(address)
  249. except socket.error:
  250. (ty, v) = sys.exc_info()[:2]
  251. if hasattr(v, 'errno'):
  252. v_err = v.errno
  253. else:
  254. v_err = v[0]
  255. if v_err not in [errno.EINPROGRESS, errno.EWOULDBLOCK, errno.EALREADY]:
  256. raise v
  257. def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
  258. one_rr_per_rrset=False):
  259. """Return the response obtained after sending a query via TCP.
  260. @param q: the query
  261. @type q: dns.message.Message object
  262. @param where: where to send the message
  263. @type where: string containing an IPv4 or IPv6 address
  264. @param timeout: The number of seconds to wait before the query times out.
  265. If None, the default, wait forever.
  266. @type timeout: float
  267. @param port: The port to which to send the message. The default is 53.
  268. @type port: int
  269. @param af: the address family to use. The default is None, which
  270. causes the address family to use to be inferred from the form of where.
  271. If the inference attempt fails, AF_INET is used.
  272. @type af: int
  273. @rtype: dns.message.Message object
  274. @param source: source address. The default is the wildcard address.
  275. @type source: string
  276. @param source_port: The port from which to send the message.
  277. The default is 0.
  278. @type source_port: int
  279. @param one_rr_per_rrset: Put each RR into its own RRset
  280. @type one_rr_per_rrset: bool
  281. """
  282. wire = q.to_wire()
  283. (af, destination, source) = _destination_and_source(af, where, port,
  284. source, source_port)
  285. s = socket_factory(af, socket.SOCK_STREAM, 0)
  286. begin_time = None
  287. try:
  288. expiration = _compute_expiration(timeout)
  289. s.setblocking(0)
  290. begin_time = time.time()
  291. if source is not None:
  292. s.bind(source)
  293. _connect(s, destination)
  294. l = len(wire)
  295. # copying the wire into tcpmsg is inefficient, but lets us
  296. # avoid writev() or doing a short write that would get pushed
  297. # onto the net
  298. tcpmsg = struct.pack("!H", l) + wire
  299. _net_write(s, tcpmsg, expiration)
  300. ldata = _net_read(s, 2, expiration)
  301. (l,) = struct.unpack("!H", ldata)
  302. wire = _net_read(s, l, expiration)
  303. finally:
  304. if begin_time is None:
  305. response_time = 0
  306. else:
  307. response_time = time.time() - begin_time
  308. s.close()
  309. r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac,
  310. one_rr_per_rrset=one_rr_per_rrset)
  311. r.time = response_time
  312. if not q.is_response(r):
  313. raise BadResponse
  314. return r
  315. def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
  316. timeout=None, port=53, keyring=None, keyname=None, relativize=True,
  317. af=None, lifetime=None, source=None, source_port=0, serial=0,
  318. use_udp=False, keyalgorithm=dns.tsig.default_algorithm):
  319. """Return a generator for the responses to a zone transfer.
  320. @param where: where to send the message
  321. @type where: string containing an IPv4 or IPv6 address
  322. @param zone: The name of the zone to transfer
  323. @type zone: dns.name.Name object or string
  324. @param rdtype: The type of zone transfer. The default is
  325. dns.rdatatype.AXFR.
  326. @type rdtype: int or string
  327. @param rdclass: The class of the zone transfer. The default is
  328. dns.rdataclass.IN.
  329. @type rdclass: int or string
  330. @param timeout: The number of seconds to wait for each response message.
  331. If None, the default, wait forever.
  332. @type timeout: float
  333. @param port: The port to which to send the message. The default is 53.
  334. @type port: int
  335. @param keyring: The TSIG keyring to use
  336. @type keyring: dict
  337. @param keyname: The name of the TSIG key to use
  338. @type keyname: dns.name.Name object or string
  339. @param relativize: If True, all names in the zone will be relativized to
  340. the zone origin. It is essential that the relativize setting matches
  341. the one specified to dns.zone.from_xfr().
  342. @type relativize: bool
  343. @param af: the address family to use. The default is None, which
  344. causes the address family to use to be inferred from the form of where.
  345. If the inference attempt fails, AF_INET is used.
  346. @type af: int
  347. @param lifetime: The total number of seconds to spend doing the transfer.
  348. If None, the default, then there is no limit on the time the transfer may
  349. take.
  350. @type lifetime: float
  351. @rtype: generator of dns.message.Message objects.
  352. @param source: source address. The default is the wildcard address.
  353. @type source: string
  354. @param source_port: The port from which to send the message.
  355. The default is 0.
  356. @type source_port: int
  357. @param serial: The SOA serial number to use as the base for an IXFR diff
  358. sequence (only meaningful if rdtype == dns.rdatatype.IXFR).
  359. @type serial: int
  360. @param use_udp: Use UDP (only meaningful for IXFR)
  361. @type use_udp: bool
  362. @param keyalgorithm: The TSIG algorithm to use; defaults to
  363. dns.tsig.default_algorithm
  364. @type keyalgorithm: string
  365. """
  366. if isinstance(zone, string_types):
  367. zone = dns.name.from_text(zone)
  368. if isinstance(rdtype, string_types):
  369. rdtype = dns.rdatatype.from_text(rdtype)
  370. q = dns.message.make_query(zone, rdtype, rdclass)
  371. if rdtype == dns.rdatatype.IXFR:
  372. rrset = dns.rrset.from_text(zone, 0, 'IN', 'SOA',
  373. '. . %u 0 0 0 0' % serial)
  374. q.authority.append(rrset)
  375. if keyring is not None:
  376. q.use_tsig(keyring, keyname, algorithm=keyalgorithm)
  377. wire = q.to_wire()
  378. (af, destination, source) = _destination_and_source(af, where, port,
  379. source, source_port)
  380. if use_udp:
  381. if rdtype != dns.rdatatype.IXFR:
  382. raise ValueError('cannot do a UDP AXFR')
  383. s = socket_factory(af, socket.SOCK_DGRAM, 0)
  384. else:
  385. s = socket_factory(af, socket.SOCK_STREAM, 0)
  386. s.setblocking(0)
  387. if source is not None:
  388. s.bind(source)
  389. expiration = _compute_expiration(lifetime)
  390. _connect(s, destination)
  391. l = len(wire)
  392. if use_udp:
  393. _wait_for_writable(s, expiration)
  394. s.send(wire)
  395. else:
  396. tcpmsg = struct.pack("!H", l) + wire
  397. _net_write(s, tcpmsg, expiration)
  398. done = False
  399. delete_mode = True
  400. expecting_SOA = False
  401. soa_rrset = None
  402. if relativize:
  403. origin = zone
  404. oname = dns.name.empty
  405. else:
  406. origin = None
  407. oname = zone
  408. tsig_ctx = None
  409. first = True
  410. while not done:
  411. mexpiration = _compute_expiration(timeout)
  412. if mexpiration is None or mexpiration > expiration:
  413. mexpiration = expiration
  414. if use_udp:
  415. _wait_for_readable(s, expiration)
  416. (wire, from_address) = s.recvfrom(65535)
  417. else:
  418. ldata = _net_read(s, 2, mexpiration)
  419. (l,) = struct.unpack("!H", ldata)
  420. wire = _net_read(s, l, mexpiration)
  421. is_ixfr = (rdtype == dns.rdatatype.IXFR)
  422. r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac,
  423. xfr=True, origin=origin, tsig_ctx=tsig_ctx,
  424. multi=True, first=first,
  425. one_rr_per_rrset=is_ixfr)
  426. tsig_ctx = r.tsig_ctx
  427. first = False
  428. answer_index = 0
  429. if soa_rrset is None:
  430. if not r.answer or r.answer[0].name != oname:
  431. raise dns.exception.FormError(
  432. "No answer or RRset not for qname")
  433. rrset = r.answer[0]
  434. if rrset.rdtype != dns.rdatatype.SOA:
  435. raise dns.exception.FormError("first RRset is not an SOA")
  436. answer_index = 1
  437. soa_rrset = rrset.copy()
  438. if rdtype == dns.rdatatype.IXFR:
  439. if soa_rrset[0].serial <= serial:
  440. #
  441. # We're already up-to-date.
  442. #
  443. done = True
  444. else:
  445. expecting_SOA = True
  446. #
  447. # Process SOAs in the answer section (other than the initial
  448. # SOA in the first message).
  449. #
  450. for rrset in r.answer[answer_index:]:
  451. if done:
  452. raise dns.exception.FormError("answers after final SOA")
  453. if rrset.rdtype == dns.rdatatype.SOA and rrset.name == oname:
  454. if expecting_SOA:
  455. if rrset[0].serial != serial:
  456. raise dns.exception.FormError(
  457. "IXFR base serial mismatch")
  458. expecting_SOA = False
  459. elif rdtype == dns.rdatatype.IXFR:
  460. delete_mode = not delete_mode
  461. #
  462. # If this SOA RRset is equal to the first we saw then we're
  463. # finished. If this is an IXFR we also check that we're seeing
  464. # the record in the expected part of the response.
  465. #
  466. if rrset == soa_rrset and \
  467. (rdtype == dns.rdatatype.AXFR or
  468. (rdtype == dns.rdatatype.IXFR and delete_mode)):
  469. done = True
  470. elif expecting_SOA:
  471. #
  472. # We made an IXFR request and are expecting another
  473. # SOA RR, but saw something else, so this must be an
  474. # AXFR response.
  475. #
  476. rdtype = dns.rdatatype.AXFR
  477. expecting_SOA = False
  478. if done and q.keyring and not r.had_tsig:
  479. raise dns.exception.FormError("missing TSIG")
  480. yield r
  481. s.close()