Quellcode durchsuchen

[impala] Move specific nested column logic to impala lib

Romain Rigaux vor 10 Jahren
Ursprung
Commit
7bab221

+ 1 - 1
apps/beeswax/src/beeswax/api.py

@@ -120,7 +120,7 @@ def _autocomplete(db, database=None, table=None, column=None, nested=None):
           table_obj = db.get_table(database, table)
           table_obj = db.get_table(database, table)
           sample = db.get_sample(database, table_obj, column, nested)
           sample = db.get_sample(database, table_obj, column, nested)
           if sample:
           if sample:
-            sample = set([row[0] for row in db.get_sample(database, table_obj, column, nested).rows()])
+            sample = set([row[0] for row in sample.rows()])
             response['sample'] = sorted(list(sample))
             response['sample'] = sorted(list(sample))
       else:
       else:
         raise Exception('Could not find column `%s`.`%s`.`%s`' % (database, table, column))
         raise Exception('Could not find column `%s`.`%s`.`%s`' % (database, table, column))

+ 11 - 68
apps/beeswax/src/beeswax/server/dbms.py

@@ -36,13 +36,17 @@ from beeswax.design import hql_query
 from beeswax.hive_site import hiveserver2_use_ssl
 from beeswax.hive_site import hiveserver2_use_ssl
 from beeswax.models import QueryHistory, QUERY_TYPES
 from beeswax.models import QueryHistory, QUERY_TYPES
 
 
-
 LOG = logging.getLogger(__name__)
 LOG = logging.getLogger(__name__)
 
 
+try:
+  from impala.dbms import ImpalaDbms 
+except ImportError, e:
+  LOG.info('Impala app enabled: %s' % e)
+
+
 DBMS_CACHE = {}
 DBMS_CACHE = {}
 DBMS_CACHE_LOCK = threading.Lock()
 DBMS_CACHE_LOCK = threading.Lock()
 
 
-
 def get(user, query_server=None):
 def get(user, query_server=None):
   global DBMS_CACHE
   global DBMS_CACHE
   global DBMS_CACHE_LOCK
   global DBMS_CACHE_LOCK
@@ -149,40 +153,6 @@ class HiveServer2Dbms(object):
       cleaned = "*%s*" % identifier.strip().strip("*")
       cleaned = "*%s*" % identifier.strip().strip("*")
     return cleaned
     return cleaned
 
 
-
-  @classmethod
-  def get_impala_nested_select(cls, database, table, column, nested=None):
-    """
-    Given a column or nested type, return the corresponding SELECT and FROM clauses in Impala's nested-type syntax
-    """
-    select_tokens = [column]
-    from_tokens = [database, table]
-
-    if nested:
-      nested_tokens = nested.strip('/').split('/')
-      while nested_tokens:
-        token = nested_tokens.pop(0)
-        if token not in ['key', 'value', 'item']:
-          select_tokens.append(token)
-        else:
-          # if we encounter a reserved keyword, move current select_tokens to from_tokens and reset the select_tokens
-          from_tokens.extend(select_tokens)
-          select_tokens = []
-          # if reserved keyword is the last token, make it the only select_token, otherwise we ignore and continue
-          if not nested_tokens:
-            select_tokens = [token]
-
-    select_clause = '.'.join(select_tokens)
-    from_clause = '.'.join('`%s`' % token for token in from_tokens)
-    return select_clause, from_clause
-
-
-  @classmethod
-  def get_histogram_query(cls, database, table, column, nested=None):
-    select_clause, from_clause = cls.get_impala_nested_select(database, table, column, nested)
-    return 'SELECT histogram(%s) FROM %s' % (select_clause, from_clause)
-
-
   def get_databases(self, database_names='*'):
   def get_databases(self, database_names='*'):
     identifier = self.to_matching_wildcard(database_names)
     identifier = self.to_matching_wildcard(database_names)
 
 
@@ -247,34 +217,6 @@ class HiveServer2Dbms(object):
         return col
         return col
     return None
     return None
 
 
