瀏覽代碼

[core] Do not re-add default group to users at login time

Do it only if the user does not have any groups.
This way, admins can remove some users from the default group if
then assign them to some other groups.
Romain Rigaux 10 年之前
父節點
當前提交
b4ccb8f

+ 2 - 1
apps/beeswax/src/beeswax/test_base.py

@@ -334,8 +334,9 @@ class BeeswaxSampleProvider(object):
     cls.db_name = get_db_prefix(name='hive')
     cls.db_name = get_db_prefix(name='hive')
     cls.cluster, shutdown = get_shared_beeswax_server(cls.db_name)
     cls.cluster, shutdown = get_shared_beeswax_server(cls.db_name)
     cls.client = make_logged_in_client(username='test', is_superuser=False)
     cls.client = make_logged_in_client(username='test', is_superuser=False)
+    add_to_group()
     add_to_group('test')
     add_to_group('test')
-    grant_access("test", "test", "beeswax")
+    grant_access('test', 'test', 'beeswax')
     # Weird redirection to avoid binding nonsense.
     # Weird redirection to avoid binding nonsense.
     cls.shutdown = [ shutdown ]
     cls.shutdown = [ shutdown ]
     cls.init_beeswax_db()
     cls.init_beeswax_db()

+ 1 - 0
apps/beeswax/src/beeswax/tests.py

@@ -111,6 +111,7 @@ class TestBeeswaxWithHadoop(BeeswaxSampleProvider):
 
 
   def setUp(self):
   def setUp(self):
     self.user = User.objects.get(username='test')
     self.user = User.objects.get(username='test')
+    add_to_group()
     add_to_group('test')
     add_to_group('test')
     self.db = dbms.get(self.user, get_query_server_config())
     self.db = dbms.get(self.user, get_query_server_config())
     self.cluster.fs.do_as_user('test', self.cluster.fs.create_home_dir, '/user/test')
     self.cluster.fs.do_as_user('test', self.cluster.fs.create_home_dir, '/user/test')

+ 19 - 27
desktop/core/src/desktop/auth/backend.py

@@ -123,6 +123,14 @@ def find_or_create_user(username, password=None):
     user = create_user(username, password)
     user = create_user(username, password)
   return user
   return user
 
 
+def ensure_has_a_group(user):
+  default_group = get_default_user_group()
+
+  if not user.groups.exists() and default_group is not None:
+    user.groups.add(default_group)
+    user.save()
+
+
 class DesktopBackendBase(object):
 class DesktopBackendBase(object):
   """
   """
   Abstract base class for providing external authentication schemes.
   Abstract base class for providing external authentication schemes.
@@ -173,10 +181,7 @@ class AllowFirstUserDjangoBackend(django.contrib.auth.backends.ModelBackend):
       userprofile.first_login = False
       userprofile.first_login = False
       userprofile.save()
       userprofile.save()
 
 
-      default_group = get_default_user_group()
-      if default_group is not None:
-        user.groups.add(default_group)
-        user.save()
+      ensure_has_a_group(user)
 
 
       return user
       return user
 
 
@@ -214,9 +219,7 @@ class OAuthBackend(DesktopBackendBase):
     user.is_superuser = False
     user.is_superuser = False
     user.save()
     user.save()
 
 
-    default_group = get_default_user_group()
-    if default_group is not None:
-      user.groups.add(default_group)
+    ensure_has_a_group(user)
 
 
     return user
     return user
 
 
@@ -239,10 +242,9 @@ class AllowAllBackend(DesktopBackendBase):
       user = create_user(username, password)
       user = create_user(username, password)
       user.is_superuser = False
       user.is_superuser = False
       user.save()
       user.save()
-      
-    default_group = get_default_user_group()
-    if default_group is not None:
-      user.groups.add(default_group)
+
+    ensure_has_a_group(user)
+
     return user
     return user
 
 
   @classmethod
   @classmethod
@@ -264,9 +266,8 @@ class DemoBackend(django.contrib.auth.backends.ModelBackend):
 
 
       user.is_superuser = False
       user.is_superuser = False
       user.save()
       user.save()
