Ver código fonte

[api] Make sure request user is wrapped with UserProfile

Romain Rigaux 4 anos atrás
pai
commit
264ba76c67

+ 22 - 17
desktop/core/src/desktop/api_public.py

@@ -27,36 +27,33 @@ from desktop.auth.backend import rewrite_user
 
 @api_view(["POST"])
 def get_config(request):
-  django_request = request._request
-
-  django_request.user = rewrite_user(django_request.user)
-
+  django_request = get_django_request(request)
   return desktop_api.get_config(django_request)
 
 @api_view(["GET"])
 def get_context_namespaces(request, interface):
-  django_request = request._request
+  django_request = get_django_request(request)
   return desktop_api.get_context_namespaces(django_request, interface)
 
 
 @api_view(["POST"])
 def create_notebook(request):
-  django_request = request._request
+  django_request = get_django_request(request)
   return notebook_api.create_notebook(django_request)
 
 @api_view(["POST"])
 def create_session(request):
-  django_request = request._request
+  django_request = get_django_request(request)
   return notebook_api.create_session(django_request)
 
 @api_view(["POST"])
 def close_session(request):
-  django_request = request._request
+  django_request = get_django_request(request)
   return notebook_api.close_session(django_request)
 
 @api_view(["POST"])
 def execute(request, dialect=None):
-  django_request = request._request
+  django_request = get_django_request(request)
 
   if not request.POST.get('notebook'):
     interpreter = _get_interpreter_from_dialect(dialect=dialect, user=django_request.user)
@@ -79,7 +76,7 @@ def execute(request, dialect=None):
 
 @api_view(["POST"])
 def check_status(request):
-  django_request = request._request
+  django_request = get_django_request(request)
 
   _patch_operation_id_request(django_request)
 
@@ -87,7 +84,7 @@ def check_status(request):
 
 @api_view(["POST"])
 def fetch_result_data(request):
-  django_request = request._request
+  django_request = get_django_request(request)
 
   _patch_operation_id_request(django_request)
 
@@ -95,27 +92,27 @@ def fetch_result_data(request):
 
 @api_view(["POST"])
 def fetch_result_metadata(request):
-  django_request = request._request
+  django_request = get_django_request(request)
   return notebook_api.fetch_result_metadata(django_request)
 
 @api_view(["POST"])
 def fetch_result_size(request):
-  django_request = request._request
+  django_request = get_django_request(request)
   return notebook_api.fetch_result_size(django_request)
 
 @api_view(["POST"])
 def cancel_statement(request):
-  django_request = request._request
+  django_request = get_django_request(request)
   return notebook_api.cancel_statement(django_request)
 
 @api_view(["POST"])
 def close_statement(request):
-  django_request = request._request
+  django_request = get_django_request(request)
   return notebook_api.close_statement(django_request)
 
 @api_view(["POST"])
 def get_logs(request):
-  django_request = request._request
+  django_request = get_django_request(request)
 
   _patch_operation_id_request(django_request)
 
@@ -124,7 +121,7 @@ def get_logs(request):
 
 @api_view(["POST"])
 def autocomplete(request, server=None, database=None, table=None, column=None, nested=None):
-  django_request = request._request
+  django_request = get_django_request(request)
   return notebook_api.autocomplete(django_request, server, database, table, column, nested)
 
 
@@ -151,3 +148,11 @@ def _patch_operation_id_request(django_request):
 
     django_request.POST = QueryDict(mutable=True)
     django_request.POST.update(data)
+
+
+def get_django_request(request):
+  django_request = request._request
+
+  django_request.user = rewrite_user(django_request.user)
+
+  return django_request

+ 18 - 0
desktop/core/src/desktop/api_public_tests.py

@@ -16,14 +16,22 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import sys
+
 from nose.tools import assert_true, assert_false, assert_equal, assert_not_equal, assert_raises
 from django.urls import reverse
 
+from desktop.api_public import get_django_request
 from desktop.lib.django_test_util import make_logged_in_client
 from desktop.lib.test_utils import grant_access
 
 from useradmin.models import User
 
+if sys.version_info[0] > 2:
+  from unittest.mock import patch, Mock, MagicMock
+else:
+  from mock import patch, Mock, MagicMock
+
 
 class TestEditorApi():
 
@@ -39,3 +47,13 @@ class TestEditorApi():
 
   def test_urls_exist(self):
     assert_equal(reverse('api:editor_execute', args=['hive']), '/api/editor/execute/hive')
+
+
+  def test_get_django_request(self):
+    request = Mock()
+
+    django_request = get_django_request(request)
+
+    assert_true(
+      hasattr(django_request.user, 'has_hue_permission')
+    )