Browse Source

HUE-311. Hue's conf.py should warn about variables without type= parameters that look like numbers

- Added more type safety checks.
- Force `bool' type to use `coerce_bool'.
- Numeric and boolean default must match the declared type.
- Converted conf_test to use nose style tests.
bc Wong 15 years ago
parent
commit
78918e3574
2 changed files with 108 additions and 77 deletions
  1. 41 17
      desktop/core/src/desktop/lib/conf.py
  2. 67 60
      desktop/core/src/desktop/lib/conf_test.py

+ 41 - 17
desktop/core/src/desktop/lib/conf.py

@@ -15,14 +15,6 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
-from desktop.lib.paths import get_desktop_root, get_build_dir
-
-import configobj
-import logging
-import os
-import textwrap
-import sys
-
 """
 """
 The application configuration framework. The user of the framework uses
 The application configuration framework. The user of the framework uses
 * Config
 * Config
@@ -69,12 +61,26 @@ application's conf.py. During startup, Desktop binds configuration files to your
 variables.
 variables.
 """
 """
 
 
+# The Config object unfortunately has a kwarg called "type", and everybody is
+# using it. So instead of breaking compatibility, we make a "pytype" alias.
+pytype = type
+
+from desktop.lib.paths import get_desktop_root, get_build_dir
+
+import configobj
+import logging
+import os
+import textwrap
+import sys
+
 # Magical object for use as a "symbol"
 # Magical object for use as a "symbol"
 _ANONYMOUS = ("_ANONYMOUS")
 _ANONYMOUS = ("_ANONYMOUS")
 
 
 # a BoundContainer(BoundConfig) object which has all of the application's configs as members
 # a BoundContainer(BoundConfig) object which has all of the application's configs as members
 GLOBAL_CONFIG = None
 GLOBAL_CONFIG = None
 
 
+LOG = logging.getLogger(__name__)
+
 __all__ = ["UnspecifiedConfigSection", "ConfigSection", "Config", "load_confs", "coerce_bool"]
 __all__ = ["UnspecifiedConfigSection", "ConfigSection", "Config", "load_confs", "coerce_bool"]
 
 
 class BoundConfig(object):
 class BoundConfig(object):
@@ -170,16 +176,29 @@ class Config(object):
     @param dynamic_default a lambda to use to calculate the default
     @param dynamic_default a lambda to use to calculate the default
     @param required whether this must be set
     @param required whether this must be set
     @param help     some text to print out for help
     @param help     some text to print out for help
-    @param type     a callable that coerces a string into the expected type.
+    @param type    a callable that coerces a string into the expected type.
                     str is the default. Should raise an exception in the case
                     str is the default. Should raise an exception in the case
                     that it cannot be coerced.
                     that it cannot be coerced.
     @param private  if True, does not emit help text
     @param private  if True, does not emit help text
     """
     """
+    if not callable(type):
+      raise ValueError("%s: The type argument '%s()' is not callable" % (key, type))
+
     if default is not None and dynamic_default is not None:
     if default is not None and dynamic_default is not None:
-      raise Exception("Cannot specify both dynamic_default and default for key %s" % key)
+      raise ValueError("Cannot specify both dynamic_default and default for key %s" % key)
 
 
     if dynamic_default is not None and not dynamic_default.__doc__ and not private:
     if dynamic_default is not None and not dynamic_default.__doc__ and not private:
-      raise Exception("Dynamic defaults must have __doc__ defined!")
+      raise ValueError("Dynamic default '%s' must have __doc__ defined!" % (key,))
+
+    if pytype(default) in (int, long, float, complex, bool) and \
+          not isinstance(type(default), pytype(default)):
+      raise ValueError("%s: '%s' does not match that of the default value %r (%s)"
+                      % (key, type, default, pytype(default)))
+
+    if type == bool:
+      LOG.warn("%s is of type bool. Resetting it as type 'coerce_bool'."
+               " Please fix it permanently" % (key,))
+      type = coerce_bool
 
 
     self.key = key
     self.key = key
     self.default_value = default
     self.default_value = default
@@ -460,11 +479,11 @@ def _configs_from_dir(conf_dir):
   for filename in sorted(os.listdir(conf_dir)):
   for filename in sorted(os.listdir(conf_dir)):
     if filename.startswith(".") or not filename.endswith('.ini'):
     if filename.startswith(".") or not filename.endswith('.ini'):
       continue
       continue
