Explorar o código

[hbase] Decorating Thrift calls to set the doAs header

No other way, except maybe making it generic in the thrift_util
lib but this is overkill for now.
Romain Rigaux %!s(int64=10) %!d(string=hai) anos
pai
achega
5c852c0

+ 83 - 1
apps/hbase/gen-py/hbased/Hbase.py

@@ -3,6 +3,8 @@
 #
 # DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
 #
+# Warning: module edited below.
+#
 #  options string: py
 #
 
@@ -16,6 +18,43 @@ try:
 except:
   fastbinary = None
 
+from django.utils.functional import wraps
+
+
+### Do as / Impersonation support.
+### This should be put back if rerenerating the Thrift.
+### As well as all the @do_as in Client.
+
+import logging
+
+from django.utils.encoding import smart_str
+from hbase.hbase_site import is_impersonation_enabled
+
+LOG = logging.getLogger(__name__)
+
+
+def do_as(func):
+  def decorate(*args, **kwargs):
+    self = args[0]
+    username = kwargs.pop('doas')
+
+    try:
+      if is_impersonation_enabled():
+        if hasattr(self._oprot.trans, 'TFramedTransport'):
+          trans_client = self._oprot.trans._TFramedTransport__trans
+        else:
+          trans_client = self._oprot.trans._TBufferedTransport__trans
+
+        trans_client.setCustomHeaders({'doAs': username})
+
+    except AttributeError, e:
+      LOG.error('Could not set doAs parameter: %s' % smart_str(e))
+
+    return func(*args, **kwargs)
+  return wraps(func)(decorate)
+
+###
+
 
 class Iface:
   def enableTable(self, tableName):
@@ -619,6 +658,7 @@ class Client(Iface):
       self._oprot = oprot
     self._seqid = 0
 
+  @do_as
   def enableTable(self, tableName):
     """
     Brings a table on-line (enables it)
@@ -651,6 +691,7 @@ class Client(Iface):
       raise result.io
     return
 
+  @do_as
   def disableTable(self, tableName):
     """
     Disables a table (takes it off-line) If it is being served, the master
@@ -684,6 +725,7 @@ class Client(Iface):
       raise result.io
     return
 
+  @do_as
   def isTableEnabled(self, tableName):
     """
     @return true if table is on-line
@@ -718,6 +760,7 @@ class Client(Iface):
       raise result.io
     raise TApplicationException(TApplicationException.MISSING_RESULT, "isTableEnabled failed: unknown result");
 
+  @do_as
   def compact(self, tableNameOrRegionName):
     """
     Parameters:
@@ -748,6 +791,7 @@ class Client(Iface):
       raise result.io
     return
 
+  @do_as
   def majorCompact(self, tableNameOrRegionName):
     """
     Parameters:
@@ -778,7 +822,8 @@ class Client(Iface):
       raise result.io
     return
 
-  def getTableNames(self, ):
+  @do_as
+  def getTableNames(self):
     """
     List all the userspace tables.
 
@@ -810,6 +855,7 @@ class Client(Iface):
       raise result.io
     raise TApplicationException(TApplicationException.MISSING_RESULT, "getTableNames failed: unknown result");
 
+  @do_as
   def getColumnDescriptors(self, tableName):
     """
     List all the column families assoicated with a table.
@@ -846,6 +892,7 @@ class Client(Iface):
       raise result.io
     raise TApplicationException(TApplicationException.MISSING_RESULT, "getColumnDescriptors failed: unknown result");
 
+  @do_as
   def getTableRegions(self, tableName):
     """
     List the regions associated with a table.
@@ -882,6 +929,7 @@ class Client(Iface):
       raise result.io
     raise TApplicationException(TApplicationException.MISSING_RESULT, "getTableRegions failed: unknown result");
 
+  @do_as
   def createTable(self, tableName, columnFamilies):
     """
     Create a table with the specified column families.  The name
@@ -927,6 +975,7 @@ class Client(Iface):
       raise result.exist
     return
 
+  @do_as
   def deleteTable(self, tableName):
     """
     Deletes a table
@@ -962,6 +1011,7 @@ class Client(Iface):
       raise result.io
     return
 
