Sfoglia il codice sorgente

[Editor] Preserve precision of numeric types in results (#4286)

Ayush Goyal 1 mese fa
parent
commit
94e566472e

+ 18 - 16
desktop/libs/notebook/src/notebook/models.py

@@ -15,12 +15,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import datetime
 import json
-import math
-import uuid
 import logging
+import math
 import numbers
-import datetime
+import uuid
 from builtins import object, str
 from datetime import timedelta
 from urllib.parse import quote as urllib_quote
@@ -37,8 +37,8 @@ from desktop.lib.i18n import smart_str
 from desktop.lib.paths import SAFE_CHARACTERS_URI
 from desktop.models import Directory, Document2
 from notebook.conf import EXAMPLES, get_ordered_interpreters
-from notebook.connectors.base import Notebook, get_api as _get_api, get_interpreter
-from useradmin.models import User, install_sample_user
+from notebook.connectors.base import get_api as _get_api, get_interpreter, Notebook
+from useradmin.models import install_sample_user, User
 
 LOG = logging.getLogger()
 
@@ -52,7 +52,9 @@ def escape_rows(rows, nulls_only=False, encoding=None):
       escaped_row = []
 
       for field in row:
-        if isinstance(field, numbers.Number):
+        if isinstance(field, float):
+          escaped_field = str(field)
+        elif isinstance(field, numbers.Number):
           if math.isnan(field) or math.isinf(field):
             escaped_field = json.dumps(field)
           else:
@@ -341,7 +343,7 @@ def import_saved_mapreduce_job(wf):
     files = json.loads(node.files)
     for filepath in files:
       snippet_properties['files'].append({'type': 'file', 'path': filepath})
-  except ValueError as e:
+  except ValueError:
     LOG.warning('Failed to parse files for mapreduce job design "%s".' % wf.name)
 
   snippet_properties['archives'] = []
@@ -349,7 +351,7 @@ def import_saved_mapreduce_job(wf):
     archives = json.loads(node.archives)
     for filepath in archives:
       snippet_properties['archives'].append(filepath)
-  except ValueError as e:
+  except ValueError:
     LOG.warning('Failed to parse archives for mapreduce job design "%s".' % wf.name)
 
   snippet_properties['hadoopProperties'] = []
@@ -358,7 +360,7 @@ def import_saved_mapreduce_job(wf):
     if properties:
       for prop in properties:
         snippet_properties['hadoopProperties'].append("%s=%s" % (prop.get('name'), prop.get('value')))
-  except ValueError as e:
+  except ValueError:
     LOG.warning('Failed to parse job properties for mapreduce job design "%s".' % wf.name)
 
   snippet_properties['app_jar'] = node.jar_path
@@ -398,7 +400,7 @@ def import_saved_shell_job(wf):
           snippet_properties['arguments'].append(param['value'])
         else:
           snippet_properties['env_var'].append(param['value'])
-  except ValueError as e:
+  except ValueError:
     LOG.warning('Failed to parse parameters for shell job design "%s".' % wf.name)
 
   snippet_properties['hadoopProperties'] = []
@@ -407,7 +409,7 @@ def import_saved_shell_job(wf):
     if properties:
       for prop in properties:
         snippet_properties['hadoopProperties'].append("%s=%s" % (prop.get('name'), prop.get('value')))
-  except ValueError as e:
+  except ValueError:
     LOG.warning('Failed to parse job properties for shell job design "%s".' % wf.name)
 
   snippet_properties['files'] = []
@@ -415,7 +417,7 @@ def import_saved_shell_job(wf):
     files = json.loads(node.files)
     for filepath in files:
       snippet_properties['files'].append({'type': 'file', 'path': filepath})
-  except ValueError as e:
+  except ValueError:
     LOG.warning('Failed to parse files for shell job design "%s".' % wf.name)
 
   snippet_properties['archives'] = []
@@ -423,7 +425,7 @@ def import_saved_shell_job(wf):
     archives = json.loads(node.archives)
     for archive in archives:
       snippet_properties['archives'].append(archive['name'])
-  except ValueError as e:
+  except ValueError:
     LOG.warning('Failed to parse archives for shell job design "%s".' % wf.name)
 
   snippet_properties['capture_output'] = node.capture_output
@@ -462,7 +464,7 @@ def import_saved_java_job(wf):
     if properties:
       for prop in properties:
         snippet_properties['hadoopProperties'].append("%s=%s" % (prop.get('name'), prop.get('value')))
-  except ValueError as e:
+  except ValueError:
     LOG.warning('Failed to parse job properties for Java job design "%s".' % wf.name)
 
   snippet_properties['files'] = []
@@ -470,7 +472,7 @@ def import_saved_java_job(wf):
     files = json.loads(node.files)
     for filepath in files:
       snippet_properties['files'].append({'type': 'file', 'path': filepath})
-  except ValueError as e:
+  except ValueError:
     LOG.warning('Failed to parse files for Java job design "%s".' % wf.name)
 
   snippet_properties['archives'] = []
@@ -478,7 +480,7 @@ def import_saved_java_job(wf):
     archives = json.loads(node.archives)
     for archive in archives:
       snippet_properties['archives'].append(archive['name'])
-  except ValueError as e:
+  except ValueError:
     LOG.warning('Failed to parse archives for Java job design "%s".' % wf.name)
 
   snippet_properties['capture_output'] = node.capture_output

+ 31 - 4
desktop/libs/notebook/src/notebook/models_tests.py

@@ -16,22 +16,49 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import sys
-import json
 import logging
-from unittest.mock import MagicMock, Mock, patch
+from unittest.mock import patch
 
 import pytest
 
 from desktop.lib.django_test_util import make_logged_in_client
 from desktop.models import Document2
 from notebook.conf import EXAMPLES
-from notebook.models import Analytics, install_custom_examples
+from notebook.models import Analytics, escape_rows, install_custom_examples
 from useradmin.models import User
 
 LOG = logging.getLogger()
 
 
+class TestEscapeRows:
+
+  def test_escape_rows_precision(self):
+    # Test data containing various types, including float
+    test_data = [
+      [1, 'Alice', 29.0],
+      [2, 'Bob', 30.67],
+      [3, 'Charlie', 25.5],
+      [4, 'David', 40.05],
+      [5, None, 29.10],
+      [6, 'Eve', 100]
+    ]
+
+    # Expected result after escaping
+    expected_result = [
+      [1, 'Alice', '29.0'],
+      [2, 'Bob', '30.67'],
+      [3, 'Charlie', '25.5'],
+      [4, 'David', '40.05'],
+      [5, 'NULL', '29.1'],
+      [6, 'Eve', 100]
+    ]
+
+    result = escape_rows(test_data)
+
+    # Assert that the result matches the expected output
+    assert result == expected_result
+
+
 @pytest.mark.django_db
 class TestAnalytics(object):