-
-  def get_histogram(self, database, table, column, nested=None):
-    """
-    Returns the results of an Impala SELECT histogram() FROM query for a given column or nested type.
-
-    Assumes that the column/nested type is scalar.
-    """
-    results = []
-
-    if self.server_name == 'impala':  # Currently histogram() is only supported by Impala
-      hql = self.get_histogram_query(database, table, column, nested)
-      query = hql_query(hql)
-      handle = self.execute_and_wait(query, timeout_sec=5.0)
-
-      if handle:
-        result = self.fetch(handle)
-        try:
-          histogram = list(result.rows())[0][0]  # actual histogram results is in first-and-only result row
-          unique_values = set(histogram.split(', '))
-          results = list(unique_values)
-        except IndexError, e:
-          LOG.warn('Failed to get histogram results, result set has unexpected format: %s' % smart_str(e))
-        finally:
-          self.close(handle)
-
-    return results
-
-
   def execute_query(self, query, design):
   def execute_query(self, query, design):
     return self.execute_and_watch(query, design=design)
     return self.execute_and_watch(query, design=design)
 
 
@@ -320,16 +262,17 @@ class HiveServer2Dbms(object):
 
 
   def get_sample(self, database, table, column=None, nested=None):
   def get_sample(self, database, table, column=None, nested=None):
     result = None
     result = None
+    hql = None
 
 
-    # No samples if it's a view (HUE-526)
     if not table.is_view:
     if not table.is_view:
 
 
       limit = min(100, BROWSE_PARTITIONED_TABLE_LIMIT.get())
       limit = min(100, BROWSE_PARTITIONED_TABLE_LIMIT.get())
 
 
-      if (column or nested) and self.server_name == 'impala':  # SELECT column or nested type
-          select_clause, from_clause = self.get_impala_nested_select(database, table.name, column, nested)
+      if column or nested: # Could do column for any type, then nested with partitions 
+        if self.server_name == 'impala':
+          select_clause, from_clause = ImpalaDbms.get_nested_select(database, table.name, column, nested)
           hql = 'SELECT %s FROM %s LIMIT %s' % (select_clause, from_clause, limit)
           hql = 'SELECT %s FROM %s LIMIT %s' % (select_clause, from_clause, limit)
-      elif not column and not nested:  # SELECT * FROM table
+      else:
         partition_query = ""
         partition_query = ""
         if table.partition_keys:
         if table.partition_keys:
           partitions = self.get_partitions(database, table, partition_spec=None, max_parts=1)
           partitions = self.get_partitions(database, table, partition_spec=None, max_parts=1)

+ 0 - 16
apps/beeswax/src/beeswax/tests.py

@@ -3030,19 +3030,3 @@ def test_apply_natural_sort():
                                                             {'name': 'test_2', 'comment': 'Test'},
                                                             {'name': 'test_2', 'comment': 'Test'},
                                                             {'name': 'test_100', 'comment': 'Test'},
                                                             {'name': 'test_100', 'comment': 'Test'},
                                                             {'name': 'test_200', 'comment': 'Test'}])
                                                             {'name': 'test_200', 'comment': 'Test'}])
-
-
-def test_get_impala_nested_select():
-  select_fn = dbms.HiveServer2Dbms.get_impala_nested_select
-
-  assert_equal(select_fn('default', 'customers', 'id', None), ('id', '`default`.`customers`'))
-  assert_equal(select_fn('default', 'customers', 'email_preferences', 'categories/promos/'),
-               ('email_preferences.categories.promos', '`default`.`customers`'))
-  assert_equal(select_fn('default', 'customers', 'addresses', 'key'),
-               ('key', '`default`.`customers`.`addresses`'))
-  assert_equal(select_fn('default', 'customers', 'addresses', 'value/street_1/'),
-               ('street_1', '`default`.`customers`.`addresses`'))
-  assert_equal(select_fn('default', 'customers', 'orders', 'item/order_date'),
-               ('order_date', '`default`.`customers`.`orders`'))
-  assert_equal(select_fn('default', 'customers', 'orders', 'item/items/item/product_id'),
-               ('product_id', '`default`.`customers`.`orders`.`items`'))

+ 86 - 0
apps/impala/src/impala/dbms.py

