Jelajahi Sumber

HUE-6249 [optimizer] Adding Sentry filtering on API arguments

Romain Rigaux 8 tahun lalu
induk
melakukan
0b53f8c

+ 21 - 14
desktop/libs/metadata/src/metadata/optimizer_api.py

@@ -26,6 +26,7 @@ from django.views.decorators.http import require_POST
 from desktop.lib.django_util import JsonResponse
 from desktop.lib.i18n import force_unicode
 from desktop.models import Document2
+from libsentry.privilege_checker import MissingSentryPrivilegeException
 from notebook.api import _get_statement
 from notebook.models import Notebook
 
@@ -55,6 +56,12 @@ def error_handler(view_fn):
         'status': -1,
         'message': e.message
       }
+    except MissingSentryPrivilegeException, e:
+      LOG.exception(e)
+      response = {
+        'status': -1,
+        'message': 'Missing privileges for %s' % force_unicode(str(e))
+      }
     except Exception, e:
       LOG.exception(e)
       response = {
@@ -72,7 +79,7 @@ def get_tenant(request):
 
   email = request.POST.get('email')
 
-  api = OptimizerApi()
+  api = OptimizerApi(request.user)
   data = api.get_tenant(email=email)
 
   if data:
@@ -121,7 +128,7 @@ def table_details(request):
   database_name = request.POST.get('databaseName')
   table_name = request.POST.get('tableName')
 
-  api = OptimizerApi()
+  api = OptimizerApi(request.user)
 
   data = api.table_details(database_name=database_name, table_name=table_name)
 
@@ -143,7 +150,7 @@ def query_compatibility(request):
   target_platform = request.POST.get('targetPlatform')
   query = request.POST.get('query')
 
-  api = OptimizerApi()
+  api = OptimizerApi(request.user)
 
   data = api.query_compatibility(source_platform=source_platform, target_platform=target_platform, query=query)
 
@@ -165,7 +172,7 @@ def query_risk(request):
   source_platform = request.POST.get('sourcePlatform')
   db_name = request.POST.get('dbName')
 
-  api = OptimizerApi()
+  api = OptimizerApi(request.user)
 
   data = api.query_risk(query=query, source_platform=source_platform, db_name=db_name)
 
@@ -186,7 +193,7 @@ def similar_queries(request):
   source_platform = request.POST.get('sourcePlatform')
   query = json.loads(request.POST.get('query'))
 
-  api = OptimizerApi()
+  api = OptimizerApi(request.user)
 
   data = api.similar_queries(source_platform=source_platform, query=query)
 
@@ -205,9 +212,9 @@ def top_filters(request):
   response = {'status': -1}
 
   db_tables = json.loads(request.POST.get('dbTables'), '[]')
-  column_name = request.POST.get('columnName') # Unsused
+  column_name = request.POST.get('columnName') # Unused
 
-  api = OptimizerApi()
+  api = OptimizerApi(request.user)
   data = api.top_filters(db_tables=db_tables)
 
   if data:
@@ -226,7 +233,7 @@ def top_joins(request):
 
   db_tables = json.loads(request.POST.get('dbTables'), '[]')
 
-  api = OptimizerApi()
+  api = OptimizerApi(request.user)
   data = api.top_joins(db_tables=db_tables)
 
   if data:
@@ -245,7 +252,7 @@ def top_aggs(request):
 
   db_tables = json.loads(request.POST.get('dbTables'), '[]')
 
-  api = OptimizerApi()
+  api = OptimizerApi(request.user)
   data = api.top_aggs(db_tables=db_tables)
 
   if data:
@@ -262,7 +269,7 @@ def top_aggs(request):
 def top_databases(request):
   response = {'status': -1}
 
-  api = OptimizerApi()
+  api = OptimizerApi(request.user)
   data = api.top_databases()
 
   if data:
@@ -281,7 +288,7 @@ def top_columns(request):
 
   db_tables = json.loads(request.POST.get('dbTables'), '[]')
 
-  api = OptimizerApi()
+  api = OptimizerApi(request.user)
   data = api.top_columns(db_tables=db_tables)
 
   if data:
@@ -322,7 +329,7 @@ def upload_history(request):
 
   queries = _convert_queries([Notebook(document=doc).get_data() for doc in history])
 
-  api = OptimizerApi()
+  api = OptimizerApi(request.user)
 
   response['upload_history'] = api.upload(data=queries, data_type='queries', source_platform=source_platform)
   response['status'] = 0
@@ -394,7 +401,7 @@ def upload_table_stats(request):
     except Exception, e:
       LOG.exception('Skipping upload of %s: %s' % (db_table, e))
 
-  api = OptimizerApi()
+  api = OptimizerApi(request.user)
 
   response['upload_table_stats'] = api.upload(data=table_stats, data_type='table_stats', source_platform=source_platform)
   response['status'] = 0 if response['upload_table_stats']['status']['state'] in ('WAITING', 'FINISHED', 'IN_PROGRESS') else -1
@@ -414,7 +421,7 @@ def upload_status(request):
 
   workload_id = request.POST.get('workloadId')
 
-  api = OptimizerApi()
+  api = OptimizerApi(request.user)
 
   response['upload_status'] = api.upload_status(workload_id=workload_id)
   response['status'] = 0

+ 33 - 10
desktop/libs/metadata/src/metadata/optimizer_client.py

@@ -24,6 +24,7 @@ import uuid
 from tempfile import NamedTemporaryFile
 from urlparse import urlparse
 
+from django.utils.functional import wraps
 from django.utils.translation import ugettext as _
 
 from desktop.lib.exceptions_renderable import PopupException
@@ -54,9 +55,34 @@ class NavOptException(Exception):
     return smart_unicode(self.message)
 
 
+def check_privileges(view_func):
+  def decorate(*args, **kwargs):
+
+    if OPTIMIZER.APPLY_SENTRY_PERMISSIONS.get():
+      checker = get_checker(user=args[0].user)
+      action = 'SELECT'
+      objects = []
+
+      if kwargs.get('db_tables'):
+        for db_table in kwargs['db_tables']:
+          objects.append({'server': get_hive_sentry_provider(), 'db': _get_table_name(db_table)['database'], 'table': _get_table_name(db_table)['table']})
+      else:
+        objects = [{'server': get_hive_sentry_provider()}]
+        if kwargs.get('database_name'):
+          objects[0]['db'] = kwargs['database_name']
+        if kwargs.get('database_name'):
+          objects[0]['table'] = kwargs['table_name']
+
+      if len(list(checker.filter_objects(objects, action))) != len(objects):
+        raise MissingSentryPrivilegeException(objects)
+
+    return view_func(*args, **kwargs)
+  return wraps(view_func)(decorate)
+
+
 class OptimizerApi(object):
 
-  def __init__(self, user=None, api_url=None, product_name=None, product_secret=None, ssl_cert_ca_verify=OPTIMIZER.SSL_CERT_CA_VERIFY.get(), product_auth_secret=None):
+  def __init__(self, user, api_url=None, product_name=None, product_secret=None, ssl_cert_ca_verify=OPTIMIZER.SSL_CERT_CA_VERIFY.get(), product_auth_secret=None):
     self.user = user
     self._api_url = (api_url or get_optimizer_url()).strip('/')
     self._email = OPTIMIZER.EMAIL.get()
@@ -160,15 +186,8 @@ class OptimizerApi(object):
   def upload_status(self, workload_id):
     return self._call('uploadStatus', {'tenant' : self._product_name, 'workloadId': workload_id})
 
-
+  @check_privileges
   def top_tables(self, workfloadId=None, database_name='default', page_size=1000, startingToken=None):
-    if OPTIMIZER.APPLY_SENTRY_PERMISSIONS.get():
-      checker = get_checker(user=self.user)
-      action = 'SELECT'
-      objects = [{'server': get_hive_sentry_provider(), 'db': database_name}]
-      if not checker.filter_objects(objects, action):
-        raise MissingSentryPrivilegeException(objects)
-
     data = self._call('getTopTables', {'tenant' : self._product_name, 'dbName': database_name.lower(), 'pageSize': page_size, startingToken: None})
 
     if OPTIMIZER.APPLY_SENTRY_PERMISSIONS.get():
@@ -183,6 +202,7 @@ class OptimizerApi(object):
 
     return data
 
+  @check_privileges
   def table_details(self, database_name, table_name, page_size=100, startingToken=None):
     return self._call('getTablesDetail', {'tenant' : self._product_name, 'dbName': database_name.lower(), 'tableName': table_name.lower(), 'pageSize': page_size, startingToken: None})
 
@@ -216,6 +236,7 @@ class OptimizerApi(object):
     return self._call('getSimilarQueries', {'tenant' : self._product_name, 'sourcePlatform': source_platform, 'query': query, 'pageSize': page_size, startingToken: None})
 
 
+  @check_privileges
   def top_filters(self, db_tables=None, page_size=100, startingToken=None):
     args = {
       'tenant' : self._product_name,
@@ -227,7 +248,7 @@ class OptimizerApi(object):
 
     return self._call('getTopFilters', args)
 
-
+  @check_privileges
   def top_aggs(self, db_tables=None, page_size=100, startingToken=None):
     args = {
       'tenant' : self._product_name,
@@ -240,6 +261,7 @@ class OptimizerApi(object):
     return self._call('getTopAggs', args)
 
 
+  @check_privileges
   def top_columns(self, db_tables=None, page_size=100, startingToken=None):
     args = {
       'tenant' : self._product_name,
@@ -252,6 +274,7 @@ class OptimizerApi(object):
     return self._call('getTopColumns', args)
 
 
+  @check_privileges
   def top_joins(self, db_tables=None, page_size=100, startingToken=None):
     args = {
       'tenant' : self._product_name,

+ 3 - 3
desktop/libs/notebook/src/notebook/connectors/hiveserver2.py

@@ -530,7 +530,7 @@ DROP TABLE IF EXISTS `%(table)s`;
     response = self._get_current_statement(db, snippet)
     query = response['statement']
 
-    api = OptimizerApi()
+    api = OptimizerApi(self.user)
 
     return api.query_risk(query=query, source_platform=snippet['type'], db_name=snippet.get('database') or 'default')
 
@@ -541,7 +541,7 @@ DROP TABLE IF EXISTS `%(table)s`;
     response = self._get_current_statement(db, snippet)
     query = response['statement']
 
-    api = OptimizerApi()
+    api = OptimizerApi(self.user)
 
     return api.query_compatibility(source_platform, target_platform, query)
 
@@ -552,7 +552,7 @@ DROP TABLE IF EXISTS `%(table)s`;
     response = self._get_current_statement(db, snippet)
     query = response['statement']
 
-    api = OptimizerApi()
+    api = OptimizerApi(self.user)
 
     return api.similar_queries(source_platform, query)