Browse Source

PR1125 [spark] Fix merging of custom configuration for connectors with default (#1125)

* fix(spark_shell): merger custom config with default

* refactoring

* add test for spark_shell.create_session

* refactor

* rename get_props to get_livy_props

* Fix

* Fix

* Add test user_config is None

Co-authored-by: Ilya Makarov <makarov_ia@nlmk.com>
Co-authored-by: Romain Rigaux <romain.rigaux@gmail.com>
e11it 5 năm trước cách đây
mục cha
commit
b12709aec8

+ 38 - 38
desktop/libs/notebook/src/notebook/connectors/spark_shell.py

@@ -151,21 +151,10 @@ class SparkApi(Api):
   STANDALONE_JOB_RE = re.compile("Got job (\d+)")
 
   @staticmethod
-  def get_properties():
-    return SparkConfiguration.PROPERTIES
-
-  def create_session(self, lang='scala', properties=None):
-    if not properties:
-      config = None
-      if USE_DEFAULT_CONFIGURATION.get():
-        config = DefaultConfiguration.objects.get_configuration_for_user(app='spark', user=self.user)
-
-      if config is not None:
-        properties = config.properties_list
-      else:
-        properties = self.get_properties()
-
-    props = dict([(p['name'], p['value']) for p in properties]) if properties is not None else {}
+  def get_livy_props(lang, properties=None):
+    props = dict([(p['name'], p['value']) for p in SparkConfiguration.PROPERTIES])
+    if properties is not None:
+      props.update(dict([(p['name'], p['value']) for p in properties]))
 
     # HUE-4761: Hue's session request is causing Livy to fail with "JsonMappingException: Can not deserialize
     # instance of scala.collection.immutable.List out of VALUE_STRING token" due to List type values
@@ -176,35 +165,46 @@ class SparkApi(Api):
     # empty list '[]' for these four values.
     # Note also that Livy has a 90 second timeout for the session request to complete, this needs to
     # be increased for requests that take longer, for example when loading large archives.
-    tmp_archives = props['archives']
-    if type(tmp_archives) is not list:
-      props['archives'] = tmp_archives.split(",")
-      LOG.debug("Check List type: archives was not a list")
-
-    tmp_jars = props['jars']
-    if type(tmp_jars) is not list:
-      props['jars'] = tmp_jars.split(",")
-      LOG.debug("Check List type: jars was not a list")
-
-    tmp_files = props['files']
-    if type(tmp_files) is not list:
-      props['files'] = tmp_files.split(",")
-      LOG.debug("Check List type: files was not a list")
-
-    tmp_py_files = props['pyFiles']
-    if type(tmp_py_files) is not list:
-      props['pyFiles'] = tmp_py_files.split(",")
-      LOG.debug("Check List type: pyFiles was not a list")
-
+    for key in ['archives','jars','files','pyFiles']:
+      if key not in props:
+        continue
+      if type(props[key]) is list:
+        continue
+      LOG.debug("Check List type: {} was not a list".format(key))
+      _tmp = props[key]
+      props[key] = _tmp.split(",")
+    
     # Convert the conf list to a dict for Livy
     LOG.debug("Property Spark Conf kvp list from UI is: " + str(props['conf']))
     props['conf'] = {conf.get('key'): conf.get('value') for i, conf in enumerate(props['conf'])}
     LOG.debug("Property Spark Conf dictionary is: " + str(props['conf']))
-
+    
     props['kind'] = lang
+      
+    return props
 
-    api = get_spark_api(self.user)
+  @staticmethod
+  def to_properties(props=None):
+    properties = list()
+    for p in SparkConfiguration.PROPERTIES:
+      properties.append(p.copy())
+
+    if props is not None:
+      for p in properties:
+        if p['name'] in props:
+          p['value'] = props[p['name']]
 
+    return properties
+
+  def create_session(self, lang='scala', properties=None):
+    if not properties and USE_DEFAULT_CONFIGURATION.get():
+      user_config = DefaultConfiguration.objects.get_configuration_for_user(app='spark', user=self.user)
+      if user_config is not None:
+        properties = user_config.properties_list
+
+    props = self.get_livy_props(lang, properties)
+
+    api = get_spark_api(self.user)
     response = api.create_session(**props)
 
     status = api.get_session(response['id'])
@@ -222,7 +222,7 @@ class SparkApi(Api):
     return {
         'type': lang,
         'id': response['id'],
-        'properties': properties
+        'properties': self.to_properties(props)
     }
 
   def execute(self, notebook, snippet):

+ 60 - 1
desktop/libs/notebook/src/notebook/connectors/spark_shell_tests.py

@@ -34,6 +34,66 @@ class TestSparkApi(object):
     self.user = 'hue_test'
     self.api = SparkApi(self.user)
 
+  def test_get_livy_props_method(self):
+    test_properties = [{
+        "name": "files",
+        "value": 'file_a,file_b,file_c',
+      }]
+    props = self.api.get_livy_props('scala', test_properties)
+    assert_equal(props['files'],['file_a','file_b','file_c'])
+    
+  def test_create_session_with_config(self):
+    lang = 'pyspark'
+    properties = None
+
+    with patch('notebook.connectors.spark_shell.get_spark_api') as get_spark_api:
+      with patch('notebook.connectors.spark_shell.DefaultConfiguration') as DefaultConfiguration:
+        with patch('notebook.connectors.spark_shell.USE_DEFAULT_CONFIGURATION') as USE_DEFAULT_CONFIGURATION:
+          DefaultConfiguration.objects.get_configuration_for_user.return_value = Mock(
+                properties_list=[
+                  {'multiple': False, 'name': 'driverCores', 'defaultValue': 1, 'value': 2, 'nice_name': 'Driver Cores',
+                   'help_text': 'Number of cores used by the driver, only in cluster mode (Default: 1)', 'type': 'number',
+                   'is_yarn': True}]
+          )
+
+          get_spark_api.return_value = Mock(
+            create_session=Mock(
+              return_value={'id': '1'}
+            ),
+            get_session=Mock(
+              return_value={'state': 'idle', 'log': ''}
+            )
+          )
+          # Case with user configuration. Expected 2 driverCores
+          USE_DEFAULT_CONFIGURATION.get.return_value = True
+          session = self.api.create_session(lang=lang, properties=properties)
+          assert_equal(session['type'], 'pyspark')
+          assert_equal(session['id'], '1')
+          for p in session['properties']:
+            if p['name'] == 'driverCores':
+              cores = p['value']
+          assert_equal(cores, 2)
+          
+          # Case without user configuration. Expected 1 driverCores
+          USE_DEFAULT_CONFIGURATION.get.return_value = True
+          DefaultConfiguration.objects.get_configuration_for_user.return_value = None
+          session2 = self.api.create_session(lang=lang, properties=properties)
+          assert_equal(session2['type'], 'pyspark')
+          assert_equal(session2['id'], '1')
+          for p in session2['properties']:
+            if p['name'] == 'driverCores':
+              cores = p['value']
+          assert_equal(cores, 1)
+
+          # Case with no user configuration. Expected 1 driverCores
+          USE_DEFAULT_CONFIGURATION.get.return_value = False
+          session3 = self.api.create_session(lang=lang, properties=properties)
+          assert_equal(session3['type'], 'pyspark')
+          assert_equal(session3['id'], '1')
+          for p in session3['properties']:
+            if p['name'] == 'driverCores':
+              cores = p['value']
+          assert_equal(cores, 1)
 
   def test_create_session_plain(self):
     lang = 'pyspark'
@@ -58,7 +118,6 @@ class TestSparkApi(object):
       assert_true(files_properties, session['properties'])
       assert_equal(files_properties[0]['value'], [], session['properties'])
 
-
   def test_get_jobs(self):
     local_jobs = [
       {'url': u'http://172.21.1.246:4040/jobs/job/?id=0', 'name': u'0'}