فهرست منبع

HUE-88. Hue-Hadoop connectivity is not appropriately thread-safe

- Adding /debug/who_am_i endpoint.
- Moving user info and request_context into a thread-local.
- Adding test for hadoopfs threadedness.

Kudos to Todd Lipcon and BC Wong for review.
Philip Zeyliger 15 سال پیش
والد
کامیت
7dcab4fcdf

+ 1 - 0
desktop/core/src/desktop/urls.py

@@ -60,6 +60,7 @@ dynamic_patterns = patterns('',
   (r'^admin/', include(admin.site.urls)),
   (r'^depender/', include(depender.urls)),
   (r'^debug/threads$', 'desktop.views.threads'),
+  (r'^debug/who_am_i$', 'desktop.views.who_am_i'),
   (r'^log_frontend_event$', 'desktop.views.log_frontend_event'),
   # Top level web page!
   (r'^$', 'desktop.views.index'),

+ 11 - 0
desktop/core/src/desktop/views.py

@@ -244,3 +244,14 @@ def log_frontend_event(request):
     get("message", "")[:_MAX_LOG_FRONTEND_EVENT_LENGTH])
   _LOG_FRONTEND_LOGGER.log(level, msg)
   return HttpResponse("")
+
+def who_am_i(request):
+  """
+  Returns username and FS username, and optionally sleeps.
+  """
+  try:
+    sleep = float(request.REQUEST.get("sleep") or 0.0)
+  except ValueError:
+    sleep = 0.0
+  time.sleep(sleep)
+  return HttpResponse(request.user.username + "\t" + request.fs.user + "\n")

+ 26 - 5
desktop/libs/hadoop/src/hadoop/fs/hadoopfs.py

@@ -26,6 +26,7 @@ import stat as statconsts
 import subprocess
 import sys
 import urlparse
+import threading
 
 from thrift.transport import TTransport
 from thrift.transport import TSocket
@@ -101,7 +102,10 @@ class HadoopFileSystem(object):
     self.nn_client = thrift_util.get_client(Namenode.Client, host, thrift_port, service_name="HDFS Namenode",
                                             timeout_seconds=NN_THRIFT_TIMEOUT)
 
-    self.request_context = RequestContext()
+    # The file systems are cached globally.  We store
+    # user information in a thread-local variable so that
+    # safety can be preserved there.
+    self.thread_local = threading.local()
     self.setuser(DEFAULT_USER, DEFAULT_GROUPS)
     LOG.debug("Initialized HadoopFS: %s:%d (%s)", host, thrift_port, hadoop_bin_path)
 
@@ -158,14 +162,31 @@ class HadoopFileSystem(object):
   def setuser(self, user, groups=None):
     # Hadoop UGI *must* have at least one group, so we mirror
     # the username as a group if not specified
+    self.thread_local.request_context = RequestContext()
     if not groups:
       groups = [user]
     if not self.request_context.confOptions:
       self.request_context.confOptions = {}
-    self.ugi = ",".join([user] + groups)
-    self.request_context.confOptions['hadoop.job.ugi'] = self.ugi
-    self.user = user
-    self.groups = groups
+    self.thread_local.ugi = ",".join([user] + groups)
+    self.thread_local.request_context.confOptions['hadoop.job.ugi'] = self.thread_local.ugi
+    self.thread_local.user = user
+    self.thread_local.groups = groups
+
+  @property
+  def user(self):
+    return self.thread_local.user
+
+  @property
+  def groups(self):
+    return self.thread_local.groups
+
+  @property
+  def request_context(self):
+    return self.thread_local.request_context
+
+  @property
+  def ugi(self):
+    return self.thread_local.ugi
 
   @_coerce_exceptions
   def open(self, path, mode="r", *args, **kwargs):

+ 22 - 0
desktop/libs/hadoop/src/hadoop/fs/hadoopfs_test.py

@@ -23,6 +23,7 @@ from nose.plugins.attrib import attr
 import logging
 import posixfile
 import random
+from threading import Thread
 
 from hadoop import mini_cluster
 from hadoop.fs.exceptions import PermissionDeniedException
@@ -344,3 +345,24 @@ def test_i18n_namespace():
     except Exception, ex:
       LOG.error('Failed to cleanup %s: %s' % (prefix, ex))
     cluster.shutdown()
