Quellcode durchsuchen

[spark] Switch to Impala API for Impala snippets

Romain Rigaux vor 11 Jahren
Ursprung
Commit
e6056f51f8
1 geänderte Dateien mit 16 neuen und 5 gelöschten Zeilen
  1. 16 5
      apps/spark/src/spark/models.py

+ 16 - 5
apps/spark/src/spark/models.py

@@ -23,6 +23,7 @@ from beeswax.design import hql_query
 from beeswax.models import QUERY_TYPES, HiveServerQueryHandle, QueryHistory
 from beeswax.views import safe_get_design, save_design
 from beeswax.server import dbms
+from beeswax.server.dbms import get_query_server_config
 
 from spark.job_server_api import get_api as get_spark_api
 from desktop.lib.i18n import smart_str
@@ -91,6 +92,16 @@ class HS2Api():
     snippet['result']['handle']['secret'], snippet['result']['handle']['guid'] = HiveServerQueryHandle.get_decoded(snippet['result']['handle']['secret'], snippet['result']['handle']['guid'])
     return HiveServerQueryHandle(**snippet['result']['handle'])
     
+  def _get_db(self, snippet):
+    if snippet['type'] == 'hive':
+      name = 'beeswax'
+    elif snippet['type'] == 'impala':
+      name = 'impala'
+    else:
+      name = 'spark-sql'
+      
+    return dbms.get(self.user, query_server=get_query_server_config(name=name))
+    
   def create_session(self, lang):
     return {
         'type': lang,
@@ -98,7 +109,7 @@ class HS2Api():
     }
   
   def execute(self, notebook, snippet):
-    db = dbms.get(self.user)
+    db = self._get_db(snippet)
     query = hql_query(snippet['statement'], QUERY_TYPES[0])
     handle = db.client.query(query)
     
@@ -126,7 +137,7 @@ class HS2Api():
     }    
 
   def check_status(self, notebook, snippet):
-    db = dbms.get(self.user)
+    db = self._get_db(snippet)
       
     handle = self._get_handle(snippet)
     status =  db.get_state(handle)
@@ -141,7 +152,7 @@ class HS2Api():
     }
 
   def fetch_result(self, notebook, snippet, rows):
-    db = dbms.get(self.user)
+    db = self._get_db(snippet)
       
     handle = self._get_handle(snippet)
     results = db.fetch(handle, start_over=False, rows=rows)
@@ -160,14 +171,14 @@ class HS2Api():
     pass 
 
   def cancel(self, notebook, snippet):
-    db = dbms.get(self.user)
+    db = self._get_db(snippet)
       
     handle = self._get_handle(snippet)
     db.cancel_operation(handle)
     return {'status': 'canceled'}    
 
   def get_log(self, snippet):
-    db = dbms.get(self.user)
+    db = self._get_db(snippet)
       
     handle = self._get_handle(snippet)    
     return db.get_log(handle)