compdoc.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. # -*- coding: utf-8 -*-
  2. # Copyright (c) 2005-2012 Stephen John Machin, Lingfo Pty Ltd
  3. # This module is part of the xlrd package, which is released under a
  4. # BSD-style licence.
  5. # No part of the content of this file was derived from the works of
  6. # David Giffin.
  7. """
  8. Implements the minimal functionality required
  9. to extract a "Workbook" or "Book" stream (as one big string)
  10. from an OLE2 Compound Document file.
  11. """
  12. from __future__ import print_function
  13. import array
  14. import sys
  15. from struct import unpack
  16. from .timemachine import *
  17. #: Magic cookie that should appear in the first 8 bytes of the file.
  18. SIGNATURE = b"\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1"
  19. EOCSID = -2
  20. FREESID = -1
  21. SATSID = -3
  22. MSATSID = -4
  23. EVILSID = -5
  24. class CompDocError(Exception):
  25. pass
  26. class DirNode(object):
  27. def __init__(self, DID, dent, DEBUG=0, logfile=sys.stdout):
  28. # dent is the 128-byte directory entry
  29. self.DID = DID
  30. self.logfile = logfile
  31. (cbufsize, self.etype, self.colour, self.left_DID, self.right_DID,
  32. self.root_DID) = \
  33. unpack('<HBBiii', dent[64:80])
  34. (self.first_SID, self.tot_size) = \
  35. unpack('<ii', dent[116:124])
  36. if cbufsize == 0:
  37. self.name = UNICODE_LITERAL('')
  38. else:
  39. self.name = unicode(dent[0:cbufsize-2], 'utf_16_le') # omit the trailing U+0000
  40. self.children = [] # filled in later
  41. self.parent = -1 # indicates orphan; fixed up later
  42. self.tsinfo = unpack('<IIII', dent[100:116])
  43. if DEBUG:
  44. self.dump(DEBUG)
  45. def dump(self, DEBUG=1):
  46. fprintf(
  47. self.logfile,
  48. "DID=%d name=%r etype=%d DIDs(left=%d right=%d root=%d parent=%d kids=%r) first_SID=%d tot_size=%d\n",
  49. self.DID, self.name, self.etype, self.left_DID,
  50. self.right_DID, self.root_DID, self.parent, self.children, self.first_SID, self.tot_size
  51. )
  52. if DEBUG == 2:
  53. # cre_lo, cre_hi, mod_lo, mod_hi = tsinfo
  54. print("timestamp info", self.tsinfo, file=self.logfile)
  55. def _build_family_tree(dirlist, parent_DID, child_DID):
  56. if child_DID < 0: return
  57. _build_family_tree(dirlist, parent_DID, dirlist[child_DID].left_DID)
  58. dirlist[parent_DID].children.append(child_DID)
  59. dirlist[child_DID].parent = parent_DID
  60. _build_family_tree(dirlist, parent_DID, dirlist[child_DID].right_DID)
  61. if dirlist[child_DID].etype == 1: # storage
  62. _build_family_tree(dirlist, child_DID, dirlist[child_DID].root_DID)
  63. class CompDoc(object):
  64. """
  65. Compound document handler.
  66. :param mem:
  67. The raw contents of the file, as a string, or as an :class:`mmap.mmap`
  68. object. The only operation it needs to support is slicing.
  69. """
  70. def __init__(self, mem, logfile=sys.stdout, DEBUG=0, ignore_workbook_corruption=False):
  71. self.logfile = logfile
  72. self.ignore_workbook_corruption = ignore_workbook_corruption
  73. self.DEBUG = DEBUG
  74. if mem[0:8] != SIGNATURE:
  75. raise CompDocError('Not an OLE2 compound document')
  76. if mem[28:30] != b'\xFE\xFF':
  77. raise CompDocError('Expected "little-endian" marker, found %r' % mem[28:30])
  78. revision, version = unpack('<HH', mem[24:28])
  79. if DEBUG:
  80. print("\nCompDoc format: version=0x%04x revision=0x%04x" % (version, revision), file=logfile)
  81. self.mem = mem
  82. ssz, sssz = unpack('<HH', mem[30:34])
  83. if ssz > 20: # allows for 2**20 bytes i.e. 1MB
  84. print("WARNING: sector size (2**%d) is preposterous; assuming 512 and continuing ..."
  85. % ssz, file=logfile)
  86. ssz = 9
  87. if sssz > ssz:
  88. print("WARNING: short stream sector size (2**%d) is preposterous; assuming 64 and continuing ..."
  89. % sssz, file=logfile)
  90. sssz = 6
  91. self.sec_size = sec_size = 1 << ssz
  92. self.short_sec_size = 1 << sssz
  93. if self.sec_size != 512 or self.short_sec_size != 64:
  94. print("@@@@ sec_size=%d short_sec_size=%d" % (self.sec_size, self.short_sec_size), file=logfile)
  95. (
  96. SAT_tot_secs, self.dir_first_sec_sid, _unused, self.min_size_std_stream,
  97. SSAT_first_sec_sid, SSAT_tot_secs,
  98. MSATX_first_sec_sid, MSATX_tot_secs,
  99. ) = unpack('<iiiiiiii', mem[44:76])
  100. mem_data_len = len(mem) - 512
  101. mem_data_secs, left_over = divmod(mem_data_len, sec_size)
  102. if left_over:
  103. #### raise CompDocError("Not a whole number of sectors")
  104. mem_data_secs += 1
  105. print("WARNING *** file size (%d) not 512 + multiple of sector size (%d)"
  106. % (len(mem), sec_size), file=logfile)
  107. self.mem_data_secs = mem_data_secs # use for checking later
  108. self.mem_data_len = mem_data_len
  109. seen = self.seen = array.array('B', [0]) * mem_data_secs
  110. if DEBUG:
  111. print('sec sizes', ssz, sssz, sec_size, self.short_sec_size, file=logfile)
  112. print("mem data: %d bytes == %d sectors" % (mem_data_len, mem_data_secs), file=logfile)
  113. print("SAT_tot_secs=%d, dir_first_sec_sid=%d, min_size_std_stream=%d"
  114. % (SAT_tot_secs, self.dir_first_sec_sid, self.min_size_std_stream,), file=logfile)
  115. print("SSAT_first_sec_sid=%d, SSAT_tot_secs=%d" % (SSAT_first_sec_sid, SSAT_tot_secs,), file=logfile)
  116. print("MSATX_first_sec_sid=%d, MSATX_tot_secs=%d" % (MSATX_first_sec_sid, MSATX_tot_secs,), file=logfile)
  117. nent = sec_size // 4 # number of SID entries in a sector
  118. fmt = "<%di" % nent
  119. trunc_warned = 0
  120. #
  121. # === build the MSAT ===
  122. #
  123. MSAT = list(unpack('<109i', mem[76:512]))
  124. SAT_sectors_reqd = (mem_data_secs + nent - 1) // nent
  125. expected_MSATX_sectors = max(0, (SAT_sectors_reqd - 109 + nent - 2) // (nent - 1))
  126. actual_MSATX_sectors = 0
  127. if MSATX_tot_secs == 0 and MSATX_first_sec_sid in (EOCSID, FREESID, 0):
  128. # Strictly, if there is no MSAT extension, then MSATX_first_sec_sid
  129. # should be set to EOCSID ... FREESID and 0 have been met in the wild.
  130. pass # Presuming no extension
  131. else:
  132. sid = MSATX_first_sec_sid
  133. while sid not in (EOCSID, FREESID, MSATSID):
  134. # Above should be only EOCSID according to MS & OOo docs
  135. # but Excel doesn't complain about FREESID. Zero is a valid
  136. # sector number, not a sentinel.
  137. if DEBUG > 1:
  138. print('MSATX: sid=%d (0x%08X)' % (sid, sid), file=logfile)
  139. if sid >= mem_data_secs:
  140. msg = "MSAT extension: accessing sector %d but only %d in file" % (sid, mem_data_secs)
  141. if DEBUG > 1:
  142. print(msg, file=logfile)
  143. break
  144. raise CompDocError(msg)
  145. elif sid < 0:
  146. raise CompDocError("MSAT extension: invalid sector id: %d" % sid)
  147. if seen[sid]:
  148. raise CompDocError("MSAT corruption: seen[%d] == %d" % (sid, seen[sid]))
  149. seen[sid] = 1
  150. actual_MSATX_sectors += 1
  151. if DEBUG and actual_MSATX_sectors > expected_MSATX_sectors:
  152. print("[1]===>>>", mem_data_secs, nent, SAT_sectors_reqd, expected_MSATX_sectors, actual_MSATX_sectors, file=logfile)
  153. offset = 512 + sec_size * sid
  154. MSAT.extend(unpack(fmt, mem[offset:offset+sec_size]))
  155. sid = MSAT.pop() # last sector id is sid of next sector in the chain
  156. if DEBUG and actual_MSATX_sectors != expected_MSATX_sectors:
  157. print("[2]===>>>", mem_data_secs, nent, SAT_sectors_reqd, expected_MSATX_sectors, actual_MSATX_sectors, file=logfile)
  158. if DEBUG:
  159. print("MSAT: len =", len(MSAT), file=logfile)
  160. dump_list(MSAT, 10, logfile)
  161. #
  162. # === build the SAT ===
  163. #
  164. self.SAT = []
  165. actual_SAT_sectors = 0
  166. dump_again = 0
  167. for msidx in xrange(len(MSAT)):
  168. msid = MSAT[msidx]
  169. if msid in (FREESID, EOCSID):
  170. # Specification: the MSAT array may be padded with trailing FREESID entries.
  171. # Toleration: a FREESID or EOCSID entry anywhere in the MSAT array will be ignored.
  172. continue
  173. if msid >= mem_data_secs:
  174. if not trunc_warned:
  175. print("WARNING *** File is truncated, or OLE2 MSAT is corrupt!!", file=logfile)
  176. print("INFO: Trying to access sector %d but only %d available"
  177. % (msid, mem_data_secs), file=logfile)
  178. trunc_warned = 1
  179. MSAT[msidx] = EVILSID
  180. dump_again = 1
  181. continue
  182. elif msid < -2:
  183. raise CompDocError("MSAT: invalid sector id: %d" % msid)
  184. if seen[msid]:
  185. raise CompDocError("MSAT extension corruption: seen[%d] == %d" % (msid, seen[msid]))
  186. seen[msid] = 2
  187. actual_SAT_sectors += 1
  188. if DEBUG and actual_SAT_sectors > SAT_sectors_reqd:
  189. print("[3]===>>>", mem_data_secs, nent, SAT_sectors_reqd, expected_MSATX_sectors, actual_MSATX_sectors, actual_SAT_sectors, msid, file=logfile)
  190. offset = 512 + sec_size * msid
  191. self.SAT.extend(unpack(fmt, mem[offset:offset+sec_size]))
  192. if DEBUG:
  193. print("SAT: len =", len(self.SAT), file=logfile)
  194. dump_list(self.SAT, 10, logfile)
  195. # print >> logfile, "SAT ",
  196. # for i, s in enumerate(self.SAT):
  197. # print >> logfile, "entry: %4d offset: %6d, next entry: %4d" % (i, 512 + sec_size * i, s)
  198. # print >> logfile, "%d:%d " % (i, s),
  199. print(file=logfile)
  200. if DEBUG and dump_again:
  201. print("MSAT: len =", len(MSAT), file=logfile)
  202. dump_list(MSAT, 10, logfile)
  203. for satx in xrange(mem_data_secs, len(self.SAT)):
  204. self.SAT[satx] = EVILSID
  205. print("SAT: len =", len(self.SAT), file=logfile)
  206. dump_list(self.SAT, 10, logfile)
  207. #
  208. # === build the directory ===
  209. #
  210. dbytes = self._get_stream(
  211. self.mem, 512, self.SAT, self.sec_size, self.dir_first_sec_sid,
  212. name="directory", seen_id=3)
  213. dirlist = []
  214. did = -1
  215. for pos in xrange(0, len(dbytes), 128):
  216. did += 1
  217. dirlist.append(DirNode(did, dbytes[pos:pos+128], 0, logfile))
  218. self.dirlist = dirlist
  219. _build_family_tree(dirlist, 0, dirlist[0].root_DID) # and stand well back ...
  220. if DEBUG:
  221. for d in dirlist:
  222. d.dump(DEBUG)
  223. #
  224. # === get the SSCS ===
  225. #
  226. sscs_dir = self.dirlist[0]
  227. assert sscs_dir.etype == 5 # root entry
  228. if sscs_dir.first_SID < 0 or sscs_dir.tot_size == 0:
  229. # Problem reported by Frank Hoffsuemmer: some software was
  230. # writing -1 instead of -2 (EOCSID) for the first_SID
  231. # when the SCCS was empty. Not having EOCSID caused assertion
  232. # failure in _get_stream.
  233. # Solution: avoid calling _get_stream in any case when the
  234. # SCSS appears to be empty.
  235. self.SSCS = ""
  236. else:
  237. self.SSCS = self._get_stream(
  238. self.mem, 512, self.SAT, sec_size, sscs_dir.first_SID,
  239. sscs_dir.tot_size, name="SSCS", seen_id=4)
  240. # if DEBUG: print >> logfile, "SSCS", repr(self.SSCS)
  241. #
  242. # === build the SSAT ===
  243. #
  244. self.SSAT = []
  245. if SSAT_tot_secs > 0 and sscs_dir.tot_size == 0:
  246. print("WARNING *** OLE2 inconsistency: SSCS size is 0 but SSAT size is non-zero", file=logfile)
  247. if sscs_dir.tot_size > 0:
  248. sid = SSAT_first_sec_sid
  249. nsecs = SSAT_tot_secs
  250. while sid >= 0 and nsecs > 0:
  251. if seen[sid]:
  252. raise CompDocError("SSAT corruption: seen[%d] == %d" % (sid, seen[sid]))
  253. seen[sid] = 5
  254. nsecs -= 1
  255. start_pos = 512 + sid * sec_size
  256. news = list(unpack(fmt, mem[start_pos:start_pos+sec_size]))
  257. self.SSAT.extend(news)
  258. sid = self.SAT[sid]
  259. if DEBUG: print("SSAT last sid %d; remaining sectors %d" % (sid, nsecs), file=logfile)
  260. assert nsecs == 0 and sid == EOCSID
  261. if DEBUG:
  262. print("SSAT", file=logfile)
  263. dump_list(self.SSAT, 10, logfile)
  264. if DEBUG:
  265. print("seen", file=logfile)
  266. dump_list(seen, 20, logfile)
  267. def _get_stream(self, mem, base, sat, sec_size, start_sid, size=None, name='', seen_id=None):
  268. # print >> self.logfile, "_get_stream", base, sec_size, start_sid, size
  269. sectors = []
  270. s = start_sid
  271. if size is None:
  272. # nothing to check against
  273. while s >= 0:
  274. if seen_id is not None:
  275. if self.seen[s]:
  276. raise CompDocError("%s corruption: seen[%d] == %d" % (name, s, self.seen[s]))
  277. self.seen[s] = seen_id
  278. start_pos = base + s * sec_size
  279. sectors.append(mem[start_pos:start_pos+sec_size])
  280. try:
  281. s = sat[s]
  282. except IndexError:
  283. raise CompDocError(
  284. "OLE2 stream %r: sector allocation table invalid entry (%d)" %
  285. (name, s)
  286. )
  287. assert s == EOCSID
  288. else:
  289. todo = size
  290. while s >= 0:
  291. if seen_id is not None:
  292. if self.seen[s]:
  293. raise CompDocError("%s corruption: seen[%d] == %d" % (name, s, self.seen[s]))
  294. self.seen[s] = seen_id
  295. start_pos = base + s * sec_size
  296. grab = sec_size
  297. if grab > todo:
  298. grab = todo
  299. todo -= grab
  300. sectors.append(mem[start_pos:start_pos+grab])
  301. try:
  302. s = sat[s]
  303. except IndexError:
  304. raise CompDocError(
  305. "OLE2 stream %r: sector allocation table invalid entry (%d)" %
  306. (name, s)
  307. )
  308. assert s == EOCSID
  309. if todo != 0:
  310. fprintf(self.logfile,
  311. "WARNING *** OLE2 stream %r: expected size %d, actual size %d\n",
  312. name, size, size - todo)
  313. return b''.join(sectors)
  314. def _dir_search(self, path, storage_DID=0):
  315. # Return matching DirNode instance, or None
  316. head = path[0]
  317. tail = path[1:]
  318. dl = self.dirlist
  319. for child in dl[storage_DID].children:
  320. if dl[child].name.lower() == head.lower():
  321. et = dl[child].etype
  322. if et == 2:
  323. return dl[child]
  324. if et == 1:
  325. if not tail:
  326. raise CompDocError("Requested component is a 'storage'")
  327. return self._dir_search(tail, child)
  328. dl[child].dump(1)
  329. raise CompDocError("Requested stream is not a 'user stream'")
  330. return None
  331. def get_named_stream(self, qname):
  332. """
  333. Interrogate the compound document's directory; return the stream as a
  334. string if found, otherwise return ``None``.
  335. :param qname:
  336. Name of the desired stream e.g. ``'Workbook'``.
  337. Should be in Unicode or convertible thereto.
  338. """
  339. d = self._dir_search(qname.split("/"))
  340. if d is None:
  341. return None
  342. if d.tot_size >= self.min_size_std_stream:
  343. return self._get_stream(
  344. self.mem, 512, self.SAT, self.sec_size, d.first_SID,
  345. d.tot_size, name=qname, seen_id=d.DID+6)
  346. else:
  347. return self._get_stream(
  348. self.SSCS, 0, self.SSAT, self.short_sec_size, d.first_SID,
  349. d.tot_size, name=qname + " (from SSCS)", seen_id=None)
  350. def locate_named_stream(self, qname):
  351. """
  352. Interrogate the compound document's directory.
  353. If the named stream is not found, ``(None, 0, 0)`` will be returned.
  354. If the named stream is found and is contiguous within the original
  355. byte sequence (``mem``) used when the document was opened,
  356. then ``(mem, offset_to_start_of_stream, length_of_stream)`` is returned.
  357. Otherwise a new string is built from the fragments and
  358. ``(new_string, 0, length_of_stream)`` is returned.
  359. :param qname:
  360. Name of the desired stream e.g. ``'Workbook'``.
  361. Should be in Unicode or convertible thereto.
  362. """
  363. d = self._dir_search(qname.split("/"))
  364. if d is None:
  365. return (None, 0, 0)
  366. if d.tot_size > self.mem_data_len:
  367. raise CompDocError("%r stream length (%d bytes) > file data size (%d bytes)"
  368. % (qname, d.tot_size, self.mem_data_len))
  369. if d.tot_size >= self.min_size_std_stream:
  370. result = self._locate_stream(
  371. self.mem, 512, self.SAT, self.sec_size, d.first_SID,
  372. d.tot_size, qname, d.DID+6)
  373. if self.DEBUG:
  374. print("\nseen", file=self.logfile)
  375. dump_list(self.seen, 20, self.logfile)
  376. return result
  377. else:
  378. return (
  379. self._get_stream(
  380. self.SSCS, 0, self.SSAT, self.short_sec_size, d.first_SID,
  381. d.tot_size, qname + " (from SSCS)", None),
  382. 0,
  383. d.tot_size,
  384. )
  385. def _locate_stream(self, mem, base, sat, sec_size, start_sid, expected_stream_size, qname, seen_id):
  386. # print >> self.logfile, "_locate_stream", base, sec_size, start_sid, expected_stream_size
  387. s = start_sid
  388. if s < 0:
  389. raise CompDocError("_locate_stream: start_sid (%d) is -ve" % start_sid)
  390. p = -99 # dummy previous SID
  391. start_pos = -9999
  392. end_pos = -8888
  393. slices = []
  394. tot_found = 0
  395. found_limit = (expected_stream_size + sec_size - 1) // sec_size
  396. while s >= 0:
  397. if self.seen[s]:
  398. if not self.ignore_workbook_corruption:
  399. print("_locate_stream(%s): seen" % qname, file=self.logfile); dump_list(self.seen, 20, self.logfile)
  400. raise CompDocError("%s corruption: seen[%d] == %d" % (qname, s, self.seen[s]))
  401. self.seen[s] = seen_id
  402. tot_found += 1
  403. if tot_found > found_limit:
  404. # Note: expected size rounded up to higher sector
  405. raise CompDocError(
  406. "%s: size exceeds expected %d bytes; corrupt?"
  407. % (qname, found_limit * sec_size)
  408. )
  409. if s == p+1:
  410. # contiguous sectors
  411. end_pos += sec_size
  412. else:
  413. # start new slice
  414. if p >= 0:
  415. # not first time
  416. slices.append((start_pos, end_pos))
  417. start_pos = base + s * sec_size
  418. end_pos = start_pos + sec_size
  419. p = s
  420. s = sat[s]
  421. assert s == EOCSID
  422. assert tot_found == found_limit
  423. # print >> self.logfile, "_locate_stream(%s): seen" % qname; dump_list(self.seen, 20, self.logfile)
  424. if not slices:
  425. # The stream is contiguous ... just what we like!
  426. return (mem, start_pos, expected_stream_size)
  427. slices.append((start_pos, end_pos))
  428. # print >> self.logfile, "+++>>> %d fragments" % len(slices)
  429. return (b''.join(mem[start_pos:end_pos] for start_pos, end_pos in slices), 0, expected_stream_size)
  430. # ==========================================================================================
  431. def x_dump_line(alist, stride, f, dpos, equal=0):
  432. print("%5d%s" % (dpos, " ="[equal]), end=' ', file=f)
  433. for value in alist[dpos:dpos + stride]:
  434. print(str(value), end=' ', file=f)
  435. print(file=f)
  436. def dump_list(alist, stride, f=sys.stdout):
  437. def _dump_line(dpos, equal=0):
  438. print("%5d%s" % (dpos, " ="[equal]), end=' ', file=f)
  439. for value in alist[dpos:dpos + stride]:
  440. print(str(value), end=' ', file=f)
  441. print(file=f)
  442. pos = None
  443. oldpos = None
  444. for pos in xrange(0, len(alist), stride):
  445. if oldpos is None:
  446. _dump_line(pos)
  447. oldpos = pos
  448. elif alist[pos:pos+stride] != alist[oldpos:oldpos+stride]:
  449. if pos - oldpos > stride:
  450. _dump_line(pos - stride, equal=1)
  451. _dump_line(pos)
  452. oldpos = pos
  453. if oldpos is not None and pos is not None and pos != oldpos:
  454. _dump_line(pos, equal=1)