mysql_lib.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. #!/usr/bin/env python
  2. # Licensed to Cloudera, Inc. under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. Cloudera, Inc. licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import logging
  18. try:
  19. import MySQLdb as Database
  20. except ImportError, e:
  21. from django.core.exceptions import ImproperlyConfigured
  22. raise ImproperlyConfigured("Error loading MySQLdb module: %s" % e)
  23. # We want version (1, 2, 1, 'final', 2) or later. We can't just use
  24. # lexicographic ordering in this check because then (1, 2, 1, 'gamma')
  25. # inadvertently passes the version test.
  26. version = Database.version_info
  27. if (version < (1,2,1) or (version[:3] == (1, 2, 1) and
  28. (len(version) < 5 or version[3] != 'final' or version[4] < 2))):
  29. from django.core.exceptions import ImproperlyConfigured
  30. raise ImproperlyConfigured("MySQLdb-1.2.1p2 or newer is required; you have %s" % Database.__version__)
  31. from librdbms.server.rdbms_base_lib import BaseRDBMSDataTable, BaseRDBMSResult, BaseRDMSClient
  32. LOG = logging.getLogger(__name__)
  33. class DataTable(BaseRDBMSDataTable): pass
  34. class Result(BaseRDBMSResult): pass
  35. class MySQLClient(BaseRDMSClient):
  36. """Same API as Beeswax"""
  37. data_table_cls = DataTable
  38. result_cls = Result
  39. def __init__(self, *args, **kwargs):
  40. super(MySQLClient, self).__init__(*args, **kwargs)
  41. self.connection = Database.connect(**self._conn_params)
  42. @property
  43. def _conn_params(self):
  44. params = {
  45. 'user': self.query_server['username'],
  46. 'passwd': self.query_server['password'],
  47. 'host': self.query_server['server_host'],
  48. 'port': self.query_server['server_port']
  49. }
  50. if self.query_server['options']:
  51. params.update(self.query_server['options'])
  52. if 'name' in self.query_server:
  53. params['db'] = self.query_server['name']
  54. return params
  55. def use(self, database):
  56. if 'db' in self._conn_params and self._conn_params['db'] != database:
  57. raise RuntimeError("Tried to use database %s when %s was specified." % (database, self._conn_params['db']))
  58. else:
  59. cursor = self.connection.cursor()
  60. cursor.execute("USE %s" % database)
  61. self.connection.commit()
  62. def execute_statement(self, statement):
  63. cursor = self.connection.cursor()
  64. cursor.execute(statement)
  65. self.connection.commit()
  66. if cursor.description:
  67. columns = [column[0] for column in cursor.description]
  68. else:
  69. columns = []
  70. return self.data_table_cls(cursor, columns)
  71. def get_databases(self):
  72. cursor = self.connection.cursor()
  73. cursor.execute("SHOW DATABASES")
  74. self.connection.commit()
  75. return [row[0] for row in cursor.fetchall()]
  76. def get_tables(self, database, table_names=[]):
  77. cursor = self.connection.cursor()
  78. cursor.execute("SHOW TABLES")
  79. self.connection.commit()
  80. return [row[0] for row in cursor.fetchall()]
  81. def get_columns(self, database, table):
  82. cursor = self.connection.cursor()
  83. cursor.execute("SHOW COLUMNS FROM %s.%s" % (database, table))
  84. self.connection.commit()
  85. return [row[0] for row in cursor.fetchall()]