Explorar o código

HUE-1297 [beeswax] Support other databases than default for parameterized queries

Romain Rigaux %!s(int64=12) %!d(string=hai) anos
pai
achega
f956e857df

+ 1 - 1
apps/beeswax/src/beeswax/server/dbms.py

@@ -81,7 +81,7 @@ def get_query_server_config(name='beeswax'):
         'server_interface': SERVER_INTERFACE.get(),
         'principal': kerberos_principal
     }
-    LOG.debug("Query Server:\n\tName: %(server_name)s\n\tHost: %(server_host)s\n\tPort: %(server_port)s\n\tInterface: %(server_interface)s\n\tKerberos Principal: %(principal)s" % query_server)
+    LOG.debug("Query Server: %s" % query_server)
 
   return query_server
 

+ 14 - 7
apps/beeswax/src/beeswax/tests.py

@@ -377,22 +377,29 @@ for x in sys.stdin:
     assert_true("parameterization.mako", response.template)
 
     # Now fill it out
-    response = self.client.post("/beeswax/execute_parameterized/%d" % design_id,
-      { "parameterization-x": str(1), "parameterization-y": str(2)}, follow=True)
-
+    response = self.client.post("/beeswax/execute_parameterized/%d" % design_id, {
+                                "parameterization-x": str(1), "parameterization-y": str(2)}, follow=True)
     assert_true("watch_wait.mako" in response.template)
+
     # Check that substitution happened!
-    assert_equal("SELECT foo FROM test WHERE foo='1' and bar='2'",
-      response.context["query"].query)
+    assert_equal("SELECT foo FROM test WHERE foo='1' and bar='2'", response.context["query"].query)
 
     # Check that error handling is reasonable
     response = self.client.post("/beeswax/execute_parameterized/%d" % design_id,
-      { "parameterization-x": "'_this_is_not SQL ", "parameterization-y": str(2) },
-      follow=True)
+                                {"parameterization-x": "'_this_is_not SQL ", "parameterization-y": str(2)},
+                                follow=True)
     assert_true("execute.mako" in response.template)
     log = response.context["log"]
     assert_true(search_log_line('ql.Driver', 'FAILED: ParseException', log), log)
 
+    # Check multi DB with a non default DB
+    response = _make_query(self.client, "SELECT foo FROM test WHERE foo='$x' and bar='$y'", database='other_db')
+    assert_true("parameterization.mako", response.template)
+    design_id = response.context["design"].id
+    response = self.client.post("/beeswax/execute_parameterized/%d" % design_id, {
+                                "parameterization-x": str(1), "parameterization-y": str(2)}, follow=True)
+    assert_equal('other_db', response.context['design'].get_design().query['database'])
+
   def test_explain_query(self):
     c = self.client
     response = _make_query(c, "SELECT KITTENS ARE TASTY", submission_type="Explain")

+ 14 - 3
apps/beeswax/src/beeswax/views.py

@@ -382,8 +382,7 @@ def execute_query(request, design_id=None):
 
   query_server = get_query_server_config(app_name)
   db = dbms.get(request.user, query_server)
-  dbs = db.get_databases()
-  databases = ((db, db) for db in dbs)
+  databases = _get_db_choices(request)
 
   if request.method == 'POST':
     form.bind(request.POST)
@@ -1073,8 +1072,13 @@ def _run_parameterized_query(request, design_id, explain):
   query_form = QueryForm()
   params = design_obj.get_query_dict()
   params.update(request.POST)
+
+  databases = _get_db_choices(request)
   query_form.bind(params)
-  assert query_form.is_valid()
+  query_form.query.fields['database'].choices = databases # Could not do it in the form
+
+  if not query_form.is_valid():
+    raise PopupException(_("Query form is invalid: %s") % query_form.errors)
 
   query_str = query_form.query.cleaned_data["query"]
   query_server = get_query_server_config(get_app_name(request))
@@ -1422,6 +1426,13 @@ def _update_query_state(query_history):
     query_history.save_state(state_enum)
   return True
 
+def _get_db_choices(request):
+  app_name = get_app_name(request)
+  query_server = get_query_server_config(app_name)
+  db = dbms.get(request.user, query_server)
+  dbs = db.get_databases()
+  return ((db, db) for db in dbs)
+
 WHITESPACE = re.compile("\s+", re.MULTILINE)
 def collapse_whitespace(s):
   return WHITESPACE.sub(" ", s).strip()