+  @do_as
   def get(self, tableName, row, column, attributes):
     """
     Get a single TCell for the specified table, row, and column at the
@@ -1005,6 +1055,7 @@ class Client(Iface):
       raise result.io
     raise TApplicationException(TApplicationException.MISSING_RESULT, "get failed: unknown result");
 
+  @do_as
   def getVer(self, tableName, row, column, numVersions, attributes):
     """
     Get the specified number of versions for the specified table,
@@ -1050,6 +1101,7 @@ class Client(Iface):
       raise result.io
     raise TApplicationException(TApplicationException.MISSING_RESULT, "getVer failed: unknown result");
 
+  @do_as
   def getVerTs(self, tableName, row, column, timestamp, numVersions, attributes):
     """
     Get the specified number of versions for the specified table,
@@ -1098,6 +1150,7 @@ class Client(Iface):
       raise result.io
     raise TApplicationException(TApplicationException.MISSING_RESULT, "getVerTs failed: unknown result");
 
+  @do_as
   def getRow(self, tableName, row, attributes):
     """
     Get all the data for the specified table and row at the latest
@@ -1139,6 +1192,7 @@ class Client(Iface):
       raise result.io
     raise TApplicationException(TApplicationException.MISSING_RESULT, "getRow failed: unknown result");
 
+  @do_as
   def getRowWithColumns(self, tableName, row, columns, attributes):
     """
     Get the specified columns for the specified table and row at the latest
@@ -1182,6 +1236,7 @@ class Client(Iface):
       raise result.io
     raise TApplicationException(TApplicationException.MISSING_RESULT, "getRowWithColumns failed: unknown result");
 
+  @do_as
   def getRowTs(self, tableName, row, timestamp, attributes):
     """
     Get all the data for the specified table and row at the specified
@@ -1225,6 +1280,7 @@ class Client(Iface):
       raise result.io
     raise TApplicationException(TApplicationException.MISSING_RESULT, "getRowTs failed: unknown result");
 
+  @do_as
   def getRowWithColumnsTs(self, tableName, row, columns, timestamp, attributes):
     """
     Get the specified columns for the specified table and row at the specified
@@ -1270,6 +1326,7 @@ class Client(Iface):
       raise result.io
     raise TApplicationException(TApplicationException.MISSING_RESULT, "getRowWithColumnsTs failed: unknown result");
 
+  @do_as
   def getRows(self, tableName, rows, attributes):
     """
     Get all the data for the specified table and rows at the latest
@@ -1311,6 +1368,7 @@ class Client(Iface):
       raise result.io
     raise TApplicationException(TApplicationException.MISSING_RESULT, "getRows failed: unknown result");
 
+  @do_as
   def getRowsWithColumns(self, tableName, rows, columns, attributes):
     """
     Get the specified columns for the specified table and rows at the latest
@@ -1354,6 +1412,7 @@ class Client(Iface):
       raise result.io
     raise TApplicationException(TApplicationException.MISSING_RESULT, "getRowsWithColumns failed: unknown result");
 
+  @do_as
   def getRowsTs(self, tableName, rows, timestamp, attributes):
     """
     Get all the data for the specified table and rows at the specified
@@ -1397,6 +1456,7 @@ class Client(Iface):
       raise result.io
     raise TApplicationException(TApplicationException.MISSING_RESULT, "getRowsTs failed: unknown result");
 
+  @do_as
   def getRowsWithColumnsTs(self, tableName, rows, columns, timestamp, attributes):
     """
     Get the specified columns for the specified table and rows at the specified
@@ -1442,6 +1502,7 @@ class Client(Iface):
       raise result.io
     raise TApplicationException(TApplicationException.MISSING_RESULT, "getRowsWithColumnsTs failed: unknown result");
 
+  @do_as
   def mutateRow(self, tableName, row, mutations, attributes):
     """
     Apply a series of mutations (updates/deletes) to a row in a
@@ -1485,6 +1546,7 @@ class Client(Iface):
       raise result.ia
     return
 
+  @do_as
   def mutateRowTs(self, tableName, row, mutations, timestamp, attributes):
     """
     Apply a series of mutations (updates/deletes) to a row in a
@@ -1530,6 +1592,7 @@ class Client(Iface):
       raise result.ia
     return
 
+  @do_as
   def mutateRows(self, tableName, rowBatches, attributes):
     """
     Apply a series of batches (each a series of mutations on a single row)
