|
|
@@ -89,10 +89,16 @@ def _secure_results(results, user, action='SELECT'):
|
|
|
|
|
|
if 'dbName' in result:
|
|
|
key['db'] = result['dbName']
|
|
|
+ elif 'database' in result:
|
|
|
+ key['db'] = result['database']
|
|
|
if 'tableName' in result:
|
|
|
key['table'] = result['tableName']
|
|
|
+ elif 'table' in result:
|
|
|
+ key['table'] = result['table']
|
|
|
if 'columnName' in result:
|
|
|
key['column'] = result['columnName']
|
|
|
+ elif 'column' in result:
|
|
|
+ key['column'] = result['column']
|
|
|
|
|
|
return key
|
|
|
|
|
|
@@ -216,7 +222,7 @@ class OptimizerApi(object):
|
|
|
action = 'SELECT'
|
|
|
|
|
|
def getkey(table):
|
|
|
- names = _get_table_name(table['name']),
|
|
|
+ names = _get_table_name(table['name'])
|
|
|
return {'server': get_hive_sentry_provider(), 'db': names['database'], 'table': names['table']}
|
|
|
|
|
|
data['results'] = list(checker.filter_objects(data['results'], action, key=getkey))
|
|
|
@@ -293,8 +299,10 @@ class OptimizerApi(object):
|
|
|
args['dbTableList'] = [db_table.lower() for db_table in db_tables]
|
|
|
|
|
|
results = self._call('getTopColumns', args)
|
|
|
- for section in ['orderbyColumns', 'selectColumns', 'filterColumns', 'joinColumns', 'groupbyColumns']:
|
|
|
- results[section] = list(_secure_results(results[section], self.user))
|
|
|
+
|
|
|
+ if OPTIMIZER.APPLY_SENTRY_PERMISSIONS.get():
|
|
|
+ for section in ['orderbyColumns', 'selectColumns', 'filterColumns', 'joinColumns', 'groupbyColumns']:
|
|
|
+ results[section] = list(_secure_results(results[section], self.user))
|
|
|
return results
|
|
|
|
|
|
|
|
|
@@ -308,7 +316,16 @@ class OptimizerApi(object):
|
|
|
if db_tables:
|
|
|
args['dbTableList'] = [db_table.lower() for db_table in db_tables]
|
|
|
|
|
|
- return self._call('getTopJoins', args)
|
|
|
+ results = self._call('getTopJoins', args)
|
|
|
+
|
|
|
+ if OPTIMIZER.APPLY_SENTRY_PERMISSIONS.get():
|
|
|
+ filtered_joins = []
|
|
|
+ for result in results['results']:
|
|
|
+ cols = [_get_table_name(col) for col in result["joinCols"][0]["columns"]]
|
|
|
+ if len(cols) == len(list(_secure_results(cols, self.user))):
|
|
|
+ filtered_joins.append(cols)
|
|
|
+ results['resulsts'] = filtered_joins
|
|
|
+ return results
|
|
|
|
|
|
|
|
|
def top_databases(self, page_size=100, startingToken=None):
|
|
|
@@ -332,9 +349,16 @@ def OptimizerQueryDataAdapter(data):
|
|
|
yield headers, rows
|
|
|
|
|
|
def _get_table_name(path):
|
|
|
- if '.' in path:
|
|
|
+ column = None
|
|
|
+
|
|
|
+ if path.count('.') == 1:
|
|
|
database, table = path.split('.', 1)
|
|
|
+ elif path.count('.') == 2:
|
|
|
+ database, table, column = path.split('.', 2)
|
|
|
else:
|
|
|
database, table = 'default', path
|
|
|
|
|
|
- return {'database': database, 'table': table}
|
|
|
+ name = {'database': database, 'table': table}
|
|
|
+ if column:
|
|
|
+ name['column'] = column
|
|
|
+ return name
|