Bladeren bron

[flink] Fix RecursionError in create_session method (#4085)

Grzegorz Kołakowski 7 maanden geleden
bovenliggende
commit
c854354e6f

+ 20 - 19
desktop/libs/notebook/src/notebook/connectors/flink_sql.py

@@ -79,6 +79,7 @@ class FlinkSqlApi(Api):
 
   @query_error_handler
   def create_session(self, lang=None, properties=None):
+    LOG.info("Creating session for %s.", lang)
     session = self._get_session()
 
     response = {
@@ -162,26 +163,30 @@ class FlinkSqlApi(Api):
     session = self._get_session_info_from_user()
 
     if not session:
-      session = self.db.create_session()
-      if self.default_database:
-        self._use_database(self.default_catalog, self.default_database)
-      elif self.default_catalog:
-        self._use_catalog(self.default_catalog)
-
+      session = self._create_session()
     try:
-      self.db.session_heartbeat(session_handle=session['sessionHandle'])
+      self.db.session_heartbeat(session_handle=session['id'])
     except Exception as e:
       if "Session '%(sessionHandle)s' does not exist" % session in str(e):
         LOG.warning('Session %(sessionHandle)s does not exist, opening a new one' % session)
-        session = self.db.create_session()
+        session = self._create_session()
       else:
         raise e
 
-    session['id'] = session['sessionHandle']
     self._set_session_info_to_user(session)
 
     return session
 
+  def _create_session(self):
+    session = self.db.create_session()
+    session['id'] = session['sessionHandle']
+
+    if self.default_database:
+      self._use_database(session, self.default_catalog, self.default_database)
+    elif self.default_catalog:
+      self._use_catalog(session, self.default_catalog)
+    return session
+
   @query_error_handler
   def execute(self, notebook, snippet):
     session = self._get_session()
@@ -461,19 +466,15 @@ class FlinkSqlApi(Api):
 
     return [{'name': function[0]} for function in function_list]
 
-  def _use_catalog(self, catalog):
-    session = self._get_session()
-    self.db.configure_session(session_handle=(session['id']), statement="USE CATALOG `%s`" % catalog)
+  def _use_catalog(self, session, catalog):
+    self.db.configure_session(session['id'], "USE CATALOG `%s`" % catalog)
 
-  def _use_database(self, catalog, database):
-    session = self._get_session()
+  def _use_database(self, session, catalog, database):
     if catalog:
-      self.db.configure_session(session_handle=(session['id']),
-                                statement="USE `%(catalog)s`.`%(database)s`" % {'catalog': catalog,
-                                                                                'database': database})
+      self.db.configure_session(session['id'], "USE `%(catalog)s`.`%(database)s`" % {'catalog': catalog,
+                                                                                     'database': database})
     else:
-      self.db.configure_session(session_handle=(session['id']),
-                                statement="USE `%(database)s`" % {'database': database})
+      self.db.configure_session(session['id'], "USE `%(database)s`" % {'database': database})
 
 
 class FlinkSqlClient:

+ 74 - 0
desktop/libs/notebook/src/notebook/connectors/flink_sql_tests.py

@@ -0,0 +1,74 @@
+#!/usr/bin/env python
+# Licensed to Cloudera, Inc. under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  Cloudera, Inc. licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest.mock import MagicMock, patch
+
+from django.test import TestCase
+
+from desktop.lib.django_test_util import make_logged_in_client
+from notebook.connectors.flink_sql import FlinkSqlApi
+from useradmin.models import User
+
+
+class TestFlinkApi(TestCase):
+  def setup_method(self, test_method):
+    self.client = make_logged_in_client(username="hue_test", groupname="default", recreate=True, is_superuser=False)
+    self.user = User.objects.get(username="hue_test")
+    self.interpreter = {
+      'options': {
+        'url': 'https://example.com:8081',
+      },
+      'name': 'flink',
+    }
+
+  @patch('notebook.connectors.flink_sql.FlinkSqlClient')
+  def test_create_session(self, client_mock):
+    # given: mock interactions
+    mock_client_instance = MagicMock()
+    client_mock.return_value = mock_client_instance
+    mock_client_instance.create_session.return_value = {'sessionHandle': '657c12d4-5509-477f-a460-ea6af927906d'}
+
+    # and: FlinkSqlApi instance
+    flink_api = FlinkSqlApi(self.user, interpreter=self.interpreter)
+
+    # when
+    created_session = flink_api.create_session(lang='flink', properties=None)
+
+    # then
+    assert created_session == {'id': '657c12d4-5509-477f-a460-ea6af927906d', 'type': 'flink'}
+    assert mock_client_instance.session_heartbeat.call_count == 1
+
+  @patch('notebook.connectors.flink_sql.FlinkSqlClient')
+  def test_create_session_with_default_catalog_and_database(self, client_mock):
+    # given: mock interactions
+    mock_client_instance = MagicMock()
+    client_mock.return_value = mock_client_instance
+    mock_client_instance.create_session.return_value = {'sessionHandle': '657c12d4-5509-477f-a460-ea6af927906d'}
+
+    # and: FlinkSqlApi instance with configuration
+    self.interpreter['options']['default_catalog'] = 'default_catalog'
+    self.interpreter['options']['default_database'] = 'default_database'
+    flink_api = FlinkSqlApi(self.user, interpreter=self.interpreter)
+
+    # when
+    created_session = flink_api.create_session(lang='flink', properties=None)
+
+    # then
+    assert created_session == {'id': '657c12d4-5509-477f-a460-ea6af927906d', 'type': 'flink'}
+    mock_client_instance.configure_session.assert_called_once_with(
+      '657c12d4-5509-477f-a460-ea6af927906d', "USE `default_catalog`.`default_database`"
+    )
+    assert mock_client_instance.session_heartbeat.call_count == 1