@@ -1571,6 +1634,7 @@ class Client(Iface):
       raise result.ia
     return
 
+  @do_as
   def mutateRowsTs(self, tableName, rowBatches, timestamp, attributes):
     """
     Apply a series of batches (each a series of mutations on a single row)
@@ -1614,6 +1678,7 @@ class Client(Iface):
       raise result.ia
     return
 
+  @do_as
   def atomicIncrement(self, tableName, row, column, value):
     """
     Atomically increment the column value specified.  Returns the next value post increment.
@@ -1656,6 +1721,7 @@ class Client(Iface):
       raise result.ia
     raise TApplicationException(TApplicationException.MISSING_RESULT, "atomicIncrement failed: unknown result");
 
+  @do_as
   def deleteAll(self, tableName, row, column, attributes):
     """
     Delete all cells that match the passed row and column.
@@ -1694,6 +1760,7 @@ class Client(Iface):
       raise result.io
     return
 
+  @do_as
   def deleteAllTs(self, tableName, row, column, timestamp, attributes):
     """
     Delete all cells that match the passed row and column and whose
@@ -1735,6 +1802,7 @@ class Client(Iface):
       raise result.io
     return
 
+  @do_as
   def deleteAllRow(self, tableName, row, attributes):
     """
     Completely delete the row's cells.
@@ -1771,6 +1839,7 @@ class Client(Iface):
       raise result.io
     return
 
+  @do_as
   def increment(self, increment):
     """
     Increment a cell by the ammount.
@@ -1806,6 +1875,7 @@ class Client(Iface):
       raise result.io
     return
 
+  @do_as
   def incrementRows(self, increments):
     """
     Parameters:
@@ -1836,6 +1906,7 @@ class Client(Iface):
       raise result.io
     return
 
+  @do_as
   def deleteAllRowTs(self, tableName, row, timestamp, attributes):
     """
     Completely delete the row's cells marked with a timestamp
@@ -1875,6 +1946,7 @@ class Client(Iface):
       raise result.io
     return
 
+  @do_as
   def scannerOpenWithScan(self, tableName, scan, attributes):
     """
     Get a scanner on the current table, using the Scan instance
@@ -1914,6 +1986,7 @@ class Client(Iface):
       raise result.io
     raise TApplicationException(TApplicationException.MISSING_RESULT, "scannerOpenWithScan failed: unknown result");
 
+  @do_as
   def scannerOpen(self, tableName, startRow, columns, attributes):
     """
     Get a scanner on the current table starting at the specified row and
@@ -1960,6 +2033,7 @@ class Client(Iface):
       raise result.io
     raise TApplicationException(TApplicationException.MISSING_RESULT, "scannerOpen failed: unknown result");
 
+  @do_as
   def scannerOpenWithStop(self, tableName, startRow, stopRow, columns, attributes):
     """
     Get a scanner on the current table starting and stopping at the
@@ -2010,6 +2084,7 @@ class Client(Iface):
       raise result.io
     raise TApplicationException(TApplicationException.MISSING_RESULT, "scannerOpenWithStop failed: unknown result");
 
+  @do_as
   def scannerOpenWithPrefix(self, tableName, startAndPrefix, columns, attributes):
     """
     Open a scanner for a given prefix.  That is all rows will have the specified
@@ -2053,6 +2128,7 @@ class Client(Iface):
       raise result.io
     raise TApplicationException(TApplicationException.MISSING_RESULT, "scannerOpenWithPrefix failed: unknown result");
 
+  @do_as
   def scannerOpenTs(self, tableName, startRow, columns, timestamp, attributes):
     """
     Get a scanner on the current table starting at the specified row and
@@ -2102,6 +2178,7 @@ class Client(Iface):
       raise result.io
     raise TApplicationException(TApplicationException.MISSING_RESULT, "scannerOpenTs failed: unknown result");
 
+  @do_as
   def scannerOpenWithStopTs(self, tableName, startRow, stopRow, columns, timestamp, attributes):
     """
     Get a scanner on the current table starting and stopping at the
@@ -2155,6 +2232,7 @@ class Client(Iface):
       raise result.io
     raise TApplicationException(TApplicationException.MISSING_RESULT, "scannerOpenWithStopTs failed: unknown result");
 
