rdbms.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. #!/usr/bin/env python
  2. # Licensed to Cloudera, Inc. under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. Cloudera, Inc. licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import json
  18. import logging
  19. from desktop.lib import export_csvxls
  20. from desktop.lib.i18n import force_unicode
  21. from beeswax import data_export
  22. from librdbms.server import dbms
  23. from notebook.connectors.base import Api, QueryError, QueryExpired, _get_snippet_name
  24. LOG = logging.getLogger(__name__)
  25. def query_error_handler(func):
  26. def decorator(*args, **kwargs):
  27. try:
  28. return func(*args, **kwargs)
  29. except Exception, e:
  30. message = force_unicode(e)
  31. if 'Invalid query handle' in message or 'Invalid OperationHandle' in message:
  32. raise QueryExpired(e)
  33. else:
  34. raise QueryError(message)
  35. return decorator
  36. class RdbmsApi(Api):
  37. def _execute(self, notebook, snippet):
  38. query_server = dbms.get_query_server_config(server=self.interpreter)
  39. db = dbms.get(self.user, query_server)
  40. db.use(snippet['database']) # TODO: only do the use on the first statement in a multi query
  41. table = db.execute_statement(snippet['statement']) # TODO: execute statement stub in Rdbms
  42. return table
  43. @query_error_handler
  44. def execute(self, notebook, snippet):
  45. table = self._execute(notebook, snippet)
  46. data = list(table.rows())
  47. has_result_set = data is not None
  48. return {
  49. 'sync': True,
  50. 'has_result_set': has_result_set,
  51. 'modified_row_count': 0,
  52. 'result': {
  53. 'has_more': False,
  54. 'data': data if has_result_set else [],
  55. 'meta': [{
  56. 'name': col['name'] if type(col) is dict else col,
  57. 'type': col.get('type', '') if type(col) is dict else '',
  58. 'comment': ''
  59. } for col in table.columns_description] if has_result_set else [],
  60. 'type': 'table'
  61. }
  62. }
  63. @query_error_handler
  64. def check_status(self, notebook, snippet):
  65. return {'status': 'expired'}
  66. @query_error_handler
  67. def fetch_result(self, notebook, snippet, rows, start_over):
  68. return {
  69. 'has_more': False,
  70. 'data': [],
  71. 'meta': [],
  72. 'type': 'table'
  73. }
  74. @query_error_handler
  75. def fetch_result_metadata(self):
  76. pass
  77. @query_error_handler
  78. def cancel(self, notebook, snippet):
  79. return {'status': 0}
  80. @query_error_handler
  81. def get_log(self, notebook, snippet, startFrom=None, size=None):
  82. return 'No logs'
  83. @query_error_handler
  84. def download(self, notebook, snippet, format, user_agent=None):
  85. file_name = _get_snippet_name(notebook)
  86. results = self._execute(notebook, snippet)
  87. db = FixedResult(results)
  88. return data_export.download(None, format, db, id=snippet['id'], file_name=file_name, user_agent=user_agent)
  89. @query_error_handler
  90. def close_statement(self, snippet):
  91. return {'status': -1}
  92. @query_error_handler
  93. def autocomplete(self, snippet, database=None, table=None, column=None, nested=None):
  94. query_server = dbms.get_query_server_config(server=self.interpreter)
  95. db = dbms.get(self.user, query_server)
  96. assist = Assist(db)
  97. response = {'status': -1}
  98. if database is None:
  99. response['databases'] = assist.get_databases()
  100. elif table is None:
  101. tables_meta = []
  102. for t in assist.get_tables(database):
  103. tables_meta.append({'name': t, 'type': 'Table', 'comment': ''})
  104. response['tables_meta'] = tables_meta
  105. elif column is None:
  106. columns = assist.get_columns(database, table)
  107. response['columns'] = [col['name'] for col in columns]
  108. response['extended_columns'] = columns
  109. else:
  110. columns = assist.get_columns(database, table)
  111. response['name'] = next((col['name'] for col in columns if column == col['name']), '')
  112. response['type'] = next((col['type'] for col in columns if column == col['name']), '')
  113. response['status'] = 0
  114. return response
  115. @query_error_handler
  116. def get_sample_data(self, snippet, database=None, table=None, column=None, async=False):
  117. query_server = dbms.get_query_server_config(server=self.interpreter)
  118. db = dbms.get(self.user, query_server)
  119. assist = Assist(db)
  120. response = {'status': -1}
  121. sample_data = assist.get_sample_data(database, table, column)
  122. if sample_data:
  123. response['status'] = 0
  124. response['headers'] = sample_data.columns
  125. response['rows'] = list(sample_data.rows())
  126. else:
  127. response['message'] = _('Failed to get sample data.')
  128. return response
  129. @query_error_handler
  130. def get_browse_query(self, snippet, database, table, partition_spec=None):
  131. return "SELECT * FROM `%s`.`%s` LIMIT 1000" % (database, table)
  132. class Assist():
  133. def __init__(self, db):
  134. self.db = db
  135. def get_databases(self):
  136. return self.db.get_databases()
  137. def get_tables(self, database, table_names=[]):
  138. self.db.use(database)
  139. return self.db.get_tables(database, table_names)
  140. def get_columns(self, database, table):
  141. return self.db.get_columns(database, table, names_only=False)
  142. def get_sample_data(self, database, table, column=None):
  143. return self.db.get_sample_data(database, table, column)
  144. class FixedResult():
  145. def __init__(self, result):
  146. self.result = result
  147. self.has_more = False
  148. def fetch(self, handle=None, start_over=None, rows=None):
  149. return self.result