Browse Source

Improve SAML group check logic, We are now checking one of required groups must be available from the SAML response. (#2182)

(cherry picked from commit 6965660f78eb0647d666a250092b1e05a620d55b)
Prakash Ranade 4 years ago
parent
commit
11ba711bcb
2 changed files with 63 additions and 11 deletions
  1. 58 10
      apps/useradmin/src/useradmin/tests.py
  2. 5 1
      desktop/core/src/desktop/views.py

+ 58 - 10
apps/useradmin/src/useradmin/tests.py

@@ -29,6 +29,7 @@ import urllib.request, urllib.parse, urllib.error
 from nose.plugins.skip import SkipTest
 from nose.plugins.skip import SkipTest
 from nose.tools import assert_true, assert_equal, assert_false, assert_not_equal
 from nose.tools import assert_true, assert_equal, assert_false, assert_not_equal
 from datetime import datetime
 from datetime import datetime
+from django.conf import settings
 from django.contrib.sessions.models import Session
 from django.contrib.sessions.models import Session
 from django.db.models import Q
 from django.db.models import Q
 from django.urls import reverse
 from django.urls import reverse
@@ -43,10 +44,11 @@ from desktop.conf import APP_BLACKLIST, ENABLE_ORGANIZATIONS, ENABLE_PROMETHEUS
 from desktop.lib.django_test_util import make_logged_in_client
 from desktop.lib.django_test_util import make_logged_in_client
 from desktop.lib.i18n import smart_unicode
 from desktop.lib.i18n import smart_unicode
 from desktop.lib.test_utils import grant_access
 from desktop.lib.test_utils import grant_access
-from desktop.views import home
+from desktop.views import home, samlgroup_check
 from hadoop import pseudo_hdfs4
 from hadoop import pseudo_hdfs4
 from hadoop.pseudo_hdfs4 import is_live_cluster
 from hadoop.pseudo_hdfs4 import is_live_cluster
 
 
+import libsaml.conf
 import useradmin.conf
 import useradmin.conf
 import useradmin.ldap_access
 import useradmin.ldap_access
 from useradmin.forms import UserChangeForm
 from useradmin.forms import UserChangeForm
@@ -60,6 +62,15 @@ if sys.version_info[0] > 2:
 else:
 else:
   from mock import patch, Mock
   from mock import patch, Mock
 
 
+class MockRequest(dict):
+  pass
+
+class MockUser(dict):
+  def is_authenticated(self):
+    return True
+
+class MockSession(dict):
+  pass
 
 
 def reset_all_users():
 def reset_all_users():
   """Reset to a clean state by deleting all users"""
   """Reset to a clean state by deleting all users"""
@@ -352,6 +363,52 @@ class TestUserProfile(BaseUserAdminTests):
     userprofile = get_profile(user)
     userprofile = get_profile(user)
     assert_equal('es', userprofile.data['language_preference'])
     assert_equal('es', userprofile.data['language_preference'])
 
 
+class TestSAMLGroupsCheck(BaseUserAdminTests):
+  def test_saml_group_conditions_check(self):
+    if sys.version_info[0] > 2:
+      raise SkipTest
+    reset = []
+    old_settings = settings.AUTHENTICATION_BACKENDS
+    try:
+      c = make_logged_in_client(username='test2', password='test2', is_superuser=False, recreate=True)
+      settings.AUTHENTICATION_BACKENDS = ["libsaml.backend.SAML2Backend"]
+      request = MockRequest()
+
+      user = User.objects.get(username='test2')
+      userprofile = get_profile(user)
+      request.user = user
+
+      # In case of no valid saml response from server.
+      reset.append(libsaml.conf.REQUIRED_GROUPS_ATTRIBUTE.set_for_testing("groups"))
+      reset.append(libsaml.conf.REQUIRED_GROUPS.set_for_testing(["ddd"]))
+      assert_false(desktop.views.samlgroup_check(request))
+
+      # mock saml response
+      userprofile.update_data({"saml_attributes":{"first_name":["test2"],
+                                                  "last_name":["test2"],
+                                                  "email":["test2@test.com"],
+                                                  "groups":["aaa","bbb","ccc"]}})
+      userprofile.save()
+
+      # valid one or more valid required groups
+      reset.append(libsaml.conf.REQUIRED_GROUPS_ATTRIBUTE.set_for_testing("groups"))
+      reset.append(libsaml.conf.REQUIRED_GROUPS.set_for_testing(["aaa","ddd"]))
+      assert_true(desktop.views.samlgroup_check(request))
+
+      # invalid required group
+      reset.append(libsaml.conf.REQUIRED_GROUPS_ATTRIBUTE.set_for_testing("groups"))
+      reset.append(libsaml.conf.REQUIRED_GROUPS.set_for_testing(["ddd"]))
+      assert_false(desktop.views.samlgroup_check(request))
+
+      # different samlresponse for group attribute
+      reset.append(libsaml.conf.REQUIRED_GROUPS_ATTRIBUTE.set_for_testing("members"))
+      reset.append(libsaml.conf.REQUIRED_GROUPS.set_for_testing(["ddd"]))
+      assert_false(desktop.views.samlgroup_check(request))
+    finally:
+      settings.AUTHENTICATION_BACKENDS = old_settings
+      for r in reset:
+        r()
+
 
 
 class TestUserAdminMetrics(BaseUserAdminTests):
 class TestUserAdminMetrics(BaseUserAdminTests):
 
 
@@ -1436,15 +1493,6 @@ class LastActivityMiddlewareTests(object):
       for f in reset:
       for f in reset:
         f()
         f()
 
 
-class MockRequest(dict):
-  pass
-
-class MockUser(dict):
-  def is_authenticated(self):
-    return True
-
-class MockSession(dict):
-  pass
 
 
 class ConcurrentUserSessionMiddlewareTests(object):
 class ConcurrentUserSessionMiddlewareTests(object):
   def setUp(self):
   def setUp(self):

+ 5 - 1
desktop/core/src/desktop/views.py

@@ -101,11 +101,15 @@ def samlgroup_check(request):
         LOG.info("Missing %s in SAMLResponse for %s user" % (REQUIRED_GROUPS_ATTRIBUTE.get(), request.user.username))
         LOG.info("Missing %s in SAMLResponse for %s user" % (REQUIRED_GROUPS_ATTRIBUTE.get(), request.user.username))
         return False
         return False
 
 
-      saml_group_found = set(REQUIRED_GROUPS.get()).issubset(
+      # Earlier we had AND condition, It means user has to be there in all given groups.
+      # Now we are doing OR condition, which means user must be in one of the given groups.
+      saml_group_found = set(REQUIRED_GROUPS.get()).intersection(
                          set(json_data['saml_attributes'].get(REQUIRED_GROUPS_ATTRIBUTE.get())))
                          set(json_data['saml_attributes'].get(REQUIRED_GROUPS_ATTRIBUTE.get())))
       if not saml_group_found:
       if not saml_group_found:
         LOG.info("User %s not found in required SAML groups, %s" % (request.user.username, REQUIRED_GROUPS.get()))
         LOG.info("User %s not found in required SAML groups, %s" % (request.user.username, REQUIRED_GROUPS.get()))
         return False
         return False
+
+      LOG.info("User %s found in the required SAML groups %s" % (request.user.username, ",".join(saml_group_found)))
   return True
   return True
 
 
 def hue(request):
 def hue(request):