+
+@attr('requires_hadoop')
+def test_threadedness():
+  # Start a second thread to change the user, and
+  # make sure that isn't reflected.
+  cluster = mini_cluster.shared_cluster()
+  try:
+    fs = cluster.fs
+    fs.setuser("alpha")
+    class T(Thread):
+      def run(self):
+        fs.setuser("beta")
+        assert_equals("beta", fs.user)
+    t = T()
+    t.start()
+    t.join()
+    assert_equals("alpha", fs.user)
+    fs.setuser("gamma")
+    assert_equals("gamma", fs.user)
+  finally:
+    cluster.shutdown()

+ 43 - 26
desktop/libs/hadoop/src/hadoop/job_tracker.py

@@ -27,6 +27,8 @@ from hadoop.api.jobtracker.ttypes import ThriftJobID, ThriftTaskAttemptID, \
     JobTrackerState, JobNotFoundException, ThriftTaskQueryState
 from hadoop.api.common.ttypes import RequestContext
 
+import threading
+
 VALID_TASK_STATES = set(["succeeded", "failed", "running", "pending", "killed"])
 VALID_TASK_TYPES = set(["map", "reduce", "job_cleanup", "job_setup"])
 
@@ -50,7 +52,10 @@ class LiveJobTracker(object):
       timeout_seconds=JT_THRIFT_TIMEOUT)
     self.host = host
     self.thrift_port = thrift_port
-    self.request_context = RequestContext()
+    # We allow a single LiveJobTracker to be used across multiple
+    # threads by restricting the stateful components to a thread
+    # thread-local.
+    self.thread_local = threading.local()
     self.setuser(DEFAULT_USER, DEFAULT_GROUPS)
 
   def thriftjobid_from_string(self, jobid):
@@ -132,25 +137,37 @@ class LiveJobTracker(object):
   def setuser(self, user, groups=None):
     # Hadoop UGI *must* have at least one group, so we mirror
     # the username as a group if not specified
+    self.thread_local.request_context = RequestContext()
     if not groups:
       groups = [user]
-    if not self.request_context.confOptions:
-      self.request_context.confOptions = {}
-    self.ugi = ",".join([user] + groups)
-    self.request_context.confOptions['hadoop.job.ugi'] = self.ugi
+    if not self.thread_local.request_context.confOptions:
+      self.thread_local.request_context.confOptions = {}
+    self.thread_local.ugi = ",".join([user] + groups)
+    self.thread_local.request_context.confOptions['hadoop.job.ugi'] = self.thread_local.ugi
+
+  @property
+  def ugi(self):
+    # Here for backwards-compatibility.
+    return self.thread_local.ugi
+
+  @property
+  def request_context(self):
+    # Here for backwards-compatibility.
+    return self.thread_local.request_context
+
 
   def queues(self):
     """
     Returns a ThriftJobQueueList
     """
-    qs = self.client.getQueues(self.request_context)
+    qs = self.client.getQueues(self.thread_local.request_context)
     return qs
 
   def cluster_status(self):
     """
     Returns a ThriftClusterStatus
     """
-    cs = self.client.getClusterStatus(self.request_context)
+    cs = self.client.getClusterStatus(self.thread_local.request_context)
     fixup_enums(cs, {"state":JobTrackerState})
     return cs
 
@@ -158,13 +175,13 @@ class LiveJobTracker(object):
     """
     Returns a RuntimeInfo
     """
-    return self.client.getRuntimeInfo(self.request_context)
+    return self.client.getRuntimeInfo(self.thread_local.request_context)
 
   def all_task_trackers(self):
     """
     Returns a ThriftTaskTrackerStatusList
     """
-    tts = self.client.getAllTrackers(self.request_context)
+    tts = self.client.getAllTrackers(self.thread_local.request_context)
     for tracker in tts.trackers:
       self._fixup_tasktracker(tracker)
     return tts
@@ -173,7 +190,7 @@ class LiveJobTracker(object):
     """
     Returns a ThriftTaskTrackerStatusList
     """
-    tts = self.client.getActiveTrackers(self.request_context)
+    tts = self.client.getActiveTrackers(self.thread_local.request_context)
     for tracker in tts.trackers:
       self._fixup_tasktracker(tracker)
     return tts
@@ -182,7 +199,7 @@ class LiveJobTracker(object):
     """
     Returns a ThriftTaskTrackerStatusList
     """
-    tts = self.client.getBlacklistedTrackers(self.request_context)
+    tts = self.client.getBlacklistedTrackers(self.thread_local.request_context)
     for tracker in tts.trackers:
       self._fixup_tasktracker(tracker)
     return tts
@@ -191,7 +208,7 @@ class LiveJobTracker(object):
     """
     Returns a ThriftTaskTrackerStatus or None
     """
-    tracker = self.client.getTracker(self.request_context, name)
+    tracker = self.client.getTracker(self.thread_local.request_context, name)
     if not tracker:
       return None
     self._fixup_tasktracker(tracker)
