s3test_utils.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. # Licensed to Cloudera, Inc. under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. Cloudera, Inc. licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. from __future__ import absolute_import
  17. import logging
  18. import os
  19. import random
  20. import string
  21. import unittest
  22. from nose.plugins.skip import SkipTest
  23. import aws
  24. from contextlib import contextmanager
  25. from aws.s3 import parse_uri, join
  26. def get_test_bucket():
  27. return os.environ.get('TEST_S3_BUCKET', '')
  28. def generate_id(size=6, chars=string.ascii_uppercase + string.digits):
  29. return ''.join(random.choice(chars) for x in range(size))
  30. class S3TestBase(unittest.TestCase):
  31. integration = True
  32. @classmethod
  33. def setUpClass(cls):
  34. cls.bucket_name = get_test_bucket()
  35. cls._should_skip = False
  36. if not cls.bucket_name:
  37. cls._should_skip = True
  38. cls._skip_msg = 'TEST_S3_BUCKET environment variable isn\'t set'
  39. return
  40. cls.path_prefix = 'test-hue/%s' % generate_id(size=16)
  41. cls.s3_connection = aws.get_client('default').get_s3_connection()
  42. cls.bucket = cls.s3_connection.get_bucket(cls.bucket_name, validate=True)
  43. @classmethod
  44. def shouldSkip(cls):
  45. return cls._should_skip
  46. def setUp(self):
  47. if self.shouldSkip():
  48. raise SkipTest(self._skip_msg)
  49. @classmethod
  50. def tearDownClass(cls):
  51. if not cls.shouldSkip():
  52. cls.clean_up(cls.get_test_path())
  53. @classmethod
  54. def get_test_path(cls, path=None):
  55. base_path = join('s3a://', cls.bucket_name, cls.path_prefix)
  56. if path:
  57. return join(base_path, path)
  58. return base_path
  59. @classmethod
  60. def get_key(cls, path, validate=False):
  61. bucket_name, key_name = parse_uri(path)[:2]
  62. bucket = cls.s3_connection.get_bucket(bucket_name)
  63. return bucket.get_key(key_name, validate=validate)
  64. @classmethod
  65. def clean_up(cls, *paths):
  66. for path in paths:
  67. key = cls.get_key(path, validate=False)
  68. try:
  69. listing = key.bucket.list(prefix=key.name)
  70. key.bucket.delete_keys(listing)
  71. except:
  72. pass
  73. @classmethod
  74. @contextmanager
  75. def cleaning(cls, *paths):
  76. try:
  77. yield paths
  78. finally:
  79. cls.clean_up(*paths)