浏览代码

HUE-1764 [dbquery] use database options when creating a connection

all: update connection params.
     force keys and values from unicode to string.
postgresql: remove autocommit from params if it exists so
            that the dbquery app may handle it.
oracle: add connection params
sqlite: check_same_thread set to false to ensure connection is shareable.
Add test for sqlite to test options.
Abraham Elmahrek 12 年之前
父节点
当前提交
d7c1019

+ 22 - 2
apps/rdbms/src/rdbms/server/dbms.py

@@ -17,14 +17,34 @@
 
 
 import logging
 import logging
 
 
+from desktop.lib.i18n import smart_str
 from beeswax.models import QueryHistory
 from beeswax.models import QueryHistory
-
 from rdbms.conf import RDBMS
 from rdbms.conf import RDBMS
 
 
 
 
 LOG = logging.getLogger(__name__)
 LOG = logging.getLogger(__name__)
 
 
 
 
+def force_dict_to_strings(dictionary):
+  if not dictionary:
+    return dictionary
+
+  new_dict = {}
+  for k in dictionary:
+    new_key = smart_str(k)
+    if isinstance(dictionary[k], basestring):
+      # Strings should not be unicode.
+      new_dict[new_key] = smart_str(dictionary[k])
+    elif isinstance(dictionary[k], dict):
+      # Recursively force dicts to strings.
+      new_dict[new_key] = force_dict_to_strings(dictionary[k])
+    else:
+      # Normal objects, or other literals, should not be converted.
+      new_dict[new_key] = dictionary[k]
+
+  return new_dict
+
+
 def get(user, query_server=None):
 def get(user, query_server=None):
   if query_server is None:
   if query_server is None:
     query_server = get_query_server_config()
     query_server = get_query_server_config()
@@ -61,7 +81,7 @@ def get_query_server_config(server=None):
       'server_port': RDBMS[name].PORT.get(),
       'server_port': RDBMS[name].PORT.get(),
       'username': RDBMS[name].USER.get(),
       'username': RDBMS[name].USER.get(),
       'password': RDBMS[name].PASSWORD.get(),
       'password': RDBMS[name].PASSWORD.get(),
-      'password': RDBMS[name].PASSWORD.get(),
+      'options': force_dict_to_strings(RDBMS[name].OPTIONS.get()),
       'alias': name
       'alias': name
     }
     }
 
 

+ 3 - 0
apps/rdbms/src/rdbms/server/mysql_lib.py

@@ -64,6 +64,9 @@ class MySQLClient(BaseRDMSClient):
       'port': self.query_server['server_port']
       'port': self.query_server['server_port']
     }
     }
 
 
+    if self.query_server['options']:
+      params.update(self.query_server['options'])
+
     if 'name' in self.query_server:
     if 'name' in self.query_server:
       params['db'] = self.query_server['name']
       params['db'] = self.query_server['name']
 
 

+ 10 - 1
apps/rdbms/src/rdbms/server/oracle_lib.py

@@ -43,8 +43,17 @@ class OracleClient(BaseRDMSClient):
 
 
   def __init__(self, *args, **kwargs):
   def __init__(self, *args, **kwargs):
     super(OracleClient, self).__init__(*args, **kwargs)
     super(OracleClient, self).__init__(*args, **kwargs)
-    self.connection = Database.connect(self._conn_string, **{})
+    if self.__conn_params:
+      self.connection = Database.connect(self._conn_string, **self.__conn_params)
+    else:
+      self.connection = Database.connect(self._conn_string)
 
 
+  @property
+  def _conn_params(self):
+    if self.query_server['options']:
+      return self.query_server['options'].copy()
+    else:
+      return None
 
 
   @property
   @property
   def _conn_string(self):
   def _conn_string(self):

+ 7 - 0
apps/rdbms/src/rdbms/server/postgresql_lib.py

@@ -55,6 +55,13 @@ class PostgreSQLClient(BaseRDMSClient):
       'port': self.query_server['server_port'] == 0 and 5432 or self.query_server['server_port'],
       'port': self.query_server['server_port'] == 0 and 5432 or self.query_server['server_port'],
       'database': self.query_server['name']
       'database': self.query_server['name']
     }
     }
+
+    if self.query_server['options']:
+      params.update(self.query_server['options'])
+      # handle transaction commits manually.
+      if 'autocommit' in params:
+        del params['autocommit']
+
     return params
     return params
 
 
 
 

+ 9 - 1
apps/rdbms/src/rdbms/server/sqlite_lib.py

@@ -51,11 +51,19 @@ class SQLiteClient(BaseRDMSClient):
 
 
   @property
   @property
   def _conn_params(self):
   def _conn_params(self):
-    return {
+    params = {
       'database': self.query_server['name'],
       'database': self.query_server['name'],
       'detect_types': Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES,
       'detect_types': Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES,
     }
     }
 
 
+    if self.query_server['options']:
+      params.update(self.query_server['options'])
+
+    # Make sure connection is shareable.
+    params['check_same_thread'] = False
+
+    return params
+
 
 
   def use(self, database):
   def use(self, database):
     # Do nothing because SQLite has one database per path.
     # Do nothing because SQLite has one database per path.

+ 8 - 0
apps/rdbms/src/rdbms/tests.py

@@ -130,3 +130,11 @@ class TestAPI(TestSQLiteRdbmsBase):
     response = self.client.post(reverse('rdbms:api_execute_query'), data, follow=True)
     response = self.client.post(reverse('rdbms:api_execute_query'), data, follow=True)
     response_dict = json.loads(response.content)
     response_dict = json.loads(response.content)
     assert_equal(1, len(response_dict['results']['rows']), response_dict)
     assert_equal(1, len(response_dict['results']['rows']), response_dict)
+
+  def test_options(self):
+    finish = rdbms_conf.RDBMS['sqlitee'].OPTIONS.set_for_testing({'nonsensical': None})
+    try:
+      self.client.get(reverse('rdbms:api_tables', args=['sqlitee', self.database]))
+    except TypeError, e:
+      assert_true('nonsensical' in str(e), e)
+    finish()