浏览代码

HUE-8747 [editor] Support task multi statement execution

jdesjean 6 年之前
父节点
当前提交
d24364c600

+ 1 - 79
apps/beeswax/src/beeswax/design.py

@@ -30,6 +30,7 @@ from django import forms
 from django.forms import ValidationError
 from django.utils.translation import ugettext as _
 
+from notebook.sql_utils import split_statements, strip_trailing_semicolon
 from desktop.lib.django_forms import BaseSimpleFormSet, MultiForm
 from hadoop.cluster import get_hdfs
 
@@ -232,74 +233,6 @@ class HQLdesign(object):
     return not self.__eq__(other)
 
 
-# Note: Might be replaceable by sqlparse.split
-def split_statements(hql):
-  """
-  Split statements at semicolons ignoring the ones inside quotes and comments.
-  The comment symbols that come inside quotes should be ignored.
-  """
-  statements = []
-  current = ''
-  prev = ''
-  between_quotes = None
-  is_comment = None
-  start_row = 0
-  start_col = 0
-  end_row = 0
-  end_col = len(hql) - 1
-
-  if hql.find(';') in (-1, len(hql) - 1):
-    return [((start_row, start_col), (end_row, end_col), hql)]
-
-  lines = hql.splitlines()
-
-  for row, line in enumerate(lines):
-    end_col = 0
-    end_row = row
-
-    if start_row == row and line.strip() == '':  # ignore leading whitespace rows
-      start_row += 1
-    elif current.strip() == '':  # reset start_row
-      start_row = row
-      start_col = 0
-
-    for col, c in enumerate(line):
-      current += c
-
-      if c in ('"', "'") and prev != '\\' and is_comment is None:
-        if between_quotes == c:
-          between_quotes = None
-        elif between_quotes is None:
-          between_quotes = c
-      elif c == '-' and prev == '-' and between_quotes is None and is_comment is None:
-        is_comment = True
-      elif c == ';':
-        if between_quotes is None and is_comment is None:
-          current = current.strip()
-          # Strip off the trailing semicolon
-          current = current[:-1]
-          if len(current) > 1:
-            statements.append(((start_row, start_col), (row, col + 1), current))
-            start_col = col + 1
-          current = ''
-      # This character holds no significance if it was escaped within a string
-      if prev == '\\' and between_quotes is not None:
-        c = ''
-      prev = c
-      end_col = col
-
-    is_comment = None
-    prev = os.linesep
-
-    if current != '':
-      current += os.linesep
-
-  if current and current != ';':
-    current = current.strip()
-    statements.append(((start_row, start_col), (end_row, end_col+1), current))
-
-  return statements
-
 def normalize_form_dict(form, attr_list):
   """
   normalize_form_dict(form, attr_list) -> A dictionary of (attr, value)
@@ -355,14 +288,3 @@ def denormalize_formset_dict(data_dict_list, formset, attr_list):
 
   def __str__(self):
     return '%s: %s' % (self.__class__, self.query)
-
-
-_SEMICOLON_WHITESPACE = re.compile(";\s*$")
-
-def strip_trailing_semicolon(query):
-  """As a convenience, we remove trailing semicolons from queries."""
-  s = _SEMICOLON_WHITESPACE.split(query, 2)
-  if len(s) > 1:
-    assert len(s) == 2
-    assert s[1] == ''
-  return s[0]

+ 1 - 22
apps/beeswax/src/beeswax/tests.py

@@ -74,7 +74,7 @@ from beeswax.conf import HIVE_SERVER_HOST, AUTH_USERNAME, AUTH_PASSWORD, AUTH_PA
 from beeswax.views import collapse_whitespace, _save_design, parse_out_jobs
 from beeswax.test_base import make_query, wait_for_query_to_finish, verify_history, get_query_server_config,\
   fetch_query_result_data
-from beeswax.design import hql_query, strip_trailing_semicolon
+from beeswax.design import hql_query
 from beeswax.data_export import upload, download
 from beeswax.models import SavedQuery, QueryHistory, HQL, HIVE_SERVER2
 from beeswax.server import dbms
@@ -2184,19 +2184,6 @@ def test_history_page():
   do_view('q-user=:all')
 
 
-def teststrip_trailing_semicolon():
-  # Note that there are two queries (both an execute and an explain) scattered
-  # in this file that use semicolons all the way through.
-
-  # Single semicolon
-  assert_equal("foo", strip_trailing_semicolon("foo;\n"))
-  assert_equal("foo\n", strip_trailing_semicolon("foo\n;\n\n\n"))
-  # Multiple semicolons: strip only last one
-  assert_equal("fo;o;", strip_trailing_semicolon("fo;o;;     "))
-  # No semicolons
-  assert_equal("foo", strip_trailing_semicolon("foo"))
-
-
 def test_hadoop_extraction():
   sample_log = """
 Starting Job = job_201003191517_0002, Tracking URL = http://localhost:50030/jobdetails.jsp?jobid=job_201003191517_0002
