瀏覽代碼

[hplsql] modifying the get_statement according to hplsql mode

ayush.goyal 4 年之前
父節點
當前提交
e6f16397b2

+ 4 - 1
apps/beeswax/src/beeswax/server/hive_server2_lib.py

@@ -990,7 +990,10 @@ class HiveServerClient(object):
 
 
     # The query can override the default configuration
     # The query can override the default configuration
     configuration.update(self._get_query_configuration(query))
     configuration.update(self._get_query_configuration(query))
-    query_statement = query.get_query_statement(statement)
+    if HPLSQL.get() and self.query_server['server_name'] == 'beeswax':
+      query_statement = query.hql_query
+    else:
+      query_statement = query.get_query_statement(statement)
 
 
     return self.execute_async_statement(statement=query_statement, conf_overlay=configuration, session=session)
     return self.execute_async_statement(statement=query_statement, conf_overlay=configuration, session=session)
 
 

+ 21 - 1
desktop/libs/notebook/src/notebook/sql_utils.py

@@ -22,6 +22,7 @@ import os
 import re
 import re
 import sys
 import sys
 
 
+from beeswax.conf import HPLSQL
 from desktop.lib.i18n import smart_str
 from desktop.lib.i18n import smart_str
 
 
 if sys.version_info[0] > 2:
 if sys.version_info[0] > 2:
@@ -50,6 +51,22 @@ def get_statements(hql_query):
     })
     })
   return statements
   return statements
 
 
+def get_hplsql_statements(hplsql_query):
+  statements = []
+  statements.append(
+    {
+      'start': {
+        'row': 0,
+        'column': 0
+      },
+      'end': {
+        'row': 0,
+        'column': len(hplsql_query) - 1
+      },
+      'statement': strip_trailing_semicolon(hplsql_query.rstrip())
+    })
+  return statements
+
 def get_current_statement(snippet):
 def get_current_statement(snippet):
   # Multiquery, if not first statement or arrived to the last query
   # Multiquery, if not first statement or arrived to the last query
   should_close = False
   should_close = False
@@ -57,7 +74,10 @@ def get_current_statement(snippet):
   statement_id = handle.get('statement_id', 0)
   statement_id = handle.get('statement_id', 0)
   statements_count = handle.get('statements_count', 1)
   statements_count = handle.get('statements_count', 1)
 
 
-  statements = get_statements(snippet['statement'])
+  if HPLSQL.get() and snippet['dialect'] == 'hive':
+    statements = get_hplsql_statements(snippet['statement'])
+  else:
+    statements = get_statements(snippet['statement'])
 
 
   statement_id = min(statement_id, len(statements) - 1) # In case of removal of statements
   statement_id = min(statement_id, len(statements) - 1) # In case of removal of statements
   previous_statement_hash = compute_statement_hash(statements[statement_id]['statement'])
   previous_statement_hash = compute_statement_hash(statements[statement_id]['statement'])

+ 16 - 3
desktop/libs/notebook/src/notebook/sql_utils_tests.py

@@ -17,7 +17,7 @@
 # limitations under the License.
 # limitations under the License.
 
 
 from beeswax.design import hql_query
 from beeswax.design import hql_query
-from notebook.sql_utils import strip_trailing_semicolon, split_statements
+from notebook.sql_utils import strip_trailing_semicolon, split_statements, get_hplsql_statements
 
 
 from nose.tools import assert_equal
 from nose.tools import assert_equal
 
 
@@ -27,8 +27,14 @@ def test_split_statements():
   assert_equal(["select * where id == '10'"], hql_query("select * where id == '10'").statements)
   assert_equal(["select * where id == '10'"], hql_query("select * where id == '10'").statements)
   assert_equal(["select * where id == '10'"], hql_query("select * where id == '10';").statements)
   assert_equal(["select * where id == '10'"], hql_query("select * where id == '10';").statements)
   assert_equal(['select', "select * where id == '10;' limit 100"], hql_query("select; select * where id == '10;' limit 100;").statements)
   assert_equal(['select', "select * where id == '10;' limit 100"], hql_query("select; select * where id == '10;' limit 100;").statements)
-  assert_equal(['select', "select * where id == \"10;\" limit 100"], hql_query("select; select * where id == \"10;\" limit 100;").statements)
-  assert_equal(['select', "select * where id == '\"10;\"\"\"' limit 100"], hql_query("select; select * where id == '\"10;\"\"\"' limit 100;").statements)
+  assert_equal(
+    ['select', "select * where id == \"10;\" limit 100"],
+    hql_query("select; select * where id == \"10;\" limit 100;").statements
+  )
+  assert_equal(
+    ['select', "select * where id == '\"10;\"\"\"' limit 100"],
+    hql_query("select; select * where id == '\"10;\"\"\"' limit 100;").statements
+  )
 
 
 
 
 def teststrip_trailing_semicolon():
 def teststrip_trailing_semicolon():
@@ -42,3 +48,10 @@ def teststrip_trailing_semicolon():
   assert_equal("fo;o;", strip_trailing_semicolon("fo;o;;     "))
   assert_equal("fo;o;", strip_trailing_semicolon("fo;o;;     "))
   # No semicolons
   # No semicolons
   assert_equal("foo", strip_trailing_semicolon("foo"))
   assert_equal("foo", strip_trailing_semicolon("foo"))
+
+def test_get_hplsql_statements():
+  # Not spliting statements at semicolon
+  assert_equal(
+    "CREATE FUNCTION hello()\n RETURNS STRING\nBEGIN\n RETURN 'Hello, world';\nEND",
+    get_hplsql_statements("CREATE FUNCTION hello()\n RETURNS STRING\nBEGIN\n RETURN 'Hello, world';\nEND;")[0]['statement']
+  )