-    logging.debug("Loading configuration from: %s" % filename)
+    LOG.debug("Loading configuration from: %s" % filename)
     try:
     try:
       conf = configobj.ConfigObj(os.path.join(conf_dir, filename))
       conf = configobj.ConfigObj(os.path.join(conf_dir, filename))
     except configobj.ConfigObjError, ex:
     except configobj.ConfigObjError, ex:
-      logging.error("Error in configuration file '%s': %s" %
+      LOG.error("Error in configuration file '%s': %s" %
                     (os.path.join(conf_dir, filename), ex))
                     (os.path.join(conf_dir, filename), ex))
       raise
       raise
     conf['DEFAULT'] = dict(desktop_root=get_desktop_root(), build_dir=get_build_dir())
     conf['DEFAULT'] = dict(desktop_root=get_desktop_root(), build_dir=get_build_dir())
@@ -564,13 +583,18 @@ def is_anonymous(key):
   return key == _ANONYMOUS
   return key == _ANONYMOUS
 
 
 def coerce_bool(value):
 def coerce_bool(value):
-  if type(value) == bool:
+  if isinstance(value, bool):
     return value
     return value
-  if value in ("false", "False", "0", "no", "off", "", None):
+
+  try:
+    upper = value.upper()
+  except:
+    upper = value
+  if upper in ("FALSE", "0", "NO", "OFF", "NAY", "", None):
     return False
     return False
-  if value in ("true", "True", "1", "yes", "on"):
+  if upper in ("TRUE", "1", "YES", "ON", "YEA"):
     return True
     return True
-  raise Exception("Could not coerce boolean value: " + str(value))
+  raise Exception("Could not coerce %r to boolean value" % (value,))
 
 
 
 
 def validate_path(confvar, is_dir=None):
 def validate_path(confvar, is_dir=None):

+ 67 - 60
desktop/core/src/desktop/lib/conf_test.py

@@ -18,10 +18,10 @@
 import configobj
 import configobj
 from cStringIO import StringIO
 from cStringIO import StringIO
 import logging
 import logging
-import unittest
 import re
 import re
 
 
 from desktop.lib.conf import *
 from desktop.lib.conf import *
+from nose.tools import assert_true, assert_false, assert_equals, assert_raises
 
 
 def my_dynamic_default():
 def my_dynamic_default():
   """
   """
@@ -29,8 +29,8 @@ def my_dynamic_default():
   """
   """
   return 3 + 4
   return 3 + 4
 
 
-class ConfigTest(unittest.TestCase):
-  """Unit tests for this module."""
+class TestConfig(object):
+  """Unit tests for the configuration module."""
 
 
   # Some test configurations to load
   # Some test configurations to load
   CONF_ONE="""
   CONF_ONE="""
@@ -49,9 +49,10 @@ class ConfigTest(unittest.TestCase):
   host="philipscomputer"
   host="philipscomputer"
   """
   """
 
 
-  def setUp(self):
+  @classmethod
+  def setup_class(cls):
     logging.basicConfig(level=logging.DEBUG)
     logging.basicConfig(level=logging.DEBUG)
-    self.conf = ConfigSection(
+    cls.conf = ConfigSection(
       members=dict(
       members=dict(
         FOO           = Config("foo",
         FOO           = Config("foo",
                                help="A vanilla configuration param",
                                help="A vanilla configuration param",
@@ -81,102 +82,109 @@ class ConfigTest(unittest.TestCase):
                                        required=True),
                                        required=True),
                          PORT = Config("port", help="Thrift port for the NN",
                          PORT = Config("port", help="Thrift port for the NN",
                                        type=int, default=10090))))))
                                        type=int, default=10090))))))
-    self.conf = self.conf.bind(
-      load_confs([configobj.ConfigObj(infile=StringIO(self.CONF_ONE)),
-                  configobj.ConfigObj(infile=StringIO(self.CONF_TWO))]),
+    cls.conf = cls.conf.bind(
+      load_confs([configobj.ConfigObj(infile=StringIO(cls.CONF_ONE)),
+                  configobj.ConfigObj(infile=StringIO(cls.CONF_TWO))]),
       prefix='')
       prefix='')
 
 
