|
|
@@ -172,6 +172,22 @@ def _addresses_equal(af, a1, a2):
|
|
|
return n1 == n2 and a1[1:] == a2[1:]
|
|
|
|
|
|
|
|
|
+def _matches_destination(af, from_address, destination, ignore_unexpected):
|
|
|
+ # Check that from_address is appropriate for a response to a query
|
|
|
+ # sent to destination.
|
|
|
+ if not destination:
|
|
|
+ return True
|
|
|
+ if _addresses_equal(af, from_address, destination) or (
|
|
|
+ dns.inet.is_multicast(destination[0]) and from_address[1:] == destination[1:]
|
|
|
+ ):
|
|
|
+ return True
|
|
|
+ elif ignore_unexpected:
|
|
|
+ return False
|
|
|
+ raise UnexpectedSource('got a response from '
|
|
|
+ '%s instead of %s' % (from_address,
|
|
|
+ destination))
|
|
|
+
|
|
|
+
|
|
|
def _destination_and_source(af, where, port, source, source_port):
|
|
|
# Apply defaults and compute destination and source tuples
|
|
|
# suitable for use in connect(), sendto(), or bind().
|
|
|
@@ -222,7 +238,9 @@ def send_udp(sock, what, destination, expiration=None):
|
|
|
|
|
|
def receive_udp(sock, destination, expiration=None,
|
|
|
ignore_unexpected=False, one_rr_per_rrset=False,
|
|
|
- keyring=None, request_mac=b'', ignore_trailing=False):
|
|
|
+ keyring=None, request_mac=b'', ignore_trailing=False,
|
|
|
+ ignore_errors=False,
|
|
|
+ query=None):
|
|
|
"""Read a DNS message from a UDP socket.
|
|
|
|
|
|
*sock*, a ``socket``.
|
|
|
@@ -247,6 +265,14 @@ def receive_udp(sock, destination, expiration=None,
|
|
|
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
|
|
|
junk at end of the received message.
|
|
|
|
|
|
+ *ignore_errors*, a ``bool``. If various format errors or response
|
|
|
+ mismatches occur, ignore them and keep listening for a valid response.
|
|
|
+ The default is ``False``.
|
|
|
+
|
|
|
+ *query*, a ``dns.message.Message`` or ``None``. If not ``None`` and
|
|
|
+ *ignore_errors* is ``True``, check that the received message is a response
|
|
|
+ to this query, and if not keep listening for a valid response.
|
|
|
+
|
|
|
Raises if the message is malformed, if network errors occur, of if
|
|
|
there is a timeout.
|
|
|
|
|
|
@@ -257,22 +283,45 @@ def receive_udp(sock, destination, expiration=None,
|
|
|
while 1:
|
|
|
_wait_for_readable(sock, expiration)
|
|
|
(wire, from_address) = sock.recvfrom(65535)
|
|
|
- if _addresses_equal(sock.family, from_address, destination) or \
|
|
|
- (dns.inet.is_multicast(destination[0]) and
|
|
|
- from_address[1:] == destination[1:]):
|
|
|
- break
|
|
|
- if not ignore_unexpected:
|
|
|
- raise UnexpectedSource('got a response from '
|
|
|
- '%s instead of %s' % (from_address,
|
|
|
- destination))
|
|
|
- received_time = time.time()
|
|
|
- r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
|
|
|
- one_rr_per_rrset=one_rr_per_rrset,
|
|
|
- ignore_trailing=ignore_trailing)
|
|
|
- return (r, received_time)
|
|
|
+ if not _matches_destination(
|
|
|
+ sock.family, from_address, destination, ignore_unexpected
|
|
|
+ ):
|
|
|
+ continue
|
|
|
+ received_time = time.time()
|
|
|
+ try:
|
|
|
+ r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
|
|
|
+ one_rr_per_rrset=one_rr_per_rrset,
|
|
|
+ ignore_trailing=ignore_trailing)
|
|
|
+ except dns.message.Truncated as e:
|
|
|
+ # If we got Truncated and not FORMERR, we at least got the header with TC
|
|
|
+ # set, and very likely the question section, so we'll re-raise if the
|
|
|
+ # message seems to be a response as we need to know when truncation happens.
|
|
|
+ # We need to check that it seems to be a response as we don't want a random
|
|
|
+ # injected message with TC set to cause us to bail out.
|
|
|
+ if (
|
|
|
+ ignore_errors
|
|
|
+ and query is not None
|
|
|
+ and not query.is_response(e.message())
|
|
|
+ ):
|
|
|
+ continue
|
|
|
+ else:
|
|
|
+ raise
|
|
|
+ except Exception:
|
|
|
+ if ignore_errors:
|
|
|
+ continue
|
|
|
+ else:
|
|
|
+ raise
|
|
|
+ if ignore_errors and query is not None and not query.is_response(r):
|
|
|
+ continue
|
|
|
+ if destination:
|
|
|
+ return (r, received_time)
|
|
|
+ else:
|
|
|
+ return (r, received_time, from_address)
|
|
|
+
|
|
|
|
|
|
def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
|
|
|
- ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False):
|
|
|
+ ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False,
|
|
|
+ ignore_errors=False):
|
|
|
"""Return the response obtained after sending a query via UDP.
|
|
|
|
|
|
*q*, a ``dns.message.Message``, the query to send
|
|
|
@@ -305,6 +354,10 @@ def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
|
|
|
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
|
|
|
junk at end of the received message.
|
|
|
|
|
|
+ *ignore_errors*, a ``bool``. If various format errors or response
|
|
|
+ mismatches occur, ignore them and keep listening for a valid response.
|
|
|
+ The default is ``False``.
|
|
|
+
|
|
|
Returns a ``dns.message.Message``.
|
|
|
"""
|
|
|
|
|
|
@@ -322,7 +375,8 @@ def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
|
|
|
(_, sent_time) = send_udp(s, wire, destination, expiration)
|
|
|
(r, received_time) = receive_udp(s, destination, expiration,
|
|
|
ignore_unexpected, one_rr_per_rrset,
|
|
|
- q.keyring, q.mac, ignore_trailing)
|
|
|
+ q.keyring, q.mac, ignore_trailing,
|
|
|
+ ignore_errors, q)
|
|
|
finally:
|
|
|
if sent_time is None or received_time is None:
|
|
|
response_time = 0
|
|
|
@@ -330,7 +384,9 @@ def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
|
|
|
response_time = received_time - sent_time
|
|
|
s.close()
|
|
|
r.time = response_time
|
|
|
- if not q.is_response(r):
|
|
|
+ # We don't need to check q.is_response() if we are in ignore_errors mode
|
|
|
+ # as receive_udp() will have checked it.
|
|
|
+ if not (ignore_errors or q.is_response(r)):
|
|
|
raise BadResponse
|
|
|
return r
|
|
|
|