summaryrefslogtreecommitdiffstats
path: root/test/unit/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/unit/__init__.py')
-rw-r--r--test/unit/__init__.py602
1 files changed, 483 insertions, 119 deletions
diff --git a/test/unit/__init__.py b/test/unit/__init__.py
index a1bfef8..372fb58 100644
--- a/test/unit/__init__.py
+++ b/test/unit/__init__.py
@@ -20,71 +20,288 @@ import copy
import logging
import errno
import sys
-from contextlib import contextmanager
-from collections import defaultdict
+from contextlib import contextmanager, closing
+from collections import defaultdict, Iterable
+import itertools
+from numbers import Number
from tempfile import NamedTemporaryFile
import time
+import eventlet
from eventlet.green import socket
from tempfile import mkdtemp
from shutil import rmtree
+from swift.common.utils import Timestamp
from test import get_config
-from swift.common.utils import config_true_value, LogAdapter
+from swift.common import swob, utils
+from swift.common.ring import Ring, RingData
from hashlib import md5
-from eventlet import sleep, Timeout
import logging.handlers
from httplib import HTTPException
-from numbers import Number
+from swift.common import storage_policy
+from swift.common.storage_policy import StoragePolicy, ECStoragePolicy
+import functools
+import cPickle as pickle
+from gzip import GzipFile
+import mock as mocklib
+import inspect
+
+EMPTY_ETAG = md5().hexdigest()
+
+# try not to import this module from swift
+if not os.path.basename(sys.argv[0]).startswith('swift'):
+ # never patch HASH_PATH_SUFFIX AGAIN!
+ utils.HASH_PATH_SUFFIX = 'endcap'
+
+
+def patch_policies(thing_or_policies=None, legacy_only=False,
+ with_ec_default=False, fake_ring_args=None):
+ if isinstance(thing_or_policies, (
+ Iterable, storage_policy.StoragePolicyCollection)):
+ return PatchPolicies(thing_or_policies, fake_ring_args=fake_ring_args)
+
+ if legacy_only:
+ default_policies = [
+ StoragePolicy(0, name='legacy', is_default=True),
+ ]
+ default_ring_args = [{}]
+ elif with_ec_default:
+ default_policies = [
+ ECStoragePolicy(0, name='ec', is_default=True,
+ ec_type='jerasure_rs_vand', ec_ndata=10,
+ ec_nparity=4, ec_segment_size=4096),
+ StoragePolicy(1, name='unu'),
+ ]
+ default_ring_args = [{'replicas': 14}, {}]
+ else:
+ default_policies = [
+ StoragePolicy(0, name='nulo', is_default=True),
+ StoragePolicy(1, name='unu'),
+ ]
+ default_ring_args = [{}, {}]
+
+ fake_ring_args = fake_ring_args or default_ring_args
+ decorator = PatchPolicies(default_policies, fake_ring_args=fake_ring_args)
+
+ if not thing_or_policies:
+ return decorator
+ else:
+ # it's a thing, we return the wrapped thing instead of the decorator
+ return decorator(thing_or_policies)
+
+
+class PatchPolicies(object):
+ """
+ Why not mock.patch? In my case, when used as a decorator on the class it
+ seemed to patch setUp at the wrong time (i.e. in setup the global wasn't
+ patched yet)
+ """
+
+ def __init__(self, policies, fake_ring_args=None):
+ if isinstance(policies, storage_policy.StoragePolicyCollection):
+ self.policies = policies
+ else:
+ self.policies = storage_policy.StoragePolicyCollection(policies)
+ self.fake_ring_args = fake_ring_args or [None] * len(self.policies)
+
+ def _setup_rings(self):
+ """
+ Our tests tend to use the policies rings like their own personal
+ playground - which can be a problem in the particular case of a
+ patched TestCase class where the FakeRing objects are scoped in the
+ call to the patch_policies wrapper outside of the TestCase instance
+ which can lead to some bled state.
+
+ To help tests get better isolation without having to think about it,
+ here we're capturing the args required to *build* a new FakeRing
+ instances so we can ensure each test method gets a clean ring setup.
+
+ The TestCase can always "tweak" these fresh rings in setUp - or if
+ they'd prefer to get the same "reset" behavior with custom FakeRing's
+ they can pass in their own fake_ring_args to patch_policies instead of
+ setting the object_ring on the policy definitions.
+ """
+ for policy, fake_ring_arg in zip(self.policies, self.fake_ring_args):
+ if fake_ring_arg is not None:
+ policy.object_ring = FakeRing(**fake_ring_arg)
+
+ def __call__(self, thing):
+ if isinstance(thing, type):
+ return self._patch_class(thing)
+ else:
+ return self._patch_method(thing)
+
+ def _patch_class(self, cls):
+ """
+ Creating a new class that inherits from decorated class is the more
+ common way I've seen class decorators done - but it seems to cause
+ infinite recursion when super is called from inside methods in the
+ decorated class.
+ """
+
+ orig_setUp = cls.setUp
+ orig_tearDown = cls.tearDown
+
+ def setUp(cls_self):
+ self._orig_POLICIES = storage_policy._POLICIES
+ if not getattr(cls_self, '_policies_patched', False):
+ storage_policy._POLICIES = self.policies
+ self._setup_rings()
+ cls_self._policies_patched = True
+
+ orig_setUp(cls_self)
+
+ def tearDown(cls_self):
+ orig_tearDown(cls_self)
+ storage_policy._POLICIES = self._orig_POLICIES
+
+ cls.setUp = setUp
+ cls.tearDown = tearDown
+
+ return cls
+
+ def _patch_method(self, f):
+ @functools.wraps(f)
+ def mywrapper(*args, **kwargs):
+ self._orig_POLICIES = storage_policy._POLICIES
+ try:
+ storage_policy._POLICIES = self.policies
+ self._setup_rings()
+ return f(*args, **kwargs)
+ finally:
+ storage_policy._POLICIES = self._orig_POLICIES
+ return mywrapper
+
+ def __enter__(self):
+ self._orig_POLICIES = storage_policy._POLICIES
+ storage_policy._POLICIES = self.policies
+ def __exit__(self, *args):
+ storage_policy._POLICIES = self._orig_POLICIES
-class FakeRing(object):
- def __init__(self, replicas=3, max_more_nodes=0):
+class FakeRing(Ring):
+
+ def __init__(self, replicas=3, max_more_nodes=0, part_power=0,
+ base_port=1000):
+ """
+ :param part_power: make part calculation based on the path
+
+ If you set a part_power when you setup your FakeRing the parts you get
+ out of ring methods will actually be based on the path - otherwise we
+ exercise the real ring code, but ignore the result and return 1.
+ """
+ self._base_port = base_port
+ self.max_more_nodes = max_more_nodes
+ self._part_shift = 32 - part_power
# 9 total nodes (6 more past the initial 3) is the cap, no matter if
# this is set higher, or R^2 for R replicas
- self.replicas = replicas
- self.max_more_nodes = max_more_nodes
- self.devs = {}
+ self.set_replicas(replicas)
+ self._reload()
+
+ def _reload(self):
+ self._rtime = time.time()
def set_replicas(self, replicas):
self.replicas = replicas
- self.devs = {}
+ self._devs = []
+ for x in range(self.replicas):
+ ip = '10.0.0.%s' % x
+ port = self._base_port + x
+ self._devs.append({
+ 'ip': ip,
+ 'replication_ip': ip,
+ 'port': port,
+ 'replication_port': port,
+ 'device': 'sd' + (chr(ord('a') + x)),
+ 'zone': x % 3,
+ 'region': x % 2,
+ 'id': x,
+ })
@property
def replica_count(self):
return self.replicas
- def get_part(self, account, container=None, obj=None):
- return 1
-
- def get_nodes(self, account, container=None, obj=None):
- devs = []
- for x in xrange(self.replicas):
- devs.append(self.devs.get(x))
- if devs[x] is None:
- self.devs[x] = devs[x] = \
- {'ip': '10.0.0.%s' % x,
- 'port': 1000 + x,
- 'device': 'sd' + (chr(ord('a') + x)),
- 'zone': x % 3,
- 'region': x % 2,
- 'id': x}
- return 1, devs
-
- def get_part_nodes(self, part):
- return self.get_nodes('blah')[1]
+ def _get_part_nodes(self, part):
+ return [dict(node, index=i) for i, node in enumerate(list(self._devs))]
def get_more_nodes(self, part):
# replicas^2 is the true cap
for x in xrange(self.replicas, min(self.replicas + self.max_more_nodes,
self.replicas * self.replicas)):
yield {'ip': '10.0.0.%s' % x,
- 'port': 1000 + x,
+ 'replication_ip': '10.0.0.%s' % x,
+ 'port': self._base_port + x,
+ 'replication_port': self._base_port + x,
'device': 'sda',
'zone': x % 3,
'region': x % 2,
'id': x}
+def write_fake_ring(path, *devs):
+ """
+ Pretty much just a two node, two replica, 2 part power ring...
+ """
+ dev1 = {'id': 0, 'zone': 0, 'device': 'sda1', 'ip': '127.0.0.1',
+ 'port': 6000}
+ dev2 = {'id': 0, 'zone': 0, 'device': 'sdb1', 'ip': '127.0.0.1',
+ 'port': 6000}
+
+ dev1_updates, dev2_updates = devs or ({}, {})
+
+ dev1.update(dev1_updates)
+ dev2.update(dev2_updates)
+
+ replica2part2dev_id = [[0, 1, 0, 1], [1, 0, 1, 0]]
+ devs = [dev1, dev2]
+ part_shift = 30
+ with closing(GzipFile(path, 'wb')) as f:
+ pickle.dump(RingData(replica2part2dev_id, devs, part_shift), f)
+
+
+class FabricatedRing(Ring):
+ """
+ When a FakeRing just won't do - you can fabricate one to meet
+ your tests needs.
+ """
+
+ def __init__(self, replicas=6, devices=8, nodes=4, port=6000,
+ part_power=4):
+ self.devices = devices
+ self.nodes = nodes
+ self.port = port
+ self.replicas = 6
+ self.part_power = part_power
+ self._part_shift = 32 - self.part_power
+ self._reload()
+
+ def _reload(self, *args, **kwargs):
+ self._rtime = time.time() * 2
+ if hasattr(self, '_replica2part2dev_id'):
+ return
+ self._devs = [{
+ 'region': 1,
+ 'zone': 1,
+ 'weight': 1.0,
+ 'id': i,
+ 'device': 'sda%d' % i,
+ 'ip': '10.0.0.%d' % (i % self.nodes),
+ 'replication_ip': '10.0.0.%d' % (i % self.nodes),
+ 'port': self.port,
+ 'replication_port': self.port,
+ } for i in range(self.devices)]
+
+ self._replica2part2dev_id = [
+ [None] * 2 ** self.part_power
+ for i in range(self.replicas)
+ ]
+ dev_ids = itertools.cycle(range(self.devices))
+ for p in range(2 ** self.part_power):
+ for r in range(self.replicas):
+ self._replica2part2dev_id[r][p] = next(dev_ids)
+
+
class FakeMemcache(object):
def __init__(self):
@@ -152,24 +369,13 @@ def tmpfile(content):
xattr_data = {}
-def _get_inode(fd_or_name):
- try:
- if isinstance(fd_or_name, int):
- fd = fd_or_name
- else:
- try:
- fd = fd_or_name.fileno()
- except AttributeError:
- fd = None
- if fd is None:
- ino = os.stat(fd_or_name).st_ino
- else:
- ino = os.fstat(fd).st_ino
- except OSError as err:
- ioerr = IOError()
- ioerr.errno = err.errno
- raise ioerr
- return ino
+def _get_inode(fd):
+ if not isinstance(fd, int):
+ try:
+ fd = fd.fileno()
+ except AttributeError:
+ return os.stat(fd).st_ino
+ return os.fstat(fd).st_ino
def _setxattr(fd, k, v):
@@ -183,9 +389,7 @@ def _getxattr(fd, k):
inode = _get_inode(fd)
data = xattr_data.get(inode, {}).get(k)
if not data:
- e = IOError("Fake IOError")
- e.errno = errno.ENODATA
- raise e
+ raise IOError(errno.ENODATA, "Fake IOError")
return data
import xattr
@@ -214,6 +418,22 @@ def temptree(files, contents=''):
rmtree(tempdir)
+def with_tempdir(f):
+ """
+ Decorator to give a single test a tempdir as argument to test method.
+ """
+ @functools.wraps(f)
+ def wrapped(*args, **kwargs):
+ tempdir = mkdtemp()
+ args = list(args)
+ args.append(tempdir)
+ try:
+ return f(*args, **kwargs)
+ finally:
+ rmtree(tempdir)
+ return wrapped
+
+
class NullLoggingHandler(logging.Handler):
def emit(self, record):
@@ -239,8 +459,8 @@ class UnmockTimeModule(object):
logging.time = UnmockTimeModule()
-class FakeLogger(logging.Logger):
- # a thread safe logger
+class FakeLogger(logging.Logger, object):
+ # a thread safe fake logger
def __init__(self, *args, **kwargs):
self._clear()
@@ -250,42 +470,57 @@ class FakeLogger(logging.Logger):
self.facility = kwargs['facility']
self.statsd_client = None
self.thread_locals = None
+ self.parent = None
+
+ store_in = {
+ logging.ERROR: 'error',
+ logging.WARNING: 'warning',
+ logging.INFO: 'info',
+ logging.DEBUG: 'debug',
+ logging.CRITICAL: 'critical',
+ }
+
+ def _log(self, level, msg, *args, **kwargs):
+ store_name = self.store_in[level]
+ cargs = [msg]
+ if any(args):
+ cargs.extend(args)
+ captured = dict(kwargs)
+ if 'exc_info' in kwargs and \
+ not isinstance(kwargs['exc_info'], tuple):
+ captured['exc_info'] = sys.exc_info()
+ self.log_dict[store_name].append((tuple(cargs), captured))
+ super(FakeLogger, self)._log(level, msg, *args, **kwargs)
def _clear(self):
self.log_dict = defaultdict(list)
- self.lines_dict = defaultdict(list)
-
- def _store_in(store_name):
- def stub_fn(self, *args, **kwargs):
- self.log_dict[store_name].append((args, kwargs))
- return stub_fn
-
- def _store_and_log_in(store_name):
- def stub_fn(self, *args, **kwargs):
- self.log_dict[store_name].append((args, kwargs))
- self._log(store_name, args[0], args[1:], **kwargs)
- return stub_fn
+ self.lines_dict = {'critical': [], 'error': [], 'info': [],
+ 'warning': [], 'debug': []}
def get_lines_for_level(self, level):
+ if level not in self.lines_dict:
+ raise KeyError(
+ "Invalid log level '%s'; valid levels are %s" %
+ (level,
+ ', '.join("'%s'" % lvl for lvl in sorted(self.lines_dict))))
return self.lines_dict[level]
- error = _store_and_log_in('error')
- info = _store_and_log_in('info')
- warning = _store_and_log_in('warning')
- warn = _store_and_log_in('warning')
- debug = _store_and_log_in('debug')
+ def all_log_lines(self):
+ return dict((level, msgs) for level, msgs in self.lines_dict.items()
+ if len(msgs) > 0)
- def exception(self, *args, **kwargs):
- self.log_dict['exception'].append((args, kwargs,
- str(sys.exc_info()[1])))
- print 'FakeLogger Exception: %s' % self.log_dict
+ def _store_in(store_name):
+ def stub_fn(self, *args, **kwargs):
+ self.log_dict[store_name].append((args, kwargs))
+ return stub_fn
# mock out the StatsD logging methods:
+ update_stats = _store_in('update_stats')
increment = _store_in('increment')
decrement = _store_in('decrement')
timing = _store_in('timing')
timing_since = _store_in('timing_since')
- update_stats = _store_in('update_stats')
+ transfer_rate = _store_in('transfer_rate')
set_statsd_prefix = _store_in('set_statsd_prefix')
def get_increments(self):
@@ -328,7 +563,7 @@ class FakeLogger(logging.Logger):
print 'WARNING: unable to format log message %r %% %r' % (
record.msg, record.args)
raise
- self.lines_dict[record.levelno].append(line)
+ self.lines_dict[record.levelname.lower()].append(line)
def handle(self, record):
self._handle(record)
@@ -345,19 +580,40 @@ class DebugLogger(FakeLogger):
def __init__(self, *args, **kwargs):
FakeLogger.__init__(self, *args, **kwargs)
- self.formatter = logging.Formatter("%(server)s: %(message)s")
+ self.formatter = logging.Formatter(
+ "%(server)s %(levelname)s: %(message)s")
def handle(self, record):
self._handle(record)
print self.formatter.format(record)
- def write(self, *args):
- print args
+
+class DebugLogAdapter(utils.LogAdapter):
+
+ def _send_to_logger(name):
+ def stub_fn(self, *args, **kwargs):
+ return getattr(self.logger, name)(*args, **kwargs)
+ return stub_fn
+
+ # delegate to FakeLogger's mocks
+ update_stats = _send_to_logger('update_stats')
+ increment = _send_to_logger('increment')
+ decrement = _send_to_logger('decrement')
+ timing = _send_to_logger('timing')
+ timing_since = _send_to_logger('timing_since')
+ transfer_rate = _send_to_logger('transfer_rate')
+ set_statsd_prefix = _send_to_logger('set_statsd_prefix')
+
+ def __getattribute__(self, name):
+ try:
+ return object.__getattribute__(self, name)
+ except AttributeError:
+ return getattr(self.__dict__['logger'], name)
def debug_logger(name='test'):
"""get a named adapted debug logger"""
- return LogAdapter(DebugLogger(), name)
+ return DebugLogAdapter(DebugLogger(), name)
original_syslog_handler = logging.handlers.SysLogHandler
@@ -374,7 +630,8 @@ def fake_syslog_handler():
logging.handlers.SysLogHandler = FakeLogger
-if config_true_value(get_config('unit_test').get('fake_syslog', 'False')):
+if utils.config_true_value(
+ get_config('unit_test').get('fake_syslog', 'False')):
fake_syslog_handler()
@@ -447,17 +704,66 @@ def mock(update):
delattr(module, attr)
+class SlowBody(object):
+ """
+ This will work with our fake_http_connect, if you hand in these
+ instead of strings it will make reads take longer by the given
+ amount. It should be a little bit easier to extend than the
+ current slow kwarg - which inserts whitespace in the response.
+ Also it should be easy to detect if you have one of these (or a
+ subclass) for the body inside of FakeConn if we wanted to do
+ something smarter than just duck-type the str/buffer api
+ enough to get by.
+ """
+
+ def __init__(self, body, slowness):
+ self.body = body
+ self.slowness = slowness
+
+ def slowdown(self):
+ eventlet.sleep(self.slowness)
+
+ def __getitem__(self, s):
+ return SlowBody(self.body[s], self.slowness)
+
+ def __len__(self):
+ return len(self.body)
+
+ def __radd__(self, other):
+ self.slowdown()
+ return other + self.body
+
+
def fake_http_connect(*code_iter, **kwargs):
class FakeConn(object):
def __init__(self, status, etag=None, body='', timestamp='1',
- expect_status=None, headers=None):
- self.status = status
- if expect_status is None:
- self.expect_status = self.status
+ headers=None, expect_headers=None, connection_id=None,
+ give_send=None):
+ # connect exception
+ if isinstance(status, (Exception, eventlet.Timeout)):
+ raise status
+ if isinstance(status, tuple):
+ self.expect_status = list(status[:-1])
+ self.status = status[-1]
+ self.explicit_expect_list = True
else:
- self.expect_status = expect_status
+ self.expect_status, self.status = ([], status)
+ self.explicit_expect_list = False
+ if not self.expect_status:
+ # when a swift backend service returns a status before reading
+ # from the body (mostly an error response) eventlet.wsgi will
+ # respond with that status line immediately instead of 100
+ # Continue, even if the client sent the Expect 100 header.
+ # BufferedHttp and the proxy both see these error statuses
+ # when they call getexpect, so our FakeConn tries to act like
+ # our backend services and return certain types of responses
+ # as expect statuses just like a real backend server would do.
+ if self.status in (507, 412, 409):
+ self.expect_status = [status]
+ else:
+ self.expect_status = [100, 100]
self.reason = 'Fake'
self.host = '1.2.3.4'
self.port = '1234'
@@ -466,30 +772,41 @@ def fake_http_connect(*code_iter, **kwargs):
self.etag = etag
self.body = body
self.headers = headers or {}
+ self.expect_headers = expect_headers or {}
self.timestamp = timestamp
+ self.connection_id = connection_id
+ self.give_send = give_send
if 'slow' in kwargs and isinstance(kwargs['slow'], list):
try:
self._next_sleep = kwargs['slow'].pop(0)
except IndexError:
self._next_sleep = None
+ # be nice to trixy bits with node_iter's
+ eventlet.sleep()
def getresponse(self):
- if kwargs.get('raise_exc'):
+ if self.expect_status and self.explicit_expect_list:
+ raise Exception('Test did not consume all fake '
+ 'expect status: %r' % (self.expect_status,))
+ if isinstance(self.status, (Exception, eventlet.Timeout)):
+ raise self.status
+ exc = kwargs.get('raise_exc')
+ if exc:
+ if isinstance(exc, (Exception, eventlet.Timeout)):
+ raise exc
raise Exception('test')
if kwargs.get('raise_timeout_exc'):
- raise Timeout()
+ raise eventlet.Timeout()
return self
def getexpect(self):
- if self.expect_status == -2:
- raise HTTPException()
- if self.expect_status == -3:
- return FakeConn(507)
- if self.expect_status == -4:
- return FakeConn(201)
- if self.expect_status == 412:
- return FakeConn(412)
- return FakeConn(100)
+ expect_status = self.expect_status.pop(0)
+ if isinstance(self.expect_status, (Exception, eventlet.Timeout)):
+ raise self.expect_status
+ headers = dict(self.expect_headers)
+ if expect_status == 409:
+ headers['X-Backend-Timestamp'] = self.timestamp
+ return FakeConn(expect_status, headers=headers)
def getheaders(self):
etag = self.etag
@@ -499,19 +816,23 @@ def fake_http_connect(*code_iter, **kwargs):
else:
etag = '"68b329da9893e34099c7d8ad5cb9c940"'
- headers = {'content-length': len(self.body),
- 'content-type': 'x-application/test',
- 'x-timestamp': self.timestamp,
- 'last-modified': self.timestamp,
- 'x-object-meta-test': 'testing',
- 'x-delete-at': '9876543210',
- 'etag': etag,
- 'x-works': 'yes'}
+ headers = swob.HeaderKeyDict({
+ 'content-length': len(self.body),
+ 'content-type': 'x-application/test',
+ 'x-timestamp': self.timestamp,
+ 'x-backend-timestamp': self.timestamp,
+ 'last-modified': self.timestamp,
+ 'x-object-meta-test': 'testing',
+ 'x-delete-at': '9876543210',
+ 'etag': etag,
+ 'x-works': 'yes',
+ })
if self.status // 100 == 2:
headers['x-account-container-count'] = \
kwargs.get('count', 12345)
if not self.timestamp:
- del headers['x-timestamp']
+ # when timestamp is None, HeaderKeyDict raises KeyError
+ headers.pop('x-timestamp', None)
try:
if container_ts_iter.next() is False:
headers['x-container-timestamp'] = '1'
@@ -538,34 +859,45 @@ def fake_http_connect(*code_iter, **kwargs):
if am_slow:
if self.sent < 4:
self.sent += 1
- sleep(value)
+ eventlet.sleep(value)
return ' '
rv = self.body[:amt]
self.body = self.body[amt:]
return rv
def send(self, amt=None):
+ if self.give_send:
+ self.give_send(self.connection_id, amt)
am_slow, value = self.get_slow()
if am_slow:
if self.received < 4:
self.received += 1
- sleep(value)
+ eventlet.sleep(value)
def getheader(self, name, default=None):
- return dict(self.getheaders()).get(name.lower(), default)
+ return swob.HeaderKeyDict(self.getheaders()).get(name, default)
+
+ def close(self):
+ pass
timestamps_iter = iter(kwargs.get('timestamps') or ['1'] * len(code_iter))
etag_iter = iter(kwargs.get('etags') or [None] * len(code_iter))
- if isinstance(kwargs.get('headers'), list):
+ if isinstance(kwargs.get('headers'), (list, tuple)):
headers_iter = iter(kwargs['headers'])
else:
headers_iter = iter([kwargs.get('headers', {})] * len(code_iter))
+ if isinstance(kwargs.get('expect_headers'), (list, tuple)):
+ expect_headers_iter = iter(kwargs['expect_headers'])
+ else:
+ expect_headers_iter = iter([kwargs.get('expect_headers', {})] *
+ len(code_iter))
x = kwargs.get('missing_container', [False] * len(code_iter))
if not isinstance(x, (tuple, list)):
x = [x] * len(code_iter)
container_ts_iter = iter(x)
code_iter = iter(code_iter)
+ conn_id_and_code_iter = enumerate(code_iter)
static_body = kwargs.get('body', None)
body_iter = kwargs.get('body_iter', None)
if body_iter:
@@ -573,21 +905,22 @@ def fake_http_connect(*code_iter, **kwargs):
def connect(*args, **ckwargs):
if kwargs.get('slow_connect', False):
- sleep(0.1)
+ eventlet.sleep(0.1)
if 'give_content_type' in kwargs:
if len(args) >= 7 and 'Content-Type' in args[6]:
kwargs['give_content_type'](args[6]['Content-Type'])
else:
kwargs['give_content_type']('')
+ i, status = conn_id_and_code_iter.next()
if 'give_connect' in kwargs:
- kwargs['give_connect'](*args, **ckwargs)
- status = code_iter.next()
- if isinstance(status, tuple):
- status, expect_status = status
- else:
- expect_status = status
+ give_conn_fn = kwargs['give_connect']
+ argspec = inspect.getargspec(give_conn_fn)
+ if argspec.keywords or 'connection_id' in argspec.args:
+ ckwargs['connection_id'] = i
+ give_conn_fn(*args, **ckwargs)
etag = etag_iter.next()
headers = headers_iter.next()
+ expect_headers = expect_headers_iter.next()
timestamp = timestamps_iter.next()
if status <= 0:
@@ -597,8 +930,39 @@ def fake_http_connect(*code_iter, **kwargs):
else:
body = body_iter.next()
return FakeConn(status, etag, body=body, timestamp=timestamp,
- expect_status=expect_status, headers=headers)
+ headers=headers, expect_headers=expect_headers,
+ connection_id=i, give_send=kwargs.get('give_send'))
connect.code_iter = code_iter
return connect
+
+
+@contextmanager
+def mocked_http_conn(*args, **kwargs):
+ requests = []
+
+ def capture_requests(ip, port, method, path, headers, qs, ssl):
+ req = {
+ 'ip': ip,
+ 'port': port,
+ 'method': method,
+ 'path': path,
+ 'headers': headers,
+ 'qs': qs,
+ 'ssl': ssl,
+ }
+ requests.append(req)
+ kwargs.setdefault('give_connect', capture_requests)
+ fake_conn = fake_http_connect(*args, **kwargs)
+ fake_conn.requests = requests
+ with mocklib.patch('swift.common.bufferedhttp.http_connect_raw',
+ new=fake_conn):
+ yield fake_conn
+ left_over_status = list(fake_conn.code_iter)
+ if left_over_status:
+ raise AssertionError('left over status %r' % left_over_status)
+
+
+def make_timestamp_iter():
+ return iter(Timestamp(t) for t in itertools.count(int(time.time())))