-  def testDynamicDefault(self):
-    self.assertEquals(7, self.conf.DYNAMIC_DEF.get())
-
-  def testLoad(self):
-    self.assertEquals(123, self.conf.FOO.get())
-    self.assertEquals(456, self.conf.BAR.get())
-    self.assertEquals(345, self.conf.REQ.get())
-
-    self.assertEquals(None, self.conf.OPT_NOT_THERE.get())
-    self.assertRaises(KeyError, self.conf.REQ_NOT_THERE.get)
-
-  def testListValues(self):
-    self.assertEquals(["a","b","c"], self.conf.LIST.get())
-
-  def testSections(self):
-    self.assertEquals(2, len(self.conf.CLUSTERS))
-    self.assertEquals(['clustera', 'clusterb'], sorted(self.conf.CLUSTERS.keys()))
-    self.assertTrue("clustera" in self.conf.CLUSTERS)
-    self.assertEquals("localhost", self.conf.CLUSTERS['clustera'].HOST.get())
-    self.assertEquals(10090, self.conf.CLUSTERS['clustera'].PORT.get())
-
-  def testFullKeyName(self):
-    self.assertEquals(self.conf.REQ.get_fully_qualifying_key(), 'req')
-    self.assertEquals(self.conf.CLUSTERS.get_fully_qualifying_key(), 'clusters')
-    self.assertEquals(self.conf.CLUSTERS['clustera'].get_fully_qualifying_key(),
+  def test_type_safety(self):
+    assert_raises(ValueError, Config, key="test_type", type=42)
+    assert_raises(ValueError, Config, key="test_type", type=str, default=42)
+    assert_raises(ValueError, Config, key="test_type", default=False)
+    bool_conf = Config("bool_conf", type=bool)
+    assert_true(bool_conf.type == coerce_bool)
+
+  def test_dynamic_default(self):
+    assert_equals(7, self.conf.DYNAMIC_DEF.get())
+
+  def test_load(self):
+    assert_equals(123, self.conf.FOO.get())
+    assert_equals(456, self.conf.BAR.get())
+    assert_equals(345, self.conf.REQ.get())
+
+    assert_equals(None, self.conf.OPT_NOT_THERE.get())
+    assert_raises(KeyError, self.conf.REQ_NOT_THERE.get)
+
+  def test_list_values(self):
+    assert_equals(["a","b","c"], self.conf.LIST.get())
+
+  def test_sections(self):
+    assert_equals(2, len(self.conf.CLUSTERS))
+    assert_equals(['clustera', 'clusterb'], sorted(self.conf.CLUSTERS.keys()))
+    assert_true("clustera" in self.conf.CLUSTERS)
+    assert_equals("localhost", self.conf.CLUSTERS['clustera'].HOST.get())
+    assert_equals(10090, self.conf.CLUSTERS['clustera'].PORT.get())
+
+  def test_full_key_name(self):
+    assert_equals(self.conf.REQ.get_fully_qualifying_key(), 'req')
+    assert_equals(self.conf.CLUSTERS.get_fully_qualifying_key(), 'clusters')
+    assert_equals(self.conf.CLUSTERS['clustera'].get_fully_qualifying_key(),
                       'clusters.clustera')
                       'clusters.clustera')
-    self.assertEquals(self.conf.CLUSTERS['clustera'].HOST.get_fully_qualifying_key(),
+    assert_equals(self.conf.CLUSTERS['clustera'].HOST.get_fully_qualifying_key(),
                       'clusters.clustera.host')
                       'clusters.clustera.host')
 
 
-  def testSetForTesting(self):
+  def test_set_for_testing(self):
     # Test base case
     # Test base case
-    self.assertEquals(123, self.conf.FOO.get())
+    assert_equals(123, self.conf.FOO.get())
     # Override with 456
     # Override with 456
     close_foo = self.conf.FOO.set_for_testing(456)
     close_foo = self.conf.FOO.set_for_testing(456)
     try:
     try:
-      self.assertEquals(456, self.conf.FOO.get())
+      assert_equals(456, self.conf.FOO.get())
       # Check nested overriding
       # Check nested overriding
       close_foo2 = self.conf.FOO.set_for_testing(789)
       close_foo2 = self.conf.FOO.set_for_testing(789)
       try:
       try:
-        self.assertEquals(789, self.conf.FOO.get())
+        assert_equals(789, self.conf.FOO.get())
       finally:
       finally:
         close_foo2()
         close_foo2()
 
 
       # Check that we pop the stack appropriately.
       # Check that we pop the stack appropriately.