-      default_group = get_default_user_group()
-      if default_group is not None:
-        user.groups.add(default_group)
+
+      ensure_has_a_group(user)
 
 
     user = rewrite_user(user)
     user = rewrite_user(user)
 
 
@@ -308,9 +309,7 @@ class PamBackend(DesktopBackendBase):
           profile.save()
           profile.save()
           user.is_superuser = is_super
           user.is_superuser = is_super
 
 
-          default_group = get_default_user_group()
-          if default_group is not None:
-            user.groups.add(default_group)
+          ensure_has_a_group(user)
 
 
           user.save()
           user.save()
 
 
@@ -445,10 +444,7 @@ class LdapBackend(object):
       user.is_superuser = is_super
       user.is_superuser = is_super
       user = rewrite_user(user)
       user = rewrite_user(user)
 
 
-      default_group = get_default_user_group()
-      if default_group is not None:
-        user.groups.add(default_group)
-        user.save()
+      ensure_has_a_group(user)
 
 
       if desktop.conf.LDAP.SYNC_GROUPS_ON_LOGIN.get():
       if desktop.conf.LDAP.SYNC_GROUPS_ON_LOGIN.get():
         self.import_groups(server, user)
         self.import_groups(server, user)
@@ -497,9 +493,7 @@ class SpnegoDjangoBackend(django.contrib.auth.backends.ModelBackend):
         profile.save()
         profile.save()
         user.is_superuser = is_super
         user.is_superuser = is_super
 
 
-        default_group = get_default_user_group()
-        if default_group is not None:
-          user.groups.add(default_group)
+        ensure_has_a_group(user)
 
 
         user.save()
         user.save()
 
 
@@ -542,9 +536,7 @@ class RemoteUserDjangoBackend(django.contrib.auth.backends.RemoteUserBackend):
         profile.save()
         profile.save()
         user.is_superuser = is_super
         user.is_superuser = is_super
 
 
-        default_group = get_default_user_group()
-        if default_group is not None:
-          user.groups.add(default_group)
+        ensure_has_a_group(user)
 
 
         user.save()
         user.save()
 
 

+ 89 - 1
desktop/core/src/desktop/auth/views_test.py

@@ -18,17 +18,19 @@
 from nose.tools import assert_true, assert_false, assert_equal
 from nose.tools import assert_true, assert_false, assert_equal
 
 
 from django.conf import settings
 from django.conf import settings
-from django.contrib.auth.models import User
+from django.contrib.auth.models import User, Group
 from django.test.client import Client
 from django.test.client import Client
 
 
 from desktop import conf, middleware
 from desktop import conf, middleware
 from desktop.auth import backend
 from desktop.auth import backend
 from django_auth_ldap import backend as django_auth_ldap_backend
 from django_auth_ldap import backend as django_auth_ldap_backend
 from desktop.lib.django_test_util import make_logged_in_client
 from desktop.lib.django_test_util import make_logged_in_client
+from desktop.lib.test_utils import add_to_group
 from hadoop.test_base import PseudoHdfsTestBase
 from hadoop.test_base import PseudoHdfsTestBase
 from hadoop import pseudo_hdfs4
 from hadoop import pseudo_hdfs4
 
 
 from useradmin import ldap_access
 from useradmin import ldap_access
+from useradmin.models import get_default_user_group
 from useradmin.tests import LdapTestConnection
 from useradmin.tests import LdapTestConnection
 from useradmin.views import import_ldap_groups
 from useradmin.views import import_ldap_groups
 
 
@@ -163,6 +165,47 @@ class TestLdapLogin(PseudoHdfsTestBase):
     assert_equal(200, response.status_code, "Expected ok status.")
     assert_equal(200, response.status_code, "Expected ok status.")
     assert_false(response.context['first_login_ever'])
     assert_false(response.context['first_login_ever'])
 
 
