sqlflow.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # Licensed to Cloudera, Inc. under one
  4. # or more contributor license agreements. See the NOTICE file
  5. # distributed with this work for additional information
  6. # regarding copyright ownership. Cloudera, Inc. licenses this file
  7. # to you under the Apache License, Version 2.0 (the
  8. # "License"); you may not use this file except in compliance
  9. # with the License. You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing, software
  14. # distributed under the License is distributed on an "AS IS" BASIS,
  15. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. # See the License for the specific language governing permissions and
  17. # limitations under the License.
  18. from __future__ import absolute_import
  19. import logging
  20. import os
  21. import sqlflow
  22. from sqlflow.rows import Rows
  23. from desktop.lib.i18n import force_unicode
  24. from notebook.connectors.base import Api, QueryError
  25. from notebook.decorators import rewrite_ssh_api_url, ssh_error_handler
  26. from notebook.models import escape_rows
  27. LOG = logging.getLogger()
  28. def query_error_handler(func):
  29. def decorator(*args, **kwargs):
  30. try:
  31. return func(*args, **kwargs)
  32. except Exception as e:
  33. message = force_unicode(str(e))
  34. raise QueryError(message)
  35. return decorator
  36. class SqlFlowApi(Api):
  37. def __init__(self, user, interpreter=None):
  38. Api.__init__(self, user, interpreter=interpreter)
  39. self.options = interpreter['options']
  40. self.url = self.options['url']
  41. if self.options.get('has_ssh'):
  42. self.url = rewrite_ssh_api_url(self.url)['url']
  43. def _get_db(self):
  44. os.environ['SQLFLOW_DATASOURCE'] = self.interpreter['options']['datasource']
  45. return sqlflow.Client(server_url='172.18.1.3:50051') # TODO Send as param instead of ENV
  46. @query_error_handler
  47. @ssh_error_handler
  48. def execute(self, notebook, snippet):
  49. statement = snippet['statement']
  50. statement = statement.replace('LIMIT 5000', '')
  51. result = self._execute(statement)
  52. has_result_set = len(result['data']) > 0
  53. return {
  54. 'sync': True,
  55. 'has_result_set': has_result_set,
  56. 'result': {
  57. 'has_more': False,
  58. 'data': result['data'] if has_result_set else [],
  59. 'meta': [{
  60. 'name': col[0],
  61. 'type': col[1],
  62. 'comment': col[2]
  63. }
  64. for col in result['description']
  65. ]
  66. if has_result_set else [],
  67. 'type': 'table'
  68. }
  69. }
  70. def _execute(self, statement):
  71. db = self._get_db()
  72. compound_message = db.execute(statement)
  73. data = []
  74. description = []
  75. if compound_message:
  76. for r in compound_message._messages:
  77. if isinstance(r[0], Rows):
  78. description = [(c, '', '') for c in r[0].column_names()]
  79. data.extend([r for r in r[0].rows()])
  80. else:
  81. description = ['']
  82. data.extend([r for r in r[0].rows()])
  83. else:
  84. # Need to grab from sqlflow.client logs
  85. pass
  86. return {
  87. 'data': data,
  88. 'description': description,
  89. }
  90. @query_error_handler
  91. def check_status(self, notebook, snippet):
  92. return {'status': 'available'}
  93. @query_error_handler
  94. @ssh_error_handler
  95. def autocomplete(self, snippet, database=None, table=None, column=None, nested=None, operation=None):
  96. response = {}
  97. if database is None:
  98. response['databases'] = self._execute('SHOW DATABASES')['data']
  99. elif table is None:
  100. response['tables_meta'] = [
  101. {'name': t[0], 'type': '', 'comment': ''}
  102. for t in self._execute('SHOW TABLES in %s' % database)['data']
  103. ]
  104. elif column is None:
  105. columns = self._execute('DESCRIBE %s.%s' % (database, table))['data']
  106. response['columns'] = [col[0] for col in columns]
  107. response['extended_columns'] = [{
  108. 'comment': col[2],
  109. 'name': col[0],
  110. 'type': col[1]
  111. }
  112. for col in columns
  113. ]
  114. return response
  115. @query_error_handler
  116. def get_sample_data(self, snippet, database=None, table=None, column=None, nested=None, is_async=False, operation=None):
  117. result = self._execute('SELECT * FROM %s.%s LIMIT 10' % (database, table))
  118. response = {
  119. 'status': 0,
  120. }
  121. response['rows'] = escape_rows(result['data'])
  122. response['full_headers'] = [{
  123. 'name': col,
  124. 'type': 'STRING_TYPE',
  125. 'comment': ''
  126. }
  127. for col in result['description']
  128. ]
  129. return response
  130. def fetch_result(self, notebook, snippet, rows, start_over):
  131. """Only called at the end of a live query."""
  132. return {
  133. 'has_more': False,
  134. 'data': [],
  135. 'meta': []
  136. }