-      self.assertEquals(456, self.conf.FOO.get())
+      assert_equals(456, self.conf.FOO.get())
       # Check default values
       # Check default values
       close_foo3 = self.conf.FOO.set_for_testing(present=False)
       close_foo3 = self.conf.FOO.set_for_testing(present=False)
       try:
       try:
-        self.assertEquals(None, self.conf.FOO.get())
+        assert_equals(None, self.conf.FOO.get())
       finally:
       finally:
         close_foo3()
         close_foo3()
     finally:
     finally:
       close_foo()
       close_foo()
     # Check that it got set back correctly
     # Check that it got set back correctly
-    self.assertEquals(123, self.conf.FOO.get())
+    assert_equals(123, self.conf.FOO.get())
 
 
     # Test something inside an unspecified config setting with a default
     # Test something inside an unspecified config setting with a default
     close = self.conf.CLUSTERS['clustera'].PORT.set_for_testing(123)
     close = self.conf.CLUSTERS['clustera'].PORT.set_for_testing(123)
     try:
     try:
-      self.assertEquals(123, self.conf.CLUSTERS['clustera'].PORT.get())
+      assert_equals(123, self.conf.CLUSTERS['clustera'].PORT.get())
     finally:
     finally:
       close()
       close()
-    self.assertEquals(10090, self.conf.CLUSTERS['clustera'].PORT.get())
+    assert_equals(10090, self.conf.CLUSTERS['clustera'].PORT.get())
 
 
     # Test something inside a config section that wasn't provided in conf file
     # Test something inside a config section that wasn't provided in conf file
-    self.assertEquals("baz_default", self.conf.SOME_SECTION.BAZ.get())
+    assert_equals("baz_default", self.conf.SOME_SECTION.BAZ.get())
     close = self.conf.SOME_SECTION.BAZ.set_for_testing("hello")
     close = self.conf.SOME_SECTION.BAZ.set_for_testing("hello")
     try:
     try:
-      self.assertEquals("hello", self.conf.SOME_SECTION.BAZ.get())
+      assert_equals("hello", self.conf.SOME_SECTION.BAZ.get())
     finally:
     finally:
       close()
       close()
-    self.assertEquals("baz_default", self.conf.SOME_SECTION.BAZ.get())
+    assert_equals("baz_default", self.conf.SOME_SECTION.BAZ.get())
 
 
 
 
   def test_coerce_bool(self):
   def test_coerce_bool(self):
-    self.assertEquals(False, coerce_bool(False))
-    self.assertEquals(False, coerce_bool("False"))
-    self.assertEquals(False, coerce_bool("false"))
-    self.assertEquals(False, coerce_bool("0"))
-    self.assertEquals(True, coerce_bool("True"))
-    self.assertEquals(True, coerce_bool("true"))
-    self.assertEquals(True, coerce_bool("1"))
-    self.assertEquals(True, coerce_bool(True))
-    self.assertRaises(Exception, coerce_bool, tuple("foo"))
-
-  def testPrintHelp(self):    
+    assert_equals(False, coerce_bool(False))
+    assert_equals(False, coerce_bool("FaLsE"))
+    assert_equals(False, coerce_bool("no"))
+    assert_equals(False, coerce_bool("0"))
+    assert_equals(True, coerce_bool("TrUe"))
+    assert_equals(True, coerce_bool("YES"))
+    assert_equals(True, coerce_bool("1"))
+    assert_equals(True, coerce_bool(True))
+    assert_raises(Exception, coerce_bool, tuple("foo"))
+
+  def test_print_help(self):
     out = StringIO()
     out = StringIO()
     self.conf.print_help(out=out, skip_header=True)
     self.conf.print_help(out=out, skip_header=True)
     out = out.getvalue().strip()
     out = out.getvalue().strip()
-    self.assertFalse("dontseeme" in out)
-    self.assertEquals(re.sub("^    (?m)", "", """
+    assert_false("dontseeme" in out)
+    assert_equals(re.sub("^    (?m)", "", """
     Key: bar (optional)
     Key: bar (optional)
       Default: 456
       Default: 456
       Config with default
       Config with default
@@ -214,4 +222,3 @@ class ConfigTest(unittest.TestCase):
     Key: req (required)
     Key: req (required)
       A required config
       A required config
     """).strip(), out)
     """).strip(), out)
-