utils.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. #
  2. # Copyright (C) 2018 Martin Owens
  3. #
  4. # This library is free software; you can redistribute it and/or
  5. # modify it under the terms of the GNU Lesser General Public
  6. # License as published by the Free Software Foundation; either
  7. # version 3.0 of the License, or (at your option) any later version.
  8. #
  9. # This library is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  12. # Lesser General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Lesser General Public
  15. # License along with this library.
  16. #
  17. """
  18. Provide some utilities to tests
  19. """
  20. from logging import Handler, getLogger, root
  21. from collections import defaultdict
  22. # FLAG: do not report failures from here in tracebacks
  23. # pylint: disable=invalid-name
  24. __unittest = True
  25. class LoggingRecorder(Handler):
  26. """Record any logger output for testing"""
  27. def __init__(self, *args, **kwargs):
  28. self.logs = defaultdict(list)
  29. super(LoggingRecorder, self).__init__(*args, **kwargs)
  30. def __getitem__(self, name):
  31. return self.logs[name.upper()]
  32. def emit(self, record):
  33. """Save the log message to the right level"""
  34. # We have no idea why record's getMessage is prefixed
  35. msg = str(record.getMessage())
  36. if msg.startswith('u"'):
  37. msg = msg[1:]
  38. if msg and (msg[0] == msg[-1] and msg[0] in '"\''):
  39. msg = msg[1:-1]
  40. self[record.levelname].append(msg)
  41. return True
  42. class LoggingMixin(object):
  43. """Provide logger capture"""
  44. log_name = None
  45. def setUp(self):
  46. """Make a fresh logger for each test function"""
  47. super(LoggingMixin, self).setUp()
  48. named = getLogger(self.log_name)
  49. for handler in root.handlers[:]:
  50. root.removeHandler(handler)
  51. for handler in named.handlers[:]:
  52. named.removeHandler(handler)
  53. self.log_handler = LoggingRecorder(level='DEBUG')
  54. named.addHandler(self.log_handler)
  55. def tearDown(self):
  56. """Warn about untested logs"""
  57. for level in self.log_handler.logs:
  58. for msg in self.log_handler[level]:
  59. raise ValueError("Uncaught log: {}: {}\n".format(level, msg))
  60. def assertLog(self, level, msg):
  61. """Checks that the logger has emitted the given log"""
  62. logs = self.log_handler[level]
  63. self.assertTrue(logs, 'Logger hasn\'t emitted "{}"'.format(msg))
  64. if len(logs) == 1:
  65. self.assertEqual(msg, logs[0])
  66. else:
  67. self.assertIn(msg, logs)
  68. logs.remove(msg)
  69. def assertNoLog(self, level, msg):
  70. """Checks that the logger has NOT emitted the given log"""
  71. self.assertNotIn(msg, self.log_handler[level])