@@ -2349,14 +2336,6 @@ def test_search_log_line():
   assert_false(search_log_line('FAILED: Parse Error', logs))
 
 
-def test_split_statements():
-  assert_equal([''], hql_query(";;;").statements)
-  assert_equal(["select * where id == '10'"], hql_query("select * where id == '10'").statements)
-  assert_equal(["select * where id == '10'"], hql_query("select * where id == '10';").statements)
-  assert_equal(['select', "select * where id == '10;' limit 100"], hql_query("select; select * where id == '10;' limit 100;").statements)
-  assert_equal(['select', "select * where id == \"10;\" limit 100"], hql_query("select; select * where id == \"10;\" limit 100;").statements)
-  assert_equal(['select', "select * where id == '\"10;\"\"\"' limit 100"], hql_query("select; select * where id == '\"10;\"\"\"' limit 100;").statements)
-
   query_with_comments = """--First query;
 select concat('--', name)  -- The '--' in quotes is not a comment
 where id = '10';

+ 2 - 3
desktop/libs/librdbms/src/librdbms/design.py

@@ -24,9 +24,8 @@ import logging
 import django.http
 from django.utils.translation import ugettext as _
 
-from beeswax.design import normalize_form_dict, denormalize_form_dict, strip_trailing_semicolon,\
-                           split_statements
-
+from beeswax.design import normalize_form_dict, denormalize_form_dict, split_statements
+from notebook.sql_utils import strip_trailing_semicolon
 
 LOG = logging.getLogger(__name__)
 

+ 12 - 2
desktop/libs/notebook/src/notebook/connectors/base.py

@@ -14,7 +14,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import json
 import logging
 import re
@@ -29,6 +28,7 @@ from desktop.lib.exceptions_renderable import PopupException
 from desktop.lib.i18n import smart_unicode
 
 from notebook.conf import get_ordered_interpreters, CONNECTORS
+from notebook.sql_utils import get_current_statement
 
 
 LOG = logging.getLogger(__name__)
@@ -530,6 +530,16 @@ class Api(object):
   def describe_database(self, notebook, snippet, database=None):
     return {}
 
+  def _get_current_statement(self, notebook, snippet):
+    should_close, resp = get_current_statement(snippet)
+    if should_close:
+      try:
+        self.close_statement(notebook, snippet)  # Close all the time past multi queries
+      except:
+        LOG.warn('Could not close previous multiquery query')
+
+    return resp
+
 def _get_snippet_name(notebook, unique=False, table_format=False):
   name = (('%(name)s' + ('-%(id)s' if unique else '') if notebook.get('name') else '%(type)s-%(id)s') % notebook)
   if table_format:
@@ -546,7 +556,7 @@ class ResultWrapper():
 
   def fetch(self, start_over=None, rows=None):
     if start_over:
-      if not self.snippet['result']['handle'] or not self.api.can_start_over(self.notebook, self.snippet):
+      if not self.snippet['result']['handle'] or not self.snippet['result']['handle'].get('guid') or not self.api.can_start_over(self.notebook, self.snippet):
         start_over = False
         handle = self.api.execute(self.notebook, self.snippet)
         self.snippet['result']['handle'] = handle

+ 10 - 84
desktop/libs/notebook/src/notebook/connectors/hiveserver2.py

@@ -15,15 +15,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import base64
 import binascii
 import copy
-import hashlib
 import logging
 import json
 import re
-import StringIO
-import struct
 import urllib
 
 from django.urls import reverse
@@ -33,7 +29,7 @@ from desktop.conf import USE_DEFAULT_CONFIGURATION
 from desktop.lib.conf import BoundConfig
 from desktop.lib.exceptions import StructuredException
 from desktop.lib.exceptions_renderable import PopupException
-from desktop.lib.i18n import force_unicode, smart_str
+from desktop.lib.i18n import force_unicode
 from desktop.lib.rest.http_client import RestException
 from desktop.lib.thrift_util import unpack_guid, unpack_guid_base64
 from desktop.models import DefaultConfiguration, Document2
@@ -51,7 +47,7 @@ try:
   from beeswax.api import _autocomplete, _get_sample_data
   from beeswax.conf import CONFIG_WHITELIST as hive_settings, DOWNLOAD_ROW_LIMIT, DOWNLOAD_BYTES_LIMIT
   from beeswax.data_export import upload
-  from beeswax.design import hql_query, strip_trailing_semicolon, split_statements
+  from beeswax.design import hql_query
   from beeswax.models import QUERY_TYPES, HiveServerQueryHandle, HiveServerQueryHistory, QueryHistory, Session
   from beeswax.server import dbms
   from beeswax.server.dbms import get_query_server_config, QueryServerException
@@ -244,7 +240,7 @@ class HS2Api(Api):
   def execute(self, notebook, snippet):
     db = self._get_db(snippet, cluster=self.cluster)
 
-    statement = self._get_current_statement(db, snippet)
+    statement = self._get_current_statement(notebook, snippet)
     session = self._get_session(notebook, snippet['type'])
 
     query = self._prepare_hql_query(snippet, statement['statement'], session)
@@ -452,7 +448,7 @@ class HS2Api(Api):
       document.can_read_or_exception(self.user)
       notebook = Notebook(document=document).get_data()
       snippet = notebook['snippets'][0]
-      query = self._get_current_statement(db, snippet)['statement']
+      query = self._get_current_statement(notebook, snippet)['statement']
       database, table = '', ''
 
     return _autocomplete(db, database, table, column, nested, query=query, cluster=self.cluster)
@@ -470,7 +466,7 @@ class HS2Api(Api):
   @query_error_handler
   def explain(self, notebook, snippet):
     db = self._get_db(snippet, cluster=self.cluster)
-    response = self._get_current_statement(db, snippet)
+    response = self._get_current_statement(notebook, snippet)
     session = self._get_session(notebook, snippet['type'])
 
     query = self._prepare_hql_query(snippet, response.pop('statement'), session)
@@ -505,7 +501,7 @@ class HS2Api(Api):
   def export_data_as_table(self, notebook, snippet, destination, is_temporary=False, location=None):
     db = self._get_db(snippet, cluster=self.cluster)
 
-    response = self._get_current_statement(db, snippet)
+    response = self._get_current_statement(notebook, snippet)
     session = self._get_session(notebook, snippet['type'])
     query = self._prepare_hql_query(snippet, response.pop('statement'), session)
 
@@ -527,9 +523,7 @@ class HS2Api(Api):
 
 
   def export_large_data_to_hdfs(self, notebook, snippet, destination):
-    db = self._get_db(snippet, cluster=self.cluster)
-
-    response = self._get_current_statement(db, snippet)
+    response = self._get_current_statement(notebook, snippet)
     session = self._get_session(notebook, snippet['type'])
     query = self._prepare_hql_query(snippet, response.pop('statement'), session)
 
@@ -561,9 +555,7 @@ DROP TABLE IF EXISTS `%(table)s`;
 
 
   def statement_risk(self, notebook, snippet):
-    db = self._get_db(snippet, cluster=self.cluster)
-
-    response = self._get_current_statement(db, snippet)
+    response = self._get_current_statement(notebook, snippet)
     query = response['statement']
 
     api = OptimizerApi(self.user)
@@ -572,9 +564,7 @@ DROP TABLE IF EXISTS `%(table)s`;
 
 
   def statement_compatibility(self, notebook, snippet, source_platform, target_platform):
-    db = self._get_db(snippet, cluster=self.cluster)
-
-    response = self._get_current_statement(db, snippet)
+    response = self._get_current_statement(notebook, snippet)
     query = response['statement']
 
     api = OptimizerApi(self.user)
@@ -583,9 +573,7 @@ DROP TABLE IF EXISTS `%(table)s`;
 
 
   def statement_similarity(self, notebook, snippet, source_platform):
-    db = self._get_db(snippet, cluster=self.cluster)
-
-    response = self._get_current_statement(db, snippet)
+    response = self._get_current_statement(notebook, snippet)
     query = response['statement']
 
     api = OptimizerApi(self.user)
@@ -645,68 +633,6 @@ DROP TABLE IF EXISTS `%(table)s`;
     return engine
 
 
-  def _get_statements(self, hql_query):
-    hql_query = strip_trailing_semicolon(hql_query)
-    hql_query_sio = StringIO.StringIO(hql_query)
-
-    statements = []
-    for (start_row, start_col), (end_row, end_col), statement in split_statements(hql_query_sio.read()):
-      statements.append({
-        'start': {
-          'row': start_row,
-          'column': start_col
-        },
-        'end': {
-          'row': end_row,
-          'column': end_col
-        },
-        'statement': strip_trailing_semicolon(statement.rstrip())
-      })
-    return statements
-
-
-  def _get_current_statement(self, db, snippet):
-    # Multiquery, if not first statement or arrived to the last query
-    statement_id = snippet['result']['handle'].get('statement_id', 0)
-    statements_count = snippet['result']['handle'].get('statements_count', 1)
-
-    statements = self._get_statements(snippet['statement'])
-
-    statement_id = min(statement_id, len(statements) - 1) # In case of removal of statements
-    previous_statement_hash = self.__compute_statement_hash(statements[statement_id]['statement'])
-    non_edited_statement = previous_statement_hash == snippet['result']['handle'].get('previous_statement_hash') or not snippet['result']['handle'].get('previous_statement_hash')
-
-    if snippet['result']['handle'].get('has_more_statements'):
-      try:
-        handle = self._get_handle(snippet)
-        db.close_operation(handle)  # Close all the time past multi queries
-      except:
-        LOG.warn('Could not close previous multiquery query')
-
-      if non_edited_statement:
-        statement_id += 1
-    else:
-      if non_edited_statement:
-        statement_id = 0
-
-    if statements_count != len(statements):
-      statement_id = min(statement_id, len(statements) - 1)
-
-    resp = {
-      'statement_id': statement_id,
-      'has_more_statements': statement_id < len(statements) - 1,
-      'statements_count': len(statements),
-      'previous_statement_hash': self.__compute_statement_hash(statements[statement_id]['statement'])
-    }
-
-    resp.update(statements[statement_id])
-    return resp
-
-
-  def __compute_statement_hash(self, statement):
-    return hashlib.sha224(smart_str(statement)).hexdigest()
-
-
   def _prepare_hql_query(self, snippet, statement, session):
     settings = snippet['properties'].get('settings', None)
     file_resources = snippet['properties'].get('files', None)

+ 155 - 0
desktop/libs/notebook/src/notebook/sql_utils.py

@@ -0,0 +1,155 @@
+#!/usr/bin/env python
+# Licensed to Cloudera, Inc. under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  Cloudera, Inc. licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import hashlib
+import os
+import re
+import StringIO
+
+from desktop.lib.i18n import smart_str
+
+# Note: Might be replaceable by sqlparse.split
+def get_statements(hql_query):
+  hql_query = strip_trailing_semicolon(hql_query)
+  hql_query_sio = StringIO.StringIO(hql_query)
+
+  statements = []
+  for (start_row, start_col), (end_row, end_col), statement in split_statements(hql_query_sio.read()):
+    statements.append({
+      'start': {
+        'row': start_row,
+        'column': start_col
+      },
+      'end': {
+        'row': end_row,
+        'column': end_col
+      },
+      'statement': strip_trailing_semicolon(statement.rstrip())
+    })
+  return statements
+
+def get_current_statement(snippet):
+  # Multiquery, if not first statement or arrived to the last query
+  statement_id = snippet['result']['handle'].get('statement_id', 0)
+  statements_count = snippet['result']['handle'].get('statements_count', 1)
+
+  statements = get_statements(snippet['statement'])
+
+  statement_id = min(statement_id, len(statements) - 1) # In case of removal of statements
+  previous_statement_hash = compute_statement_hash(statements[statement_id]['statement'])
+  non_edited_statement = previous_statement_hash == snippet['result']['handle'].get('previous_statement_hash') or not snippet['result']['handle'].get('previous_statement_hash')
+  should_close = False
+  if snippet['result']['handle'].get('has_more_statements'):
+    should_close = True
+    if non_edited_statement:
+      statement_id += 1
+  else:
+    if non_edited_statement:
+      statement_id = 0
+
+  if statements_count != len(statements):
+    statement_id = min(statement_id, len(statements) - 1)
+
+  resp = {
+    'statement_id': statement_id,
+    'has_more_statements': statement_id < len(statements) - 1,
+    'statements_count': len(statements),
+    'previous_statement_hash': compute_statement_hash(statements[statement_id]['statement'])
+  }
+
+  resp.update(statements[statement_id])
+  return should_close, resp
+
+
+def compute_statement_hash(statement):
+  return hashlib.sha224(smart_str(statement)).hexdigest()
+
+def split_statements(hql):
+  """
+  Split statements at semicolons ignoring the ones inside quotes and comments.
+  The comment symbols that come inside quotes should be ignored.
+  """
+  statements = []
+  current = ''
+  prev = ''
+  between_quotes = None
+  is_comment = None
+  start_row = 0
+  start_col = 0
+  end_row = 0
+  end_col = len(hql) - 1
+
+  if hql.find(';') in (-1, len(hql) - 1):
+    return [((start_row, start_col), (end_row, end_col), hql)]
+
+  lines = hql.splitlines()
+
+  for row, line in enumerate(lines):
+    end_col = 0
+    end_row = row
+
+    if start_row == row and line.strip() == '':  # ignore leading whitespace rows
+      start_row += 1
+    elif current.strip() == '':  # reset start_row
+      start_row = row
+      start_col = 0
+
+    for col, c in enumerate(line):
+      current += c
+
+      if c in ('"', "'") and prev != '\\' and is_comment is None:
+        if between_quotes == c:
+          between_quotes = None
+        elif between_quotes is None:
+          between_quotes = c
+      elif c == '-' and prev == '-' and between_quotes is None and is_comment is None:
+        is_comment = True
+      elif c == ';':
+        if between_quotes is None and is_comment is None:
+          current = current.strip()
+          # Strip off the trailing semicolon
+          current = current[:-1]
+          if len(current) > 1:
+            statements.append(((start_row, start_col), (row, col + 1), current))
+            start_col = col + 1
+          current = ''
+      # This character holds no significance if it was escaped within a string
+      if prev == '\\' and between_quotes is not None:
+        c = ''
+      prev = c
+      end_col = col
+
+    is_comment = None
+    prev = os.linesep
+
+    if current != '':
+      current += os.linesep
+
+  if current and current != ';':
+    current = current.strip()
+    statements.append(((start_row, start_col), (end_row, end_col+1), current))
+
+  return statements
+
+_SEMICOLON_WHITESPACE = re.compile(";\s*$")
+
+def strip_trailing_semicolon(query):
+  """As a convenience, we remove trailing semicolons from queries."""
+  s = _SEMICOLON_WHITESPACE.split(query, 2)
+  if len(s) > 1:
+    assert len(s) == 2
+    assert s[1] == ''
+  return s[0]

+ 41 - 0
desktop/libs/notebook/src/notebook/sql_utils_test.py

@@ -0,0 +1,41 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Licensed to Cloudera, Inc. under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  Cloudera, Inc. licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from notebook.sql_utils import strip_trailing_semicolon
+
+from nose.tools import assert_equal
+
+def test_split_statements():
+  assert_equal([''], ";;;")
+  assert_equal(["select * where id == '10'"], "select * where id == '10'")
+  assert_equal(["select * where id == '10'"], "select * where id == '10';")
+  assert_equal(['select', "select * where id == '10;' limit 100"], "select; select * where id == '10;' limit 100;")
+  assert_equal(['select', "select * where id == \"10;\" limit 100"], "select; select * where id == \"10;\" limit 100;")
+  assert_equal(['select', "select * where id == '\"10;\"\"\"' limit 100"], "select; select * where id == '\"10;\"\"\"' limit 100;")
+
+def teststrip_trailing_semicolon():
+  # Note that there are two queries (both an execute and an explain) scattered
+  # in this file that use semicolons all the way through.
+
+  # Single semicolon
+  assert_equal("foo", strip_trailing_semicolon("foo;\n"))
+  assert_equal("foo\n", strip_trailing_semicolon("foo\n;\n\n\n"))
+  # Multiple semicolons: strip only last one
+  assert_equal("fo;o;", strip_trailing_semicolon("fo;o;;     "))
+  # No semicolons
+  assert_equal("foo", strip_trailing_semicolon("foo"))

+ 9 - 2
desktop/libs/notebook/src/notebook/tasks.py

@@ -36,6 +36,7 @@ from desktop.conf import TASK_SERVER
 from desktop.lib import export_csvxls
 
 from notebook.connectors.base import get_api, QueryExpired, ResultWrapper
+from notebook.sql_utils import get_current_statement
 
 LOG_TASK = get_task_logger(__name__)
 LOG = logging.getLogger(__name__)
@@ -142,10 +143,15 @@ def _patch_status(notebook):
 
 def execute(*args, **kwargs):
   notebook = args[0]
+  snippet = args[1]
   kwargs['max_rows'] = TASK_SERVER.PREFETCH_RESULT_COUNT.get()
   _patch_status(notebook)
   download_to_file.apply_async(args=args, kwargs=kwargs, task_id=notebook['uuid'])
-  return {'sync': False,
+
+  should_close, resp = get_current_statement(snippet) # This redoes some of the work in api.execute. Other option is to pass statement, but then we'd have to modify notebook.api.
+  #if should_close: #front end already calls close_statement for multi statement execution no need to do here. In addition, we'd have to figure out what was the previous guid.
+
+  resp.update({'sync': False,
       'has_result_set': True,
       'modified_row_count': 0,
       'guid': '',
@@ -154,7 +160,8 @@ def execute(*args, **kwargs):
         'data': [],
         'meta': [],
         'type': 'table'
-      }}
+      }})
+  return resp
 
 def check_status(*args, **kwargs):
   notebook = args[0]