+  def test_login_does_not_reset_groups(self):
+    client = make_logged_in_client(username=self.test_username, password="test")
+
+    user = User.objects.get(username=self.test_username)
+    test_group, created = Group.objects.get_or_create(name=self.test_username)
+    default_group = get_default_user_group()
+
+    user.groups.all().delete()
+    assert_false(user.groups.exists())
+
+    # No groups
+    response = client.post('/accounts/login/', dict(username=self.test_username, password="test"), follow=True)
+    assert_equal(200, response.status_code, "Expected ok status.")
+    assert_equal([default_group.name], list(user.groups.values_list('name', flat=True)))
+
+    add_to_group(self.test_username, self.test_username)
+
+    # Two groups
+    client.get('/accounts/logout')
+    response = client.post('/accounts/login/', dict(username=self.test_username, password="test"), follow=True)
+    assert_equal(200, response.status_code, "Expected ok status.")
+    assert_equal(set([default_group.name, test_group.name]), set(user.groups.values_list('name', flat=True)))
+
+    user.groups.filter(name=default_group.name).delete()
+    assert_equal(set([test_group.name]), set(user.groups.values_list('name', flat=True)))
+
+    # Keep manual group only, don't re-add default group
+    client.get('/accounts/logout')
+    response = client.post('/accounts/login/', dict(username=self.test_username, password="test"), follow=True)
+    assert_equal(200, response.status_code, "Expected ok status.")
+    assert_equal([test_group.name], list(user.groups.values_list('name', flat=True)))
+
+    user.groups.remove(test_group)
+    assert_false(user.groups.exists())
+
+    # Re-add default group
+    client.get('/accounts/logout')
+    response = client.post('/accounts/login/', dict(username=self.test_username, password="test"), follow=True)
+    assert_equal(200, response.status_code, "Expected ok status.")
+    assert_equal([default_group.name], list(user.groups.values_list('name', flat=True)))
+
   def test_login_home_creation_failure(self):
   def test_login_home_creation_failure(self):
     response = self.c.get('/accounts/login/')
     response = self.c.get('/accounts/login/')
     assert_equal(200, response.status_code, "Expected ok status.")
     assert_equal(200, response.status_code, "Expected ok status.")
@@ -556,6 +599,51 @@ class TestLogin(PseudoHdfsTestBase):
     assert_equal(200, response.status_code, "Expected unauthorized status.")
     assert_equal(200, response.status_code, "Expected unauthorized status.")
 
 
 
 
+class TestLoginNoHadoop(object):
+
+  reset = []
+  test_username = "test_login_no_hadoop"
+
+  @classmethod
+  def setup_class(cls):
+    # Simulate first login ever
+    User.objects.all().delete()
+
+    cls.auth_backends = settings.AUTHENTICATION_BACKENDS
+    settings.AUTHENTICATION_BACKENDS = ('desktop.auth.backend.AllowFirstUserDjangoBackend',)
+
+  @classmethod
+  def teardown_class(cls):
+    settings.AUTHENTICATION_BACKENDS = cls.auth_backends
+
+  def setUp(self):
+    self.c = Client()
+
+    self.reset.append( conf.AUTH.BACKEND.set_for_testing(['desktop.auth.backend.AllowFirstUserDjangoBackend']) )
+
+  def tearDown(self):
+    for finish in self.reset:
+      finish()
+
+    User.objects.all().delete()
+    if Group.objects.filter(name=self.test_username).exists():
+      Group.objects.filter(name=self.test_username).delete()
+
+  def test_login_does_not_reset_groups(self):
+    self.reset.append( conf.AUTH.BACKEND.set_for_testing(["desktop.auth.backend.AllowFirstUserDjangoBackend"]) )
+
+    client = make_logged_in_client(username=self.test_username, password="test")
+    client.get('/accounts/logout')
+    user = User.objects.get(username=self.test_username)
+    group, created = Group.objects.get_or_create(name=self.test_username)
+
+    user.groups.all().delete()
+    assert_false(user.groups.exists())
+
+    response = client.post('/accounts/login/', dict(username=self.test_username, password="test"), follow=True)
+    assert_equal(200, response.status_code, "Expected ok status.")
+
+
 class MockLdapBackend(object):
 class MockLdapBackend(object):
   settings = django_auth_ldap_backend.LDAPSettings()
   settings = django_auth_ldap_backend.LDAPSettings()