sqlalchemy.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. from __future__ import unicode_literals
  5. from sqlalchemy.engine import default
  6. from sqlalchemy.sql import compiler
  7. from sqlalchemy import types
  8. import pydruid.db
  9. from pydruid.db import exceptions
  10. RESERVED_SCHEMAS = ['INFORMATION_SCHEMA']
  11. type_map = {
  12. 'char': types.String,
  13. 'varchar': types.String,
  14. 'float': types.Float,
  15. 'decimal': types.Float,
  16. 'real': types.Float,
  17. 'double': types.Float,
  18. 'boolean': types.Boolean,
  19. 'tinyint': types.BigInteger,
  20. 'smallint': types.BigInteger,
  21. 'integer': types.BigInteger,
  22. 'bigint': types.BigInteger,
  23. 'timestamp': types.TIMESTAMP,
  24. 'date': types.DATE,
  25. 'other': types.BLOB,
  26. }
  27. class UniversalSet(object):
  28. def __contains__(self, item):
  29. return True
  30. class DruidIdentifierPreparer(compiler.IdentifierPreparer):
  31. reserved_words = UniversalSet()
  32. class DruidCompiler(compiler.SQLCompiler):
  33. pass
  34. class DruidTypeCompiler(compiler.GenericTypeCompiler):
  35. def visit_REAL(self, type_, **kwargs):
  36. return "DOUBLE"
  37. def visit_NUMERIC(self, type_, **kwargs):
  38. return "LONG"
  39. visit_DECIMAL = visit_NUMERIC
  40. visit_INTEGER = visit_NUMERIC
  41. visit_SMALLINT = visit_NUMERIC
  42. visit_BIGINT = visit_NUMERIC
  43. visit_BOOLEAN = visit_NUMERIC
  44. visit_TIMESTAMP = visit_NUMERIC
  45. visit_DATE = visit_NUMERIC
  46. def visit_CHAR(self, type_, **kwargs):
  47. return "STRING"
  48. visit_NCHAR = visit_CHAR
  49. visit_VARCHAR = visit_CHAR
  50. visit_NVARCHAR = visit_CHAR
  51. visit_TEXT = visit_CHAR
  52. def visit_DATETIME(self, type_, **kwargs):
  53. raise exceptions.NotSupportedError('Type DATETIME is not supported')
  54. def visit_TIME(self, type_, **kwargs):
  55. raise exceptions.NotSupportedError('Type TIME is not supported')
  56. def visit_BINARY(self, type_, **kwargs):
  57. raise exceptions.NotSupportedError('Type BINARY is not supported')
  58. def visit_VARBINARY(self, type_, **kwargs):
  59. raise exceptions.NotSupportedError('Type VARBINARY is not supported')
  60. def visit_BLOB(self, type_, **kwargs):
  61. raise exceptions.NotSupportedError('Type BLOB is not supported')
  62. def visit_CLOB(self, type_, **kwargs):
  63. raise exceptions.NotSupportedError('Type CBLOB is not supported')
  64. def visit_NCLOB(self, type_, **kwargs):
  65. raise exceptions.NotSupportedError('Type NCBLOB is not supported')
  66. class DruidDialect(default.DefaultDialect):
  67. name = 'druid'
  68. scheme = 'http'
  69. driver = 'rest'
  70. preparer = DruidIdentifierPreparer
  71. statement_compiler = DruidCompiler
  72. type_compiler = DruidTypeCompiler
  73. supports_alter = False
  74. supports_pk_autoincrement = False
  75. supports_default_values = False
  76. supports_empty_insert = False
  77. supports_unicode_statements = True
  78. supports_unicode_binds = True
  79. returns_unicode_strings = True
  80. description_encoding = None
  81. supports_native_boolean = True
  82. @classmethod
  83. def dbapi(cls):
  84. return pydruid.db
  85. def create_connect_args(self, url):
  86. kwargs = {
  87. 'host': url.host,
  88. 'port': url.port or 8082,
  89. 'path': url.database,
  90. 'scheme': self.scheme,
  91. }
  92. return ([], kwargs)
  93. def get_schema_names(self, connection, **kwargs):
  94. # Each Druid datasource appears as a table in the "druid" schema. This
  95. # is also the default schema, so Druid datasources can be referenced as
  96. # either druid.dataSourceName or simply dataSourceName.
  97. result = connection.execute(
  98. 'SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA')
  99. return [
  100. row.SCHEMA_NAME for row in result
  101. if row.SCHEMA_NAME not in RESERVED_SCHEMAS
  102. ]
  103. def has_table(self, connection, table_name, schema=None):
  104. query = """
  105. SELECT COUNT(*) > 0 AS exists_
  106. FROM INFORMATION_SCHEMA.TABLES
  107. WHERE TABLE_NAME = '{table_name}'
  108. """.format(table_name=table_name)
  109. result = connection.execute(query)
  110. return result.fetchone().exists_
  111. def get_table_names(self, connection, schema=None, **kwargs):
  112. query = "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES"
  113. if schema:
  114. query = "{query} WHERE TABLE_SCHEMA = '{schema}'".format(
  115. query=query, schema=schema)
  116. result = connection.execute(query)
  117. return [row.TABLE_NAME for row in result]
  118. def get_view_names(self, connection, schema=None, **kwargs):
  119. return []
  120. def get_table_options(self, connection, table_name, schema=None, **kwargs):
  121. return {}
  122. def get_columns(self, connection, table_name, schema=None, **kwargs):
  123. query = """
  124. SELECT COLUMN_NAME,
  125. DATA_TYPE,
  126. IS_NULLABLE,
  127. COLUMN_DEFAULT
  128. FROM INFORMATION_SCHEMA.COLUMNS
  129. WHERE TABLE_NAME = '{table_name}'
  130. """.format(table_name=table_name)
  131. if schema:
  132. query = "{query} AND TABLE_SCHEMA = '{schema}'".format(
  133. query=query, schema=schema)
  134. result = connection.execute(query)
  135. return [
  136. {
  137. 'name': row.COLUMN_NAME,
  138. 'type': type_map[row.DATA_TYPE.lower()],
  139. 'nullable': get_is_nullable(row.IS_NULLABLE),
  140. 'default': get_default(row.COLUMN_DEFAULT),
  141. }
  142. for row in result
  143. ]
  144. def get_pk_constraint(self, connection, table_name, schema=None, **kwargs):
  145. return {'constrained_columns': [], 'name': None}
  146. def get_foreign_keys(self, connection, table_name, schema=None, **kwargs):
  147. return []
  148. def get_check_constraints(
  149. self,
  150. connection,
  151. table_name,
  152. schema=None,
  153. **kwargs
  154. ):
  155. return []
  156. def get_table_comment(self, connection, table_name, schema=None, **kwargs):
  157. return {'text': ''}
  158. def get_indexes(self, connection, table_name, schema=None, **kwargs):
  159. return []
  160. def get_unique_constraints(
  161. self,
  162. connection,
  163. table_name,
  164. schema=None,
  165. **kwargs
  166. ):
  167. return []
  168. def get_view_definition(
  169. self,
  170. connection,
  171. view_name,
  172. schema=None,
  173. **kwargs
  174. ):
  175. pass
  176. def do_rollback(self, dbapi_connection):
  177. pass
  178. def _check_unicode_returns(self, connection, additional_tests=None):
  179. return True
  180. def _check_unicode_description(self, connection):
  181. return True
  182. DruidHTTPDialect = DruidDialect
  183. class DruidHTTPSDialect(DruidDialect):
  184. scheme = 'https'
  185. def get_is_nullable(druid_is_nullable):
  186. # this should be 'YES' or 'NO'; we default to no
  187. return druid_is_nullable.lower() == 'yes'
  188. def get_default(druid_column_default):
  189. # currently unused, returns ''
  190. return str(druid_column_default) if druid_column_default != '' else None