@@ -202,7 +219,7 @@ class LiveJobTracker(object):
     Returns a ThriftJobInProgress (including task info)
     """
     try:
-      job = self.client.getJob(self.request_context, jobid)
+      job = self.client.getJob(self.thread_local.request_context, jobid)
     except JobNotFoundException, e:
       e.response_data = dict(code="JT_JOB_NOT_FOUND", message="Could not find job %s on JobTracker." % jobid.asString, data=jobid)
       raise
@@ -213,7 +230,7 @@ class LiveJobTracker(object):
     """
     Returns a ThriftJobList (does not include task info)
     """
-    joblist = self.client.getRunningJobs(self.request_context)
+    joblist = self.client.getRunningJobs(self.thread_local.request_context)
     for job in joblist.jobs:
       self._fixup_job(job)
     return joblist
@@ -222,7 +239,7 @@ class LiveJobTracker(object):
     """
     Returns a ThriftJobList (does not include task info)
     """
-    joblist = self.client.getCompletedJobs(self.request_context)
+    joblist = self.client.getCompletedJobs(self.thread_local.request_context)
     for job in joblist.jobs:
       self._fixup_job(job)
     return joblist
@@ -231,7 +248,7 @@ class LiveJobTracker(object):
     """
     Returns a ThriftJobList (does not include task info)
     """
-    joblist = self.client.getFailedJobs(self.request_context)
+    joblist = self.client.getFailedJobs(self.thread_local.request_context)
     for job in joblist.jobs:
       self._fixup_job(job)
     return joblist
@@ -249,7 +266,7 @@ class LiveJobTracker(object):
     """
     Returns a ThriftJobList (does not include task info)
     """
-    joblist = self.client.getAllJobs(self.request_context)
+    joblist = self.client.getAllJobs(self.thread_local.request_context)
     for job in joblist.jobs:
       self._fixup_job(job)
     return joblist
@@ -258,19 +275,19 @@ class LiveJobTracker(object):
     """
     Returns a ThriftUserJobCounts.
     """
-    return self.client.getUserJobCounts(self.request_context, user)
+    return self.client.getUserJobCounts(self.thread_local.request_context, user)
 
   def get_job_counters(self, jobid):
     """
     Returns a ThriftGroupList
     """
-    return self.client.getJobCounters(self.request_context, jobid)
+    return self.client.getJobCounters(self.thread_local.request_context, jobid)
 
   def get_job_counter_rollups(self, jobid):
     """
     Returns a ThriftGroupList
     """
-    return self.client.getJobCounterRollups(self.request_context, jobid)
+    return self.client.getJobCounterRollups(self.thread_local.request_context, jobid)
 
 
   def get_task_list(self, jobid, task_types, task_states, task_text, count, offset):
@@ -281,7 +298,7 @@ class LiveJobTracker(object):
     ttask_types = [ ThriftTaskType._NAMES_TO_VALUES[x.upper()] for x in task_types ]
     ttask_states = [ ThriftTaskQueryState._NAMES_TO_VALUES[x.upper()] for x in task_states ]
     tip_list = self.client.getTaskList(
-          self.request_context, jobid, ttask_types, ttask_states, task_text, count, offset)
+          self.thread_local.request_context, jobid, ttask_types, ttask_states, task_text, count, offset)
 
     for tip in tip_list.tasks:
       self._fixup_task_in_progress(tip)
@@ -308,28 +325,28 @@ class LiveJobTracker(object):
     """
     Returns an integer timestamp
     """
-    return self.client.getCurrentTime(self.request_context)
+    return self.client.getCurrentTime(self.thread_local.request_context)
 
   def get_job_xml(self, jobid):
     """
     Returns a string representation of the job XML
     """
-    return self.client.getJobConfXML(self.request_context, jobid)
+    return self.client.getJobConfXML(self.thread_local.request_context, jobid)
 
   def kill_job(self, jobid):
     """
     Kill a job
     """
-    return self.client.killJob(self.request_context, jobid)
+    return self.client.killJob(self.thread_local.request_context, jobid)
 
   def kill_task_attempt(self, attemptid):
     """
     Kill a task attempt
     """
-    return self.client.killTaskAttempt(self.request_context, attemptid)
+    return self.client.killTaskAttempt(self.thread_local.request_context, attemptid)
 
   def set_job_priority(self, jobid, priority):
     """
     Set a job's priority
     """
-    return self.client.setJobPriority(self.request_context, jobid, priority)
+    return self.client.setJobPriority(self.thread_local.request_context, jobid, priority)