+  @do_as
   def scannerGet(self, id):
     """
     Returns the scanner's current row value and advances to the next
@@ -2200,6 +2278,7 @@ class Client(Iface):
       raise result.ia
     raise TApplicationException(TApplicationException.MISSING_RESULT, "scannerGet failed: unknown result");
 
+  @do_as
   def scannerGetList(self, id, nbRows):
     """
     Returns, starting at the scanner's current row value nbRows worth of
@@ -2247,6 +2326,7 @@ class Client(Iface):
       raise result.ia
     raise TApplicationException(TApplicationException.MISSING_RESULT, "scannerGetList failed: unknown result");
 
+  @do_as
   def scannerClose(self, id):
     """
     Closes the server-state associated with an open scanner.
@@ -2283,6 +2363,7 @@ class Client(Iface):
       raise result.ia
     return
 
+  @do_as
   def getRowOrBefore(self, tableName, row, family):
     """
     Get the row just before the specified one.
@@ -2323,6 +2404,7 @@ class Client(Iface):
       raise result.io
     raise TApplicationException(TApplicationException.MISSING_RESULT, "getRowOrBefore failed: unknown result");
 
+  @do_as
   def getRegionInfo(self, row):
     """
     Get the regininfo for the specified row. It scans

+ 25 - 27
apps/hbase/src/hbase/api.py

@@ -27,9 +27,10 @@ from django.utils.encoding import smart_str
 from desktop.lib import thrift_util
 from desktop.lib.exceptions_renderable import PopupException
 
-from hbase.server.hbase_lib import get_thrift_type, get_client_type
 from hbase import conf
-from hbase.hbase_site import get_server_principal, get_server_authentication, is_using_thrift_ssl, is_using_thrift_http
+from hbase.hbase_site import get_server_principal, get_server_authentication, is_using_thrift_ssl, is_using_thrift_http, is_impersonation_enabled
+from hbase.server.hbase_lib import get_thrift_type, get_client_type
+
 
 LOG = logging.getLogger(__name__)
 
@@ -55,7 +56,7 @@ class HbaseApi(object):
   def queryCluster(self, action, cluster, *args):
     client = self.connectCluster(cluster)
     method = getattr(client, action)
-    return method(*args)
+    return method(*args, doas=self.user.username)
 
   def getClusters(self):
     clusters = []
@@ -101,9 +102,6 @@ class HbaseApi(object):
                                   http_url=('https://' if is_using_thrift_ssl() else 'http://') + target['host'] + ':' + str(target['port'])
     )
 
-    if hasattr(client, 'setCustomHeaders'):
-      client.setCustomHeaders({'doAs': self.user.username})
-
     return client
 
   @classmethod
@@ -125,29 +123,29 @@ class HbaseApi(object):
 
   def get(self, cluster, tableName, row, column, attributes):
     client = self.connectCluster(cluster)
-    return client.get(tableName, smart_str(row), smart_str(column), attributes)
+    return client.get(tableName, smart_str(row), smart_str(column), attributes, doas=self.user.username)
 
   def getVerTs(self, cluster, tableName, row, column, timestamp, numVersions, attributesargs):
     client = self.connectCluster(cluster)
-    return client.getVerTs(tableName, smart_str(row), smart_str(column), timestamp, numVersions, attributesargs)
+    return client.getVerTs(tableName, smart_str(row), smart_str(column), timestamp, numVersions, attributesargs, doas=self.user.username)
 
   def createTable(self, cluster, tableName, columns):
     client = self.connectCluster(cluster)
-    client.createTable(tableName, [get_thrift_type('ColumnDescriptor')(**column['properties']) for column in columns])
+    client.createTable(tableName, [get_thrift_type('ColumnDescriptor')(**column['properties']) for column in columns], doas=self.user.username)
     return "%s successfully created" % tableName
 
   def getTableList(self, cluster):
     client = self.connectCluster(cluster)
-    return [{'name': name, 'enabled': client.isTableEnabled(name)} for name in client.getTableNames()]
+    return [{'name': name, 'enabled': client.isTableEnabled(name, doas=self.user.username)} for name in client.getTableNames(doas=self.user.username)]
 
   def getRows(self, cluster, tableName, columns, startRowKey, numRows, prefix=False):
     client = self.connectCluster(cluster)
     if prefix == False:
-      scanner = client.scannerOpen(tableName, smart_str(startRowKey), columns, None)
+      scanner = client.scannerOpen(tableName, smart_str(startRowKey), columns, None, doas=self.user.username)
     else:
-      scanner = client.scannerOpenWithPrefix(tableName, smart_str(startRowKey), columns, None)
-    data = client.scannerGetList(scanner, numRows)
-    client.scannerClose(scanner)
+      scanner = client.scannerOpenWithPrefix(tableName, smart_str(startRowKey), columns, None, doas=self.user.username)
+    data = client.scannerGetList(scanner, numRows, doas=self.user.username)
+    client.scannerClose(scanner, doas=self.user.username)
     return data
 
   def getAutocompleteRows(self, cluster, tableName, numRows, query):
@@ -155,8 +153,8 @@ class HbaseApi(object):
     try:
       client = self.connectCluster(cluster)
       scan = get_thrift_type('TScan')(startRow=query, stopRow=None, timestamp=None, columns=[], caching=None, filterString="PrefixFilter('" + query + "') AND ColumnPaginationFilter(1,0)", batchSize=None)
-      scanner = client.scannerOpenWithScan(tableName, scan, None)
-      return [result.row for result in client.scannerGetList(scanner, numRows)]
+      scanner = client.scannerOpenWithScan(tableName, scan, None, doas=self.user.username)
+      return [result.row for result in client.scannerGetList(scanner, numRows, doas=self.user.username)]
     except Exception, e:
       LOG.error('Autocomplete error: %s' % smart_str(e))
       return []
@@ -169,7 +167,7 @@ class HbaseApi(object):
 
   def getRowsFull(self, cluster, tableName, startRowKey, numRows):
     client = self.connectCluster(cluster)
-    return self.getRows(cluster, tableName, [smart_str(column) for column in client.getColumnDescriptors(tableName)], smart_str(startRowKey), numRows)
+    return self.getRows(cluster, tableName, [smart_str(column) for column in client.getColumnDescriptors(tableName, aaa=11, doas=self.user.username)], smart_str(startRowKey), numRows)
 
   def getRowFull(self, cluster, tableName, startRowKey, numRows):
     row = self.getRowsFull(cluster, tableName, smart_str(startRowKey), 1)
@@ -180,21 +178,21 @@ class HbaseApi(object):
   def getRowPartial(self, cluster, tableName, rowKey, offset, number):
     client = self.connectCluster(cluster)
     scan = get_thrift_type('TScan')(startRow=rowKey, stopRow=None, timestamp=None, columns=[], caching=None, filterString="ColumnPaginationFilter(%i, %i)" % (number, offset), batchSize=None)
-    scanner = client.scannerOpenWithScan(tableName, scan, None)
-    return client.scannerGetList(scanner, 1)
+    scanner = client.scannerOpenWithScan(tableName, scan, None, doas=self.user.username)
+    return client.scannerGetList(scanner, 1, doas=self.user.username)
 
   def deleteColumns(self, cluster, tableName, row, columns):
     client = self.connectCluster(cluster)
     Mutation = get_thrift_type('Mutation')
     mutations = [Mutation(isDelete = True, column=smart_str(column)) for column in columns]
-    return client.mutateRow(tableName, smart_str(row), mutations, None)
+    return client.mutateRow(tableName, smart_str(row), mutations, None, doas=self.user.username)
 
   def deleteColumn(self, cluster, tableName, row, column):
-    return self.deleteColumns(cluster, tableName, smart_str(row), [smart_str(column)])
+    return self.deleteColumns(cluster, tableName, smart_str(row), [smart_str(column)], doas=self.user.username)
 
   def deleteAllRow(self, cluster, tableName, row, attributes):
     client = self.connectCluster(cluster)
-    return client.deleteAllRow(tableName, smart_str(row), attributes)
+    return client.deleteAllRow(tableName, smart_str(row), attributes, doas=self.user.username)
 
   def putRow(self, cluster, tableName, row, data):
     client = self.connectCluster(cluster)
@@ -202,7 +200,7 @@ class HbaseApi(object):
     Mutation = get_thrift_type('Mutation')
     for column in data.keys():
       mutations.append(Mutation(column=smart_str(column), value=smart_str(data[column]))) # must use str for API, does thrift coerce by itself?
-    return client.mutateRow(tableName, smart_str(row), mutations, None)
+    return client.mutateRow(tableName, smart_str(row), mutations, None, doas=self.user.username)
 
   def putColumn(self, cluster, tableName, row, column, value):
     return self.putRow(cluster, tableName, smart_str(row), {column: value})
@@ -210,7 +208,7 @@ class HbaseApi(object):
   def putUpload(self, cluster, tableName, row, column, value):
     client = self.connectCluster(cluster)
     Mutation = get_thrift_type('Mutation')
-    return client.mutateRow(tableName, smart_str(row), [Mutation(column=smart_str(column), value=value.file.read(value.size))], None)
+    return client.mutateRow(tableName, smart_str(row), [Mutation(column=smart_str(column), value=value.file.read(value.size))], None, doas=self.user.username)
 
   def getRowQuerySet(self, cluster, tableName, columns, queries):
     client = self.connectCluster(cluster)
@@ -227,8 +225,8 @@ class HbaseApi(object):
       filterstring = "(ColumnPaginationFilter(%i,0) AND PageFilter(%i))" % (limit, limit) + (fs or "")
       scan_columns = [smart_str(column.strip(':')) for column in query['columns']] or [smart_str(column.strip(':')) for column in columns]
       scan = get_thrift_type('TScan')(startRow=smart_str(query['row_key']), stopRow=None, timestamp=None, columns=scan_columns, caching=None, filterString=filterstring, batchSize=None)
-      scanner = client.scannerOpenWithScan(tableName, scan, None)
-      aggregate_data += client.scannerGetList(scanner, query['scan_length'])
+      scanner = client.scannerOpenWithScan(tableName, scan, None, doas=self.user.username)
+      aggregate_data += client.scannerGetList(scanner, query['scan_length'], doas=self.user.username)
     return aggregate_data
 
   def bulkUpload(self, cluster, tableName, data):
@@ -249,5 +247,5 @@ class HbaseApi(object):
         if str(row[column_index]) != "":
           mutations.append(Mutation(column=smart_str(columns[column_index]), value=smart_str(row[column_index])))
       batches += [BatchMutation(row=row_key, mutations=mutations)]
-    client.mutateRows(tableName, batches, None)
+    client.mutateRows(tableName, batches, None, doas=self.user.username)
     return True

+ 2 - 2
apps/hbase/src/hbase/hbase_site.py

@@ -62,10 +62,10 @@ def get_server_authentication():
   return get_conf().get(_CNF_HBASE_AUTHENTICATION, 'NOSASL').upper()
 
 def is_impersonation_enabled():
-  return get_conf().get(_CNF_HBASE_IMPERSONATION_ENABLED, 'FALSE').upper() == 'TRUE' or USE_DOAS.get()
+  return get_conf().get(_CNF_HBASE_IMPERSONATION_ENABLED, 'FALSE').upper() == 'TRUE' and USE_DOAS.get()
 
 def is_using_thrift_http():
-  return get_conf().get(_CNF_HBASE_USE_THRIFT_HTTP, 'FALSE').upper() == 'TRUE' or USE_DOAS.get()
+  return get_conf().get(_CNF_HBASE_USE_THRIFT_HTTP, 'FALSE').upper() == 'TRUE' and USE_DOAS.get()
 
 def is_using_thrift_ssl():
   return get_conf().get(_CNF_HBASE_USE_THRIFT_SSL, 'FALSE').upper() == 'TRUE'

+ 68 - 1
apps/hbase/src/hbase/tests.py

@@ -21,8 +21,13 @@ import tempfile
 
 from nose.tools import assert_true, assert_equal
 
+from django.contrib.auth.models import User
+
+from desktop.lib.django_test_util import make_logged_in_client
+from desktop.lib.test_utils import grant_access
+
 from hbase.api import HbaseApi
-from hbase.conf import HBASE_CONF_DIR
+from hbase.conf import HBASE_CONF_DIR, USE_DOAS
 from hbase.hbase_site import get_server_authentication, get_server_principal, reset
 
 
@@ -92,3 +97,65 @@ def hbase_site_xml(
     'kerberos_principal': kerberos_principal,
     'authentication': authentication,
   }
+
+
+def test_impersonation_is_decorator_is_there():
+  # Decorator is still there
+  from hbased.Hbase import do_as
+
+
+def test_impersonation():
+  from hbased import Hbase as thrift_hbase
+
+  c = make_logged_in_client(username='test_hbase', is_superuser=False)
+  grant_access('test_hbase', 'test_hbase', 'hbase')
+  user = User.objects.get(username='test_hbase')
+
+  proto = MockProtocol()
+  client = thrift_hbase.Client(proto)
+
+  reset = USE_DOAS.set_for_testing(False)
+  try:
+    client.getTableNames(doas=user.username)
+  except AttributeError:
+    pass # We don't mock everything
+  finally:
+    reset()
+
+  assert_equal({}, proto.get_headers())
+
+
+  reset = USE_DOAS.set_for_testing(True)
+
+  try:
+    client.getTableNames(doas=user.username)
+  except AttributeError:
+    pass # We don't mock everything
+  finally:
+    reset()
+
+  assert_equal({'doAs': u'test_hbase'}, proto.get_headers())
+
+
+
+class MockHttpClient():
+  def __init__(self):
+    self.headers = {}
+
+  def setCustomHeaders(self, headers):
+    self.headers = headers
+
+class MockTransport():
+  def __init__(self):
+    self._TBufferedTransport__trans = MockHttpClient()
+
+class MockProtocol():
+  def __init__(self):
+    self.trans = MockTransport()
+
+  def getTableNames(self):
+    pass
+
+  def get_headers(self):
+    return self.trans._TBufferedTransport__trans.headers
+

+ 11 - 8
desktop/core/src/desktop/lib/rest/resource.py

@@ -108,18 +108,18 @@ class Resource(object):
     return self.invoke("DELETE", relpath, params)
 
 
-  def post(self, relpath=None, params=None, data=None, contenttype=None):
+  def post(self, relpath=None, params=None, data=None, contenttype=None, headers=None):
     """
     Invoke the POST method on a resource.
     @param relpath: Optional. A relative path to this resource's path.
     @param params: Key-value data.
     @param data: Optional. Body of the request.
     @param contenttype: Optional.
