websocket.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831
  1. import base64
  2. import codecs
  3. import collections
  4. import errno
  5. from random import Random
  6. from socket import error as SocketError
  7. import string
  8. import struct
  9. import sys
  10. import time
  11. import zlib
  12. try:
  13. from hashlib import md5, sha1
  14. except ImportError: # pragma NO COVER
  15. from md5 import md5
  16. from sha import sha as sha1
  17. from eventlet import semaphore
  18. from eventlet import wsgi
  19. from eventlet.green import socket
  20. from eventlet.support import get_errno
  21. import six
  22. # Python 2's utf8 decoding is more lenient than we'd like
  23. # In order to pass autobahn's testsuite we need stricter validation
  24. # if available...
  25. for _mod in ('wsaccel.utf8validator', 'autobahn.utf8validator'):
  26. # autobahn has it's own python-based validator. in newest versions
  27. # this prefers to use wsaccel, a cython based implementation, if available.
  28. # wsaccel may also be installed w/out autobahn, or with a earlier version.
  29. try:
  30. utf8validator = __import__(_mod, {}, {}, [''])
  31. except ImportError:
  32. utf8validator = None
  33. else:
  34. break
  35. ACCEPTABLE_CLIENT_ERRORS = set((errno.ECONNRESET, errno.EPIPE))
  36. __all__ = ["WebSocketWSGI", "WebSocket"]
  37. PROTOCOL_GUID = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
  38. VALID_CLOSE_STATUS = set(
  39. list(range(1000, 1004)) +
  40. list(range(1007, 1012)) +
  41. # 3000-3999: reserved for use by libraries, frameworks,
  42. # and applications
  43. list(range(3000, 4000)) +
  44. # 4000-4999: reserved for private use and thus can't
  45. # be registered
  46. list(range(4000, 5000))
  47. )
  48. class BadRequest(Exception):
  49. def __init__(self, status='400 Bad Request', body=None, headers=None):
  50. super(Exception, self).__init__()
  51. self.status = status
  52. self.body = body
  53. self.headers = headers
  54. class WebSocketWSGI(object):
  55. """Wraps a websocket handler function in a WSGI application.
  56. Use it like this::
  57. @websocket.WebSocketWSGI
  58. def my_handler(ws):
  59. from_browser = ws.wait()
  60. ws.send("from server")
  61. The single argument to the function will be an instance of
  62. :class:`WebSocket`. To close the socket, simply return from the
  63. function. Note that the server will log the websocket request at
  64. the time of closure.
  65. """
  66. def __init__(self, handler):
  67. self.handler = handler
  68. self.protocol_version = None
  69. self.support_legacy_versions = True
  70. self.supported_protocols = []
  71. self.origin_checker = None
  72. @classmethod
  73. def configured(cls,
  74. handler=None,
  75. supported_protocols=None,
  76. origin_checker=None,
  77. support_legacy_versions=False):
  78. def decorator(handler):
  79. inst = cls(handler)
  80. inst.support_legacy_versions = support_legacy_versions
  81. inst.origin_checker = origin_checker
  82. if supported_protocols:
  83. inst.supported_protocols = supported_protocols
  84. return inst
  85. if handler is None:
  86. return decorator
  87. return decorator(handler)
  88. def __call__(self, environ, start_response):
  89. http_connection_parts = [
  90. part.strip()
  91. for part in environ.get('HTTP_CONNECTION', '').lower().split(',')]
  92. if not ('upgrade' in http_connection_parts and
  93. environ.get('HTTP_UPGRADE', '').lower() == 'websocket'):
  94. # need to check a few more things here for true compliance
  95. start_response('400 Bad Request', [('Connection', 'close')])
  96. return []
  97. try:
  98. if 'HTTP_SEC_WEBSOCKET_VERSION' in environ:
  99. ws = self._handle_hybi_request(environ)
  100. elif self.support_legacy_versions:
  101. ws = self._handle_legacy_request(environ)
  102. else:
  103. raise BadRequest()
  104. except BadRequest as e:
  105. status = e.status
  106. body = e.body or b''
  107. headers = e.headers or []
  108. start_response(status,
  109. [('Connection', 'close'), ] + headers)
  110. return [body]
  111. try:
  112. self.handler(ws)
  113. except socket.error as e:
  114. if get_errno(e) not in ACCEPTABLE_CLIENT_ERRORS:
  115. raise
  116. # Make sure we send the closing frame
  117. ws._send_closing_frame(True)
  118. # use this undocumented feature of eventlet.wsgi to ensure that it
  119. # doesn't barf on the fact that we didn't call start_response
  120. return wsgi.ALREADY_HANDLED
  121. def _handle_legacy_request(self, environ):
  122. if 'eventlet.input' in environ:
  123. sock = environ['eventlet.input'].get_socket()
  124. elif 'gunicorn.socket' in environ:
  125. sock = environ['gunicorn.socket']
  126. else:
  127. raise Exception('No eventlet.input or gunicorn.socket present in environ.')
  128. if 'HTTP_SEC_WEBSOCKET_KEY1' in environ:
  129. self.protocol_version = 76
  130. if 'HTTP_SEC_WEBSOCKET_KEY2' not in environ:
  131. raise BadRequest()
  132. else:
  133. self.protocol_version = 75
  134. if self.protocol_version == 76:
  135. key1 = self._extract_number(environ['HTTP_SEC_WEBSOCKET_KEY1'])
  136. key2 = self._extract_number(environ['HTTP_SEC_WEBSOCKET_KEY2'])
  137. # There's no content-length header in the request, but it has 8
  138. # bytes of data.
  139. environ['wsgi.input'].content_length = 8
  140. key3 = environ['wsgi.input'].read(8)
  141. key = struct.pack(">II", key1, key2) + key3
  142. response = md5(key).digest()
  143. # Start building the response
  144. scheme = 'ws'
  145. if environ.get('wsgi.url_scheme') == 'https':
  146. scheme = 'wss'
  147. location = '%s://%s%s%s' % (
  148. scheme,
  149. environ.get('HTTP_HOST'),
  150. environ.get('SCRIPT_NAME'),
  151. environ.get('PATH_INFO')
  152. )
  153. qs = environ.get('QUERY_STRING')
  154. if qs is not None:
  155. location += '?' + qs
  156. if self.protocol_version == 75:
  157. handshake_reply = (
  158. b"HTTP/1.1 101 Web Socket Protocol Handshake\r\n"
  159. b"Upgrade: WebSocket\r\n"
  160. b"Connection: Upgrade\r\n"
  161. b"WebSocket-Origin: " + six.b(environ.get('HTTP_ORIGIN')) + b"\r\n"
  162. b"WebSocket-Location: " + six.b(location) + b"\r\n\r\n"
  163. )
  164. elif self.protocol_version == 76:
  165. handshake_reply = (
  166. b"HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
  167. b"Upgrade: WebSocket\r\n"
  168. b"Connection: Upgrade\r\n"
  169. b"Sec-WebSocket-Origin: " + six.b(environ.get('HTTP_ORIGIN')) + b"\r\n"
  170. b"Sec-WebSocket-Protocol: " +
  171. six.b(environ.get('HTTP_SEC_WEBSOCKET_PROTOCOL', 'default')) + b"\r\n"
  172. b"Sec-WebSocket-Location: " + six.b(location) + b"\r\n"
  173. b"\r\n" + response
  174. )
  175. else: # pragma NO COVER
  176. raise ValueError("Unknown WebSocket protocol version.")
  177. sock.sendall(handshake_reply)
  178. return WebSocket(sock, environ, self.protocol_version)
  179. def _parse_extension_header(self, header):
  180. if header is None:
  181. return None
  182. res = {}
  183. for ext in header.split(","):
  184. parts = ext.split(";")
  185. config = {}
  186. for part in parts[1:]:
  187. key_val = part.split("=")
  188. if len(key_val) == 1:
  189. config[key_val[0].strip().lower()] = True
  190. else:
  191. config[key_val[0].strip().lower()] = key_val[1].strip().strip('"').lower()
  192. res.setdefault(parts[0].strip().lower(), []).append(config)
  193. return res
  194. def _negotiate_permessage_deflate(self, extensions):
  195. if not extensions:
  196. return None
  197. deflate = extensions.get("permessage-deflate")
  198. if deflate is None:
  199. return None
  200. for config in deflate:
  201. # We'll evaluate each config in the client's preferred order and pick
  202. # the first that we can support.
  203. want_config = {
  204. # These are bool options, we can support both
  205. "server_no_context_takeover": config.get("server_no_context_takeover", False),
  206. "client_no_context_takeover": config.get("client_no_context_takeover", False)
  207. }
  208. # These are either bool OR int options. True means the client can accept a value
  209. # for the option, a number means the client wants that specific value.
  210. max_wbits = min(zlib.MAX_WBITS, 15)
  211. mwb = config.get("server_max_window_bits")
  212. if mwb is not None:
  213. if mwb is True:
  214. want_config["server_max_window_bits"] = max_wbits
  215. else:
  216. want_config["server_max_window_bits"] = \
  217. int(config.get("server_max_window_bits", max_wbits))
  218. if not (8 <= want_config["server_max_window_bits"] <= 15):
  219. continue
  220. mwb = config.get("client_max_window_bits")
  221. if mwb is not None:
  222. if mwb is True:
  223. want_config["client_max_window_bits"] = max_wbits
  224. else:
  225. want_config["client_max_window_bits"] = \
  226. int(config.get("client_max_window_bits", max_wbits))
  227. if not (8 <= want_config["client_max_window_bits"] <= 15):
  228. continue
  229. return want_config
  230. return None
  231. def _format_extension_header(self, parsed_extensions):
  232. if not parsed_extensions:
  233. return None
  234. parts = []
  235. for name, config in parsed_extensions.items():
  236. ext_parts = [six.b(name)]
  237. for key, value in config.items():
  238. if value is False:
  239. pass
  240. elif value is True:
  241. ext_parts.append(six.b(key))
  242. else:
  243. ext_parts.append(six.b("%s=%s" % (key, str(value))))
  244. parts.append(b"; ".join(ext_parts))
  245. return b", ".join(parts)
  246. def _handle_hybi_request(self, environ):
  247. if 'eventlet.input' in environ:
  248. sock = environ['eventlet.input'].get_socket()
  249. elif 'gunicorn.socket' in environ:
  250. sock = environ['gunicorn.socket']
  251. else:
  252. raise Exception('No eventlet.input or gunicorn.socket present in environ.')
  253. hybi_version = environ['HTTP_SEC_WEBSOCKET_VERSION']
  254. if hybi_version not in ('8', '13', ):
  255. raise BadRequest(status='426 Upgrade Required',
  256. headers=[('Sec-WebSocket-Version', '8, 13')])
  257. self.protocol_version = int(hybi_version)
  258. if 'HTTP_SEC_WEBSOCKET_KEY' not in environ:
  259. # That's bad.
  260. raise BadRequest()
  261. origin = environ.get(
  262. 'HTTP_ORIGIN',
  263. (environ.get('HTTP_SEC_WEBSOCKET_ORIGIN', '')
  264. if self.protocol_version <= 8 else ''))
  265. if self.origin_checker is not None:
  266. if not self.origin_checker(environ.get('HTTP_HOST'), origin):
  267. raise BadRequest(status='403 Forbidden')
  268. protocols = environ.get('HTTP_SEC_WEBSOCKET_PROTOCOL', None)
  269. negotiated_protocol = None
  270. if protocols:
  271. for p in (i.strip() for i in protocols.split(',')):
  272. if p in self.supported_protocols:
  273. negotiated_protocol = p
  274. break
  275. key = environ['HTTP_SEC_WEBSOCKET_KEY']
  276. response = base64.b64encode(sha1(six.b(key) + PROTOCOL_GUID).digest())
  277. handshake_reply = [b"HTTP/1.1 101 Switching Protocols",
  278. b"Upgrade: websocket",
  279. b"Connection: Upgrade",
  280. b"Sec-WebSocket-Accept: " + response]
  281. if negotiated_protocol:
  282. handshake_reply.append(b"Sec-WebSocket-Protocol: " + six.b(negotiated_protocol))
  283. parsed_extensions = {}
  284. extensions = self._parse_extension_header(environ.get("HTTP_SEC_WEBSOCKET_EXTENSIONS"))
  285. deflate = self._negotiate_permessage_deflate(extensions)
  286. if deflate is not None:
  287. parsed_extensions["permessage-deflate"] = deflate
  288. formatted_ext = self._format_extension_header(parsed_extensions)
  289. if formatted_ext is not None:
  290. handshake_reply.append(b"Sec-WebSocket-Extensions: " + formatted_ext)
  291. sock.sendall(b'\r\n'.join(handshake_reply) + b'\r\n\r\n')
  292. return RFC6455WebSocket(sock, environ, self.protocol_version,
  293. protocol=negotiated_protocol,
  294. extensions=parsed_extensions)
  295. def _extract_number(self, value):
  296. """
  297. Utility function which, given a string like 'g98sd 5[]221@1', will
  298. return 9852211. Used to parse the Sec-WebSocket-Key headers.
  299. """
  300. out = ""
  301. spaces = 0
  302. for char in value:
  303. if char in string.digits:
  304. out += char
  305. elif char == " ":
  306. spaces += 1
  307. return int(out) // spaces
  308. class WebSocket(object):
  309. """A websocket object that handles the details of
  310. serialization/deserialization to the socket.
  311. The primary way to interact with a :class:`WebSocket` object is to
  312. call :meth:`send` and :meth:`wait` in order to pass messages back
  313. and forth with the browser. Also available are the following
  314. properties:
  315. path
  316. The path value of the request. This is the same as the WSGI PATH_INFO variable,
  317. but more convenient.
  318. protocol
  319. The value of the Websocket-Protocol header.
  320. origin
  321. The value of the 'Origin' header.
  322. environ
  323. The full WSGI environment for this request.
  324. """
  325. def __init__(self, sock, environ, version=76):
  326. """
  327. :param socket: The eventlet socket
  328. :type socket: :class:`eventlet.greenio.GreenSocket`
  329. :param environ: The wsgi environment
  330. :param version: The WebSocket spec version to follow (default is 76)
  331. """
  332. self.log = environ.get('wsgi.errors', sys.stderr)
  333. self.log_context = 'server={shost}/{spath} client={caddr}:{cport}'.format(
  334. shost=environ.get('HTTP_HOST'),
  335. spath=environ.get('SCRIPT_NAME', '') + environ.get('PATH_INFO', ''),
  336. caddr=environ.get('REMOTE_ADDR'), cport=environ.get('REMOTE_PORT'),
  337. )
  338. self.socket = sock
  339. self.origin = environ.get('HTTP_ORIGIN')
  340. self.protocol = environ.get('HTTP_WEBSOCKET_PROTOCOL')
  341. self.path = environ.get('PATH_INFO')
  342. self.environ = environ
  343. self.version = version
  344. self.websocket_closed = False
  345. self._buf = b""
  346. self._msgs = collections.deque()
  347. self._sendlock = semaphore.Semaphore()
  348. def _pack_message(self, message):
  349. """Pack the message inside ``00`` and ``FF``
  350. As per the dataframing section (5.3) for the websocket spec
  351. """
  352. if isinstance(message, six.text_type):
  353. message = message.encode('utf-8')
  354. elif not isinstance(message, six.binary_type):
  355. message = six.b(str(message))
  356. packed = b"\x00" + message + b"\xFF"
  357. return packed
  358. def _parse_messages(self):
  359. """ Parses for messages in the buffer *buf*. It is assumed that
  360. the buffer contains the start character for a message, but that it
  361. may contain only part of the rest of the message.
  362. Returns an array of messages, and the buffer remainder that
  363. didn't contain any full messages."""
  364. msgs = []
  365. end_idx = 0
  366. buf = self._buf
  367. while buf:
  368. frame_type = six.indexbytes(buf, 0)
  369. if frame_type == 0:
  370. # Normal message.
  371. end_idx = buf.find(b"\xFF")
  372. if end_idx == -1: # pragma NO COVER
  373. break
  374. msgs.append(buf[1:end_idx].decode('utf-8', 'replace'))
  375. buf = buf[end_idx + 1:]
  376. elif frame_type == 255:
  377. # Closing handshake.
  378. assert six.indexbytes(buf, 1) == 0, "Unexpected closing handshake: %r" % buf
  379. self.websocket_closed = True
  380. break
  381. else:
  382. raise ValueError("Don't understand how to parse this type of message: %r" % buf)
  383. self._buf = buf
  384. return msgs
  385. def send(self, message):
  386. """Send a message to the browser.
  387. *message* should be convertable to a string; unicode objects should be
  388. encodable as utf-8. Raises socket.error with errno of 32
  389. (broken pipe) if the socket has already been closed by the client."""
  390. packed = self._pack_message(message)
  391. # if two greenthreads are trying to send at the same time
  392. # on the same socket, sendlock prevents interleaving and corruption
  393. self._sendlock.acquire()
  394. try:
  395. self.socket.sendall(packed)
  396. finally:
  397. self._sendlock.release()
  398. def wait(self):
  399. """Waits for and deserializes messages.
  400. Returns a single message; the oldest not yet processed. If the client
  401. has already closed the connection, returns None. This is different
  402. from normal socket behavior because the empty string is a valid
  403. websocket message."""
  404. while not self._msgs:
  405. # Websocket might be closed already.
  406. if self.websocket_closed:
  407. return None
  408. # no parsed messages, must mean buf needs more data
  409. delta = self.socket.recv(8096)
  410. if delta == b'':
  411. return None
  412. self._buf += delta
  413. msgs = self._parse_messages()
  414. self._msgs.extend(msgs)
  415. return self._msgs.popleft()
  416. def _send_closing_frame(self, ignore_send_errors=False):
  417. """Sends the closing frame to the client, if required."""
  418. if self.version == 76 and not self.websocket_closed:
  419. try:
  420. self.socket.sendall(b"\xff\x00")
  421. except SocketError:
  422. # Sometimes, like when the remote side cuts off the connection,
  423. # we don't care about this.
  424. if not ignore_send_errors: # pragma NO COVER
  425. raise
  426. self.websocket_closed = True
  427. def close(self):
  428. """Forcibly close the websocket; generally it is preferable to
  429. return from the handler method."""
  430. try:
  431. self._send_closing_frame(True)
  432. self.socket.shutdown(True)
  433. except SocketError as e:
  434. if e.errno != errno.ENOTCONN:
  435. self.log.write('{ctx} socket shutdown error: {e}'.format(ctx=self.log_context, e=e))
  436. finally:
  437. self.socket.close()
  438. class ConnectionClosedError(Exception):
  439. pass
  440. class FailedConnectionError(Exception):
  441. def __init__(self, status, message):
  442. super(FailedConnectionError, self).__init__(status, message)
  443. self.message = message
  444. self.status = status
  445. class ProtocolError(ValueError):
  446. pass
  447. class RFC6455WebSocket(WebSocket):
  448. def __init__(self, sock, environ, version=13, protocol=None, client=False, extensions=None):
  449. super(RFC6455WebSocket, self).__init__(sock, environ, version)
  450. self.iterator = self._iter_frames()
  451. self.client = client
  452. self.protocol = protocol
  453. self.extensions = extensions or {}
  454. self._deflate_enc = None
  455. self._deflate_dec = None
  456. class UTF8Decoder(object):
  457. def __init__(self):
  458. if utf8validator:
  459. self.validator = utf8validator.Utf8Validator()
  460. else:
  461. self.validator = None
  462. decoderclass = codecs.getincrementaldecoder('utf8')
  463. self.decoder = decoderclass()
  464. def reset(self):
  465. if self.validator:
  466. self.validator.reset()
  467. self.decoder.reset()
  468. def decode(self, data, final=False):
  469. if self.validator:
  470. valid, eocp, c_i, t_i = self.validator.validate(data)
  471. if not valid:
  472. raise ValueError('Data is not valid unicode')
  473. return self.decoder.decode(data, final)
  474. def _get_permessage_deflate_enc(self):
  475. options = self.extensions.get("permessage-deflate")
  476. if options is None:
  477. return None
  478. def _make():
  479. return zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED,
  480. -options.get("client_max_window_bits" if self.client
  481. else "server_max_window_bits",
  482. zlib.MAX_WBITS))
  483. if options.get("client_no_context_takeover" if self.client
  484. else "server_no_context_takeover"):
  485. # This option means we have to make a new one every time
  486. return _make()
  487. else:
  488. if self._deflate_enc is None:
  489. self._deflate_enc = _make()
  490. return self._deflate_enc
  491. def _get_permessage_deflate_dec(self, rsv1):
  492. options = self.extensions.get("permessage-deflate")
  493. if options is None or not rsv1:
  494. return None
  495. def _make():
  496. return zlib.decompressobj(-options.get("server_max_window_bits" if self.client
  497. else "client_max_window_bits",
  498. zlib.MAX_WBITS))
  499. if options.get("server_no_context_takeover" if self.client
  500. else "client_no_context_takeover"):
  501. # This option means we have to make a new one every time
  502. return _make()
  503. else:
  504. if self._deflate_dec is None:
  505. self._deflate_dec = _make()
  506. return self._deflate_dec
  507. def _get_bytes(self, numbytes):
  508. data = b''
  509. while len(data) < numbytes:
  510. d = self.socket.recv(numbytes - len(data))
  511. if not d:
  512. raise ConnectionClosedError()
  513. data = data + d
  514. return data
  515. class Message(object):
  516. def __init__(self, opcode, decoder=None, decompressor=None):
  517. self.decoder = decoder
  518. self.data = []
  519. self.finished = False
  520. self.opcode = opcode
  521. self.decompressor = decompressor
  522. def push(self, data, final=False):
  523. self.finished = final
  524. self.data.append(data)
  525. def getvalue(self):
  526. data = b"".join(self.data)
  527. if not self.opcode & 8 and self.decompressor:
  528. data = self.decompressor.decompress(data + b'\x00\x00\xff\xff')
  529. if self.decoder:
  530. data = self.decoder.decode(data, self.finished)
  531. return data
  532. @staticmethod
  533. def _apply_mask(data, mask, length=None, offset=0):
  534. if length is None:
  535. length = len(data)
  536. cnt = range(length)
  537. return b''.join(six.int2byte(six.indexbytes(data, i) ^ mask[(offset + i) % 4]) for i in cnt)
  538. def _handle_control_frame(self, opcode, data):
  539. if opcode == 8: # connection close
  540. if not data:
  541. status = 1000
  542. elif len(data) > 1:
  543. status = struct.unpack_from('!H', data)[0]
  544. if not status or status not in VALID_CLOSE_STATUS:
  545. raise FailedConnectionError(
  546. 1002,
  547. "Unexpected close status code.")
  548. try:
  549. data = self.UTF8Decoder().decode(data[2:], True)
  550. except (UnicodeDecodeError, ValueError):
  551. raise FailedConnectionError(
  552. 1002,
  553. "Close message data should be valid UTF-8.")
  554. else:
  555. status = 1002
  556. self.close(close_data=(status, ''))
  557. raise ConnectionClosedError()
  558. elif opcode == 9: # ping
  559. self.send(data, control_code=0xA)
  560. elif opcode == 0xA: # pong
  561. pass
  562. else:
  563. raise FailedConnectionError(
  564. 1002, "Unknown control frame received.")
  565. def _iter_frames(self):
  566. fragmented_message = None
  567. try:
  568. while True:
  569. message = self._recv_frame(message=fragmented_message)
  570. if message.opcode & 8:
  571. self._handle_control_frame(
  572. message.opcode, message.getvalue())
  573. continue
  574. if fragmented_message and message is not fragmented_message:
  575. raise RuntimeError('Unexpected message change.')
  576. fragmented_message = message
  577. if message.finished:
  578. data = fragmented_message.getvalue()
  579. fragmented_message = None
  580. yield data
  581. except FailedConnectionError:
  582. exc_typ, exc_val, exc_tb = sys.exc_info()
  583. self.close(close_data=(exc_val.status, exc_val.message))
  584. except ConnectionClosedError:
  585. return
  586. except Exception:
  587. self.close(close_data=(1011, 'Internal Server Error'))
  588. raise
  589. def _recv_frame(self, message=None):
  590. recv = self._get_bytes
  591. # Unpacking the frame described in Section 5.2 of RFC6455
  592. # (https://tools.ietf.org/html/rfc6455#section-5.2)
  593. header = recv(2)
  594. a, b = struct.unpack('!BB', header)
  595. finished = a >> 7 == 1
  596. rsv123 = a >> 4 & 7
  597. rsv1 = rsv123 & 4
  598. if rsv123:
  599. if rsv1 and "permessage-deflate" not in self.extensions:
  600. # must be zero - unless it's compressed then rsv1 is true
  601. raise FailedConnectionError(
  602. 1002,
  603. "RSV1, RSV2, RSV3: MUST be 0 unless an extension is"
  604. " negotiated that defines meanings for non-zero values.")
  605. opcode = a & 15
  606. if opcode not in (0, 1, 2, 8, 9, 0xA):
  607. raise FailedConnectionError(1002, "Unknown opcode received.")
  608. masked = b & 128 == 128
  609. if not masked and not self.client:
  610. raise FailedConnectionError(1002, "A client MUST mask all frames"
  611. " that it sends to the server")
  612. length = b & 127
  613. if opcode & 8:
  614. if not finished:
  615. raise FailedConnectionError(1002, "Control frames must not"
  616. " be fragmented.")
  617. if length > 125:
  618. raise FailedConnectionError(
  619. 1002,
  620. "All control frames MUST have a payload length of 125"
  621. " bytes or less")
  622. elif opcode and message:
  623. raise FailedConnectionError(
  624. 1002,
  625. "Received a non-continuation opcode within"
  626. " fragmented message.")
  627. elif not opcode and not message:
  628. raise FailedConnectionError(
  629. 1002,
  630. "Received continuation opcode with no previous"
  631. " fragments received.")
  632. if length == 126:
  633. length = struct.unpack('!H', recv(2))[0]
  634. elif length == 127:
  635. length = struct.unpack('!Q', recv(8))[0]
  636. if masked:
  637. mask = struct.unpack('!BBBB', recv(4))
  638. received = 0
  639. if not message or opcode & 8:
  640. decoder = self.UTF8Decoder() if opcode == 1 else None
  641. decompressor = self._get_permessage_deflate_dec(rsv1)
  642. message = self.Message(opcode, decoder=decoder, decompressor=decompressor)
  643. if not length:
  644. message.push(b'', final=finished)
  645. else:
  646. while received < length:
  647. d = self.socket.recv(length - received)
  648. if not d:
  649. raise ConnectionClosedError()
  650. dlen = len(d)
  651. if masked:
  652. d = self._apply_mask(d, mask, length=dlen, offset=received)
  653. received = received + dlen
  654. try:
  655. message.push(d, final=finished)
  656. except (UnicodeDecodeError, ValueError):
  657. raise FailedConnectionError(
  658. 1007, "Text data must be valid utf-8")
  659. return message
  660. def _pack_message(self, message, masked=False,
  661. continuation=False, final=True, control_code=None):
  662. is_text = False
  663. if isinstance(message, six.text_type):
  664. message = message.encode('utf-8')
  665. is_text = True
  666. compress_bit = 0
  667. compressor = self._get_permessage_deflate_enc()
  668. if message and compressor:
  669. message = compressor.compress(message)
  670. message += compressor.flush(zlib.Z_SYNC_FLUSH)
  671. assert message[-4:] == b"\x00\x00\xff\xff"
  672. message = message[:-4]
  673. compress_bit = 1 << 6
  674. length = len(message)
  675. if not length:
  676. # no point masking empty data
  677. masked = False
  678. if control_code:
  679. if control_code not in (8, 9, 0xA):
  680. raise ProtocolError('Unknown control opcode.')
  681. if continuation or not final:
  682. raise ProtocolError('Control frame cannot be a fragment.')
  683. if length > 125:
  684. raise ProtocolError('Control frame data too large (>125).')
  685. header = struct.pack('!B', control_code | 1 << 7)
  686. else:
  687. opcode = 0 if continuation else ((1 if is_text else 2) | compress_bit)
  688. header = struct.pack('!B', opcode | (1 << 7 if final else 0))
  689. lengthdata = 1 << 7 if masked else 0
  690. if length > 65535:
  691. lengthdata = struct.pack('!BQ', lengthdata | 127, length)
  692. elif length > 125:
  693. lengthdata = struct.pack('!BH', lengthdata | 126, length)
  694. else:
  695. lengthdata = struct.pack('!B', lengthdata | length)
  696. if masked:
  697. # NOTE: RFC6455 states:
  698. # A server MUST NOT mask any frames that it sends to the client
  699. rand = Random(time.time())
  700. mask = [rand.getrandbits(8) for _ in six.moves.xrange(4)]
  701. message = RFC6455WebSocket._apply_mask(message, mask, length)
  702. maskdata = struct.pack('!BBBB', *mask)
  703. else:
  704. maskdata = b''
  705. return b''.join((header, lengthdata, maskdata, message))
  706. def wait(self):
  707. for i in self.iterator:
  708. return i
  709. def _send(self, frame):
  710. self._sendlock.acquire()
  711. try:
  712. self.socket.sendall(frame)
  713. finally:
  714. self._sendlock.release()
  715. def send(self, message, **kw):
  716. kw['masked'] = self.client
  717. payload = self._pack_message(message, **kw)
  718. self._send(payload)
  719. def _send_closing_frame(self, ignore_send_errors=False, close_data=None):
  720. if self.version in (8, 13) and not self.websocket_closed:
  721. if close_data is not None:
  722. status, msg = close_data
  723. if isinstance(msg, six.text_type):
  724. msg = msg.encode('utf-8')
  725. data = struct.pack('!H', status) + msg
  726. else:
  727. data = ''
  728. try:
  729. self.send(data, control_code=8)
  730. except SocketError:
  731. # Sometimes, like when the remote side cuts off the connection,
  732. # we don't care about this.
  733. if not ignore_send_errors: # pragma NO COVER
  734. raise
  735. self.websocket_closed = True
  736. def close(self, close_data=None):
  737. """Forcibly close the websocket; generally it is preferable to
  738. return from the handler method."""
  739. try:
  740. self._send_closing_frame(close_data=close_data, ignore_send_errors=True)
  741. self.socket.shutdown(socket.SHUT_WR)
  742. except SocketError as e:
  743. if e.errno != errno.ENOTCONN:
  744. self.log.write('{ctx} socket shutdown error: {e}'.format(ctx=self.log_context, e=e))
  745. finally:
  746. self.socket.close()