@@ -0,0 +1,86 @@
+#!/usr/bin/env python
+# Licensed to Cloudera, Inc. under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  Cloudera, Inc. licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from desktop.lib.i18n import smart_str
+
+from beeswax.design import hql_query
+
+
+LOG = logging.getLogger(__name__)
+
+
+class ImpalaDbms():
+
+  def get_histogram(self, database, table, column, nested=None):
+    """
+    Returns the results of an Impala SELECT histogram() FROM query for a given column or nested type.
+
+    Assumes that the column/nested type is scalar.
+    """
+    results = []
+
+    hql = self.get_histogram_query(database, table, column, nested)
+    query = hql_query(hql)
+    handle = self.execute_and_wait(query, timeout_sec=5.0)
+
+    if handle:
+      result = self.fetch(handle)
+      try:
+        histogram = list(result.rows())[0][0]  # actual histogram results is in first-and-only result row
+        unique_values = set(histogram.split(', '))
+        results = list(unique_values)
+      except IndexError, e:
+        LOG.warn('Failed to get histogram results, result set has unexpected format: %s' % smart_str(e))
+      finally:
+        self.close(handle)
+
+    return results
+
+  @classmethod
+  def get_nested_select(cls, database, table, column, nested=None):
+    """
+    Given a column or nested type, return the corresponding SELECT and FROM clauses in Impala's nested-type syntax
+    """
+    select_tokens = [column]
+    from_tokens = [database, table]
+
+    if nested:
+      nested_tokens = nested.strip('/').split('/')
+      while nested_tokens:
+        token = nested_tokens.pop(0)
+        if token not in ['key', 'value', 'item']:
+          select_tokens.append(token)
+        else:
+          # if we encounter a reserved keyword, move current select_tokens to from_tokens and reset the select_tokens
+          from_tokens.extend(select_tokens)
+          select_tokens = []
+          # if reserved keyword is the last token, make it the only select_token, otherwise we ignore and continue
+          if not nested_tokens:
+            select_tokens = [token]
+
+    select_clause = '.'.join(select_tokens)
+    from_clause = '.'.join('`%s`' % token for token in from_tokens)
+    return select_clause, from_clause
+
+
+  @classmethod
+  def get_histogram_query(cls, database, table, column, nested=None):
+    select_clause, from_clause = cls.get_nested_select(database, table, column, nested)
+    return 'SELECT histogram(%s) FROM %s' % (select_clause, from_clause)
+        

+ 18 - 3
apps/impala/src/impala/tests.py

@@ -17,9 +17,7 @@
 
 
 import json
 import json
 import logging
 import logging
-import os
 import re
 import re
-import sys
 
 
 import desktop.conf as desktop_conf
 import desktop.conf as desktop_conf
 
 
@@ -30,7 +28,7 @@ from django.contrib.auth.models import User
 from django.core.urlresolvers import reverse
 from django.core.urlresolvers import reverse
 
 
 from desktop.lib.django_test_util import make_logged_in_client
 from desktop.lib.django_test_util import make_logged_in_client
-from desktop.lib.test_utils import grant_access, add_to_group
+from desktop.lib.test_utils import add_to_group
 from desktop.models import Document
 from desktop.models import Document
 
 
 from beeswax.design import hql_query
 from beeswax.design import hql_query
@@ -42,6 +40,7 @@ from hadoop.pseudo_hdfs4 import get_db_prefix, is_live_cluster
 
 
 from impala import conf
 from impala import conf
 from impala.conf import SERVER_HOST
 from impala.conf import SERVER_HOST
+from impala.dbms import ImpalaDbms
 
 
 
 
 LOG = logging.getLogger(__name__)
 LOG = logging.getLogger(__name__)
@@ -294,3 +293,19 @@ def test_ssl_validate():
     finally:
     finally:
       for reset in resets:
       for reset in resets:
         reset()
         reset()
+
+
+class TestImpalaDbms():
+
+  def test_get_impala_nested_select(self):
+    assert_equal(ImpalaDbms.get_nested_select('default', 'customers', 'id', None), ('id', '`default`.`customers`'))
+    assert_equal(ImpalaDbms.get_nested_select('default', 'customers', 'email_preferences', 'categories/promos/'),
+                 ('email_preferences.categories.promos', '`default`.`customers`'))
+    assert_equal(ImpalaDbms.get_nested_select('default', 'customers', 'addresses', 'key'),
+                 ('key', '`default`.`customers`.`addresses`'))
+    assert_equal(ImpalaDbms.get_nested_select('default', 'customers', 'addresses', 'value/street_1/'),
+                 ('street_1', '`default`.`customers`.`addresses`'))
+    assert_equal(ImpalaDbms.get_nested_select('default', 'customers', 'orders', 'item/order_date'),
+                 ('order_date', '`default`.`customers`.`orders`'))
+    assert_equal(ImpalaDbms.get_nested_select('default', 'customers', 'orders', 'item/items/item/product_id'),
+                 ('product_id', '`default`.`customers`.`orders`.`items`'))