+    @param headers: Optional. Base set of headers.
 
     @return: A dictionary of the JSON result.
     """
-    return self.invoke("POST", relpath, params, data,
-                       self._make_headers(contenttype))
+    return self.invoke("POST", relpath, params, data, self._make_headers(contenttype, headers))
 
 
   def put(self, relpath=None, params=None, data=None, contenttype=None):
@@ -132,11 +132,14 @@ class Resource(object):
 
     @return: A dictionary of the JSON result.
     """
-    return self.invoke("PUT", relpath, params, data,
-                       self._make_headers(contenttype))
+    return self.invoke("PUT", relpath, params, data, self._make_headers(contenttype))
 
 
-  def _make_headers(self, contenttype=None):
+  def _make_headers(self, contenttype=None, headers=None):
+    if headers is None:
+      headers = {}
+
     if contenttype:
-      return { 'Content-Type': contenttype }
-    return None
+      headers.update({'Content-Type': contenttype})
+
+    return headers

+ 1 - 1
desktop/core/src/desktop/lib/thrift_/http_client.py

@@ -85,4 +85,4 @@ class THttpClient(TTransportBase):
 
     # POST
     self._root = Resource(self._client)
-    self._data = self._root.post('', data=data)
+    self._data = self._root.post('', data=data, headers=self._headers)

+ 2 - 3
desktop/core/src/desktop/lib/thrift_util.py

@@ -158,8 +158,7 @@ class ConnectionPooler(object):
     self.poolsize = poolsize
     self.dictlock = threading.Lock()
 
-  def get_client(self, conf,
-                 get_client_timeout=None):
+  def get_client(self, conf, get_client_timeout=None):
     """
     Could block while we wait for the pool to become non-empty.
 
@@ -344,7 +343,7 @@ class PooledClient(object):
         try:
           # Poke it to see if it's closed on the other end. This can happen if a connection
           # sits in the connection pool longer than the read timeout of the server.
-          sock = self.conf.transport_mode == 'TCP' and _grab_transport_from_wrapper(superclient.transport).handle
+          sock = self.conf.transport_mode != 'http' and _grab_transport_from_wrapper(superclient.transport).handle
           if sock and create_synchronous_io_multiplexer().read([sock]):
             # the socket is readable, meaning there is either data from a previous call
             # (i.e our protocol is out of sync), or the connection was shut down on the