tests.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505
  1. from __future__ import print_function
  2. import os
  3. import sys
  4. import unittest
  5. import io
  6. from xml.sax.saxutils import XMLGenerator
  7. from xml.sax import SAXParseException
  8. from pyexpat import ExpatError
  9. from defusedxml import cElementTree, ElementTree, minidom, pulldom, sax, xmlrpc
  10. from defusedxml import defuse_stdlib
  11. from defusedxml import (DTDForbidden, EntitiesForbidden,
  12. ExternalReferenceForbidden, NotSupportedError)
  13. from defusedxml.common import PY3
  14. try:
  15. import gzip
  16. except ImportError:
  17. gzip = None
  18. try:
  19. from defusedxml import lxml
  20. from lxml.etree import XMLSyntaxError
  21. LXML3 = lxml.LXML3
  22. except ImportError:
  23. lxml = None
  24. XMLSyntaxError = None
  25. LXML3 = False
  26. HERE = os.path.dirname(os.path.abspath(__file__))
  27. # prevent web access
  28. # based on Debian's rules, Port 9 is discard
  29. os.environ["http_proxy"] = "http://127.0.9.1:9"
  30. os.environ["https_proxy"] = os.environ["http_proxy"]
  31. os.environ["ftp_proxy"] = os.environ["http_proxy"]
  32. class DefusedTestCase(unittest.TestCase):
  33. if PY3:
  34. content_binary = False
  35. else:
  36. content_binary = True
  37. xml_dtd = os.path.join(HERE, "xmltestdata", "dtd.xml")
  38. xml_external = os.path.join(HERE, "xmltestdata", "external.xml")
  39. xml_external_file = os.path.join(HERE, "xmltestdata", "external_file.xml")
  40. xml_quadratic = os.path.join(HERE, "xmltestdata", "quadratic.xml")
  41. xml_simple = os.path.join(HERE, "xmltestdata", "simple.xml")
  42. xml_simple_ns = os.path.join(HERE, "xmltestdata", "simple-ns.xml")
  43. xml_bomb = os.path.join(HERE, "xmltestdata", "xmlbomb.xml")
  44. xml_bomb2 = os.path.join(HERE, "xmltestdata", "xmlbomb2.xml")
  45. xml_cyclic = os.path.join(HERE, "xmltestdata", "cyclic.xml")
  46. def get_content(self, xmlfile):
  47. mode = "rb" if self.content_binary else "r"
  48. with io.open(xmlfile, mode) as f:
  49. data = f.read()
  50. return data
  51. class BaseTests(DefusedTestCase):
  52. module = None
  53. dtd_external_ref = False
  54. external_ref_exception = ExternalReferenceForbidden
  55. cyclic_error = None
  56. iterparse = None
  57. def test_simple_parse(self):
  58. self.parse(self.xml_simple)
  59. self.parseString(self.get_content(self.xml_simple))
  60. if self.iterparse:
  61. self.iterparse(self.xml_simple)
  62. def test_simple_parse_ns(self):
  63. self.parse(self.xml_simple_ns)
  64. self.parseString(self.get_content(self.xml_simple_ns))
  65. if self.iterparse:
  66. self.iterparse(self.xml_simple_ns)
  67. def test_entities_forbidden(self):
  68. self.assertRaises(EntitiesForbidden, self.parse, self.xml_bomb)
  69. self.assertRaises(EntitiesForbidden, self.parse, self.xml_quadratic)
  70. self.assertRaises(EntitiesForbidden, self.parse, self.xml_external)
  71. self.assertRaises(EntitiesForbidden, self.parseString,
  72. self.get_content(self.xml_bomb))
  73. self.assertRaises(EntitiesForbidden, self.parseString,
  74. self.get_content(self.xml_quadratic))
  75. self.assertRaises(EntitiesForbidden, self.parseString,
  76. self.get_content(self.xml_external))
  77. if self.iterparse:
  78. self.assertRaises(EntitiesForbidden, self.iterparse,
  79. self.xml_bomb)
  80. self.assertRaises(EntitiesForbidden, self.iterparse,
  81. self.xml_quadratic)
  82. self.assertRaises(EntitiesForbidden, self.iterparse,
  83. self.xml_external)
  84. def test_entity_cycle(self):
  85. self.assertRaises(self.cyclic_error, self.parse, self.xml_cyclic,
  86. forbid_entities=False)
  87. def test_dtd_forbidden(self):
  88. self.assertRaises(DTDForbidden, self.parse, self.xml_bomb,
  89. forbid_dtd=True)
  90. self.assertRaises(DTDForbidden, self.parse, self.xml_quadratic,
  91. forbid_dtd=True)
  92. self.assertRaises(DTDForbidden, self.parse, self.xml_external,
  93. forbid_dtd=True)
  94. self.assertRaises(DTDForbidden, self.parse, self.xml_dtd,
  95. forbid_dtd=True)
  96. self.assertRaises(DTDForbidden, self.parseString,
  97. self.get_content(self.xml_bomb),
  98. forbid_dtd=True)
  99. self.assertRaises(DTDForbidden, self.parseString,
  100. self.get_content(self.xml_quadratic),
  101. forbid_dtd=True)
  102. self.assertRaises(DTDForbidden, self.parseString,
  103. self.get_content(self.xml_external),
  104. forbid_dtd=True)
  105. self.assertRaises(DTDForbidden, self.parseString,
  106. self.get_content(self.xml_dtd),
  107. forbid_dtd=True)
  108. if self.iterparse:
  109. self.assertRaises(DTDForbidden, self.iterparse,
  110. self.xml_bomb, forbid_dtd=True)
  111. self.assertRaises(DTDForbidden, self.iterparse,
  112. self.xml_quadratic, forbid_dtd=True)
  113. self.assertRaises(DTDForbidden, self.iterparse,
  114. self.xml_external, forbid_dtd=True)
  115. self.assertRaises(DTDForbidden, self.iterparse,
  116. self.xml_dtd, forbid_dtd=True)
  117. def test_dtd_with_external_ref(self):
  118. if self.dtd_external_ref:
  119. self.assertRaises(self.external_ref_exception, self.parse,
  120. self.xml_dtd)
  121. else:
  122. self.parse(self.xml_dtd)
  123. def test_external_ref(self):
  124. self.assertRaises(self.external_ref_exception, self.parse,
  125. self.xml_external, forbid_entities=False)
  126. def test_external_file_ref(self):
  127. content = self.get_content(self.xml_external_file)
  128. if isinstance(content, bytes):
  129. here = HERE.encode(sys.getfilesystemencoding())
  130. content = content.replace(b"/PATH/TO", here)
  131. else:
  132. content = content.replace("/PATH/TO", HERE)
  133. self.assertRaises(self.external_ref_exception, self.parseString,
  134. content, forbid_entities=False)
  135. def test_allow_expansion(self):
  136. self.parse(self.xml_bomb2, forbid_entities=False)
  137. self.parseString(self.get_content(self.xml_bomb2),
  138. forbid_entities=False)
  139. class TestDefusedElementTree(BaseTests):
  140. module = ElementTree
  141. # etree doesn't do external ref lookup
  142. # external_ref_exception = ElementTree.ParseError
  143. cyclic_error = ElementTree.ParseError
  144. def parse(self, xmlfile, **kwargs):
  145. tree = self.module.parse(xmlfile, **kwargs)
  146. return self.module.tostring(tree.getroot())
  147. def parseString(self, xmlstring, **kwargs):
  148. tree = self.module.fromstring(xmlstring, **kwargs)
  149. return self.module.tostring(tree)
  150. def iterparse(self, source, **kwargs):
  151. return list(self.module.iterparse(source, **kwargs))
  152. class TestDefusedcElementTree(TestDefusedElementTree):
  153. module = cElementTree
  154. class TestDefusedMinidom(BaseTests):
  155. module = minidom
  156. cyclic_error = ExpatError
  157. iterparse = None
  158. def parse(self, xmlfile, **kwargs):
  159. doc = self.module.parse(xmlfile, **kwargs)
  160. return doc.toxml()
  161. def parseString(self, xmlstring, **kwargs):
  162. doc = self.module.parseString(xmlstring, **kwargs)
  163. return doc.toxml()
  164. class TestDefusedPulldom(BaseTests):
  165. module = pulldom
  166. cyclic_error = SAXParseException
  167. dtd_external_ref = True
  168. def parse(self, xmlfile, **kwargs):
  169. events = self.module.parse(xmlfile, **kwargs)
  170. return list(events)
  171. def parseString(self, xmlstring, **kwargs):
  172. events = self.module.parseString(xmlstring, **kwargs)
  173. return list(events)
  174. class TestDefusedSax(BaseTests):
  175. module = sax
  176. cyclic_error = SAXParseException
  177. content_binary = True
  178. dtd_external_ref = True
  179. def parse(self, xmlfile, **kwargs):
  180. if PY3:
  181. result = io.StringIO()
  182. else:
  183. result = io.BytesIO()
  184. handler = XMLGenerator(result)
  185. self.module.parse(xmlfile, handler, **kwargs)
  186. return result.getvalue()
  187. def parseString(self, xmlstring, **kwargs):
  188. if PY3:
  189. result = io.StringIO()
  190. else:
  191. result = io.BytesIO()
  192. handler = XMLGenerator(result)
  193. self.module.parseString(xmlstring, handler, **kwargs)
  194. return result.getvalue()
  195. def test_exceptions(self):
  196. with self.assertRaises(EntitiesForbidden) as ctx:
  197. self.parse(self.xml_bomb)
  198. msg = "EntitiesForbidden(name='a', system_id=None, public_id=None)"
  199. self.assertEqual(str(ctx.exception), msg)
  200. self.assertEqual(repr(ctx.exception), msg)
  201. with self.assertRaises(ExternalReferenceForbidden) as ctx:
  202. self.parse(self.xml_external, forbid_entities=False)
  203. msg = ("ExternalReferenceForbidden"
  204. "(system_id='http://www.w3schools.com/xml/note.xml', public_id=None)")
  205. self.assertEqual(str(ctx.exception), msg)
  206. self.assertEqual(repr(ctx.exception), msg)
  207. with self.assertRaises(DTDForbidden) as ctx:
  208. self.parse(self.xml_bomb, forbid_dtd=True)
  209. msg = "DTDForbidden(name='xmlbomb', system_id=None, public_id=None)"
  210. self.assertEqual(str(ctx.exception), msg)
  211. self.assertEqual(repr(ctx.exception), msg)
  212. class TestDefusedLxml(BaseTests):
  213. module = lxml
  214. cyclic_error = XMLSyntaxError
  215. content_binary = True
  216. def parse(self, xmlfile, **kwargs):
  217. try:
  218. tree = self.module.parse(xmlfile, **kwargs)
  219. except XMLSyntaxError:
  220. self.skipTest("lxml detects entityt reference loop")
  221. return self.module.tostring(tree)
  222. def parseString(self, xmlstring, **kwargs):
  223. try:
  224. tree = self.module.fromstring(xmlstring, **kwargs)
  225. except XMLSyntaxError:
  226. self.skipTest("lxml detects entityt reference loop")
  227. return self.module.tostring(tree)
  228. if not LXML3:
  229. def test_entities_forbidden(self):
  230. self.assertRaises(NotSupportedError, self.parse, self.xml_bomb)
  231. def test_dtd_with_external_ref(self):
  232. self.assertRaises(NotSupportedError, self.parse, self.xml_dtd)
  233. def test_external_ref(self):
  234. pass
  235. def test_external_file_ref(self):
  236. pass
  237. def test_restricted_element1(self):
  238. try:
  239. tree = self.module.parse(self.xml_bomb, forbid_dtd=False,
  240. forbid_entities=False)
  241. except XMLSyntaxError:
  242. self.skipTest("lxml detects entityt reference loop")
  243. root = tree.getroot()
  244. self.assertEqual(root.text, None)
  245. self.assertEqual(list(root), [])
  246. self.assertEqual(root.getchildren(), [])
  247. self.assertEqual(list(root.iter()), [root])
  248. self.assertEqual(list(root.iterchildren()), [])
  249. self.assertEqual(list(root.iterdescendants()), [])
  250. self.assertEqual(list(root.itersiblings()), [])
  251. self.assertEqual(list(root.getiterator()), [root])
  252. self.assertEqual(root.getnext(), None)
  253. def test_restricted_element2(self):
  254. try:
  255. tree = self.module.parse(self.xml_bomb2, forbid_dtd=False,
  256. forbid_entities=False)
  257. except XMLSyntaxError:
  258. self.skipTest("lxml detects entityt reference loop")
  259. root = tree.getroot()
  260. bomb, tag = root
  261. self.assertEqual(root.text, "text")
  262. self.assertEqual(list(root), [bomb, tag])
  263. self.assertEqual(root.getchildren(), [bomb, tag])
  264. self.assertEqual(list(root.iter()), [root, bomb, tag])
  265. self.assertEqual(list(root.iterchildren()), [bomb, tag])
  266. self.assertEqual(list(root.iterdescendants()), [bomb, tag])
  267. self.assertEqual(list(root.itersiblings()), [])
  268. self.assertEqual(list(root.getiterator()), [root, bomb, tag])
  269. self.assertEqual(root.getnext(), None)
  270. self.assertEqual(root.getprevious(), None)
  271. self.assertEqual(list(bomb.itersiblings()), [tag])
  272. self.assertEqual(bomb.getnext(), tag)
  273. self.assertEqual(bomb.getprevious(), None)
  274. self.assertEqual(tag.getnext(), None)
  275. self.assertEqual(tag.getprevious(), bomb)
  276. def test_xpath_injection(self):
  277. # show XPath injection vulnerability
  278. xml = """<root><tag id="one" /><tag id="two"/></root>"""
  279. expr = "one' or @id='two"
  280. root = lxml.fromstring(xml)
  281. # insecure way
  282. xp = "tag[@id='%s']" % expr
  283. elements = root.xpath(xp)
  284. self.assertEqual(len(elements), 2)
  285. self.assertEqual(elements, list(root))
  286. # proper and safe way
  287. xp = "tag[@id=$idname]"
  288. elements = root.xpath(xp, idname=expr)
  289. self.assertEqual(len(elements), 0)
  290. self.assertEqual(elements, [])
  291. elements = root.xpath(xp, idname="one")
  292. self.assertEqual(len(elements), 1)
  293. self.assertEqual(elements, list(root)[:1])
  294. class XmlRpcTarget(object):
  295. def __init__(self):
  296. self._data = []
  297. def __str__(self):
  298. return "".join(self._data)
  299. def xml(self, encoding, standalone):
  300. pass
  301. def start(self, tag, attrs):
  302. self._data.append("<%s>" % tag)
  303. def data(self, text):
  304. self._data.append(text)
  305. def end(self, tag):
  306. self._data.append("</%s>" % tag)
  307. class TestXmlRpc(DefusedTestCase):
  308. module = xmlrpc
  309. def parse(self, xmlfile, **kwargs):
  310. target = XmlRpcTarget()
  311. parser = self.module.DefusedExpatParser(target, **kwargs)
  312. data = self.get_content(xmlfile)
  313. parser.feed(data)
  314. parser.close()
  315. return target
  316. def parse_unpatched(self, xmlfile):
  317. target = XmlRpcTarget()
  318. parser = self.module.ExpatParser(target)
  319. data = self.get_content(xmlfile)
  320. parser.feed(data)
  321. parser.close()
  322. return target
  323. def test_xmlrpc(self):
  324. self.assertRaises(EntitiesForbidden, self.parse, self.xml_bomb)
  325. self.assertRaises(EntitiesForbidden, self.parse, self.xml_quadratic)
  326. self.parse(self.xml_dtd)
  327. self.assertRaises(DTDForbidden, self.parse, self.xml_dtd,
  328. forbid_dtd=True)
  329. # def test_xmlrpc_unpatched(self):
  330. # for fname in (self.xml_external, self.xml_dtd):
  331. # print(self.parse_unpatched(fname))
  332. def test_monkeypatch(self):
  333. try:
  334. xmlrpc.monkey_patch()
  335. finally:
  336. xmlrpc.unmonkey_patch()
  337. class TestDefusedGzip(DefusedTestCase):
  338. def get_gzipped(self, length):
  339. f = io.BytesIO()
  340. gzf = gzip.GzipFile(mode="wb", fileobj=f)
  341. gzf.write(b"d" * length)
  342. gzf.close()
  343. f.seek(0)
  344. return f
  345. def decode_response(self, response, limit=None, readlength=1024):
  346. dec = xmlrpc.DefusedGzipDecodedResponse(response, limit)
  347. acc = []
  348. while True:
  349. data = dec.read(readlength)
  350. if not data:
  351. break
  352. acc.append(data)
  353. return b"".join(acc)
  354. def test_defused_gzip_decode(self):
  355. data = self.get_gzipped(4096).getvalue()
  356. result = xmlrpc.defused_gzip_decode(data)
  357. self.assertEqual(result, b"d" * 4096)
  358. result = xmlrpc.defused_gzip_decode(data, -1)
  359. self.assertEqual(result, b"d" * 4096)
  360. result = xmlrpc.defused_gzip_decode(data, 4096)
  361. self.assertEqual(result, b"d" * 4096)
  362. with self.assertRaises(ValueError):
  363. result = xmlrpc.defused_gzip_decode(data, 4095)
  364. with self.assertRaises(ValueError):
  365. result = xmlrpc.defused_gzip_decode(data, 0)
  366. def test_defused_gzip_response(self):
  367. clen = len(self.get_gzipped(4096).getvalue())
  368. response = self.get_gzipped(4096)
  369. data = self.decode_response(response)
  370. self.assertEqual(data, b"d" * 4096)
  371. with self.assertRaises(ValueError):
  372. response = self.get_gzipped(4096)
  373. xmlrpc.DefusedGzipDecodedResponse(response, clen - 1)
  374. with self.assertRaises(ValueError):
  375. response = self.get_gzipped(4096)
  376. self.decode_response(response, 4095)
  377. with self.assertRaises(ValueError):
  378. response = self.get_gzipped(4096)
  379. self.decode_response(response, 4095, 8192)
  380. def test_main():
  381. suite = unittest.TestSuite()
  382. suite.addTests(unittest.makeSuite(TestDefusedcElementTree))
  383. suite.addTests(unittest.makeSuite(TestDefusedElementTree))
  384. suite.addTests(unittest.makeSuite(TestDefusedMinidom))
  385. suite.addTests(unittest.makeSuite(TestDefusedPulldom))
  386. suite.addTests(unittest.makeSuite(TestDefusedSax))
  387. suite.addTests(unittest.makeSuite(TestXmlRpc))
  388. if lxml is not None:
  389. suite.addTests(unittest.makeSuite(TestDefusedLxml))
  390. if gzip is not None:
  391. suite.addTests(unittest.makeSuite(TestDefusedGzip))
  392. return suite
  393. if __name__ == "__main__":
  394. suite = test_main()
  395. result = unittest.TextTestRunner(verbosity=1).run(suite)
  396. # TODO: test that it actually works
  397. defuse_stdlib()
  398. sys.exit(not result.wasSuccessful())