浏览代码

[sparksql] Improve complex data type result parsing and add UTs (#3422)

Harsh Gupta 2 年之前
父节点
当前提交
481b1d3a1a

+ 25 - 8
desktop/libs/notebook/src/notebook/connectors/spark_shell.py

@@ -303,22 +303,39 @@ class SparkApi(Api):
 
 
   def _handle_result_data(self, result, is_complex_type=False):
-    data = []
+    """
+    Parse the data from the 'result' dict based on whether it has complex datatypes or not.
+
+    If the 'is_complex_type' flag is True, it parses the result dict, checking for 'schema' and 'values' 
+    and if found, formatting them into a appropriate result data dictionary representing that result column. 
+    If the flag is False, it simply returns the 'data' as is.
+
+    Args:
+      result (dict): A dict containing the query result data from Livy to be parsed.
+      is_complex_type (bool, optional): A flag indicating whether the data has complex datatypes.
 
+    Returns:
+      list: A list of result data where each element represents a result row and each result row contains formatted columns.
+    """
+    data = []
     if is_complex_type:
+      # If the query result contains complex datatypes, we are formatting the rows.
       for row in result['data']:
         row_data = []
-        for ele in row:
-          if isinstance(ele, dict):
-            row_schema = []
-            for val in ele['schema']:
-              row_schema.append(val['name'])
-            row_data.append(dict(zip(row_schema, ele['values'])))
+        for element in row:
+          if isinstance(element, dict) and 'schema' in element and 'values' in element:
+            # Extract the row_schema from the 'schema' dict.
+            row_schema = [val['name'] for val in element['schema']]
+
+            # Combine row_schema with 'values' to create the 'row_data' dict and add as a result column.
+            row_data.append(dict(zip(row_schema, element['values'])))
           else:
-            row_data.append(ele)
+            # If the element is not a valid dict, add it to row_data as is.
+            row_data.append(element)
 
         data.append(row_data)
     else:
+      # If the query result is not having complex datatype, return the 'data' as it is.
       data = result['data']
     
     return data

+ 39 - 0
desktop/libs/notebook/src/notebook/connectors/spark_shell_tests.py

@@ -176,6 +176,45 @@ class TestSparkApi(object):
         assert_raises(Exception, self.api.execute, notebook, snippet)
 
 
+  def test_handle_result_data(self):
+    # When result data has no complex type.
+    data = {
+      'data': [[1, 'Test']]
+    }
+    processed_data = self.api._handle_result_data(data, is_complex_type=False)
+    assert_equal(processed_data, [[1, 'Test']])
+
+    # When result data has struct complex type with 'schema' and 'values'.
+    data = {
+      'data': [[1, 'Test',
+              {
+                'schema': [{
+                      'dataType': {},
+                      'metadata': {'map': {}},
+                      'name': 'city',
+                      'nullable': True
+                    },
+                    {
+                      'dataType': {},
+                      'metadata': {'map': {}},
+                      'name': 'State',
+                      'nullable': True
+                    }
+                  ],
+                'values': ['Toronto', 'ON']}]]}
+
+    processed_data = self.api._handle_result_data(data, is_complex_type=True)
+    assert_equal(processed_data, [[1, 'Test', {'State': 'ON', 'city': 'Toronto'}]])
+
+    # When result data has map complex type.
+    data = {
+      'data': [['0', 535.0, {'site_id': 'BEB'}, {'c_id': 'EF'}, '2023-06-16T23:53:31Z']]
+    }
+
+    processed_data = self.api._handle_result_data(data, is_complex_type=True)
+    assert_equal(processed_data, [['0', 535.0, {'site_id': 'BEB'}, {'c_id': 'EF'}, '2023-06-16T23:53:31Z']])
+
+
   def test_check_status(self):
     with patch('notebook.connectors.spark_shell._get_snippet_session') as _get_snippet_session:
       with patch('notebook.connectors.spark_shell.get_spark_api') as get_spark_api: