diff --git a/pymongo/connection.py b/pymongo/connection.py index bad1879..9b38687 100644 --- a/pymongo/connection.py +++ b/pymongo/connection.py @@ -81,6 +81,31 @@ def _partition_node(node): return host, port +class _SocketHolder(object): + def __init__(self, pool): + self.sock = None + self.reusable = False + self.pid = pool.pid + self.sockets = pool.sockets + self.pool_size = pool.pool_size + + def attach(self, sock): + self.discard() + self.sock = sock + self.reusable = False + + def discard(self, reuse=True): + if self.sock is not None: + sock, self.sock = self.sock, None + if reuse and self.reusable and self.pid == os.getpid() and len(self.sockets) < self.pool_size: + # There's a race condition here, but we deliberately + # ignore it. It means that if the pool_size is 10 we + # might actually keep slightly more than that. + self.sockets.append(sock) + + def __del__(self): + self.discard() + class _Pool(threading.local): """A simple connection pool. @@ -89,17 +114,29 @@ class _Pool(threading.local): """ # Non thread-locals - __slots__ = ["sockets", "socket_factory", "pool_size", "pid"] + __slots__ = ["pid", "pool_size", "network_timeout", "sockets", "forklock"] # thread-local default - sock = None + holder = None - def __init__(self, pool_size, network_timeout): + @property + def sock(self): + if self.holder is not None: + return self.holder.sock + return None + + def __new__(cls, pool_size, network_timeout): + # __new__ is called only once, so we can initialize non thread-locals + self = threading.local.__new__(cls, pool_size, network_timeout) self.pid = os.getpid() self.pool_size = pool_size self.network_timeout = network_timeout - if not hasattr(self, "sockets"): - self.sockets = [] + self.sockets = [] + self.forklock = threading.Lock() + return self + + def __init__(self, pool_size, network_timeout): + pass def connect(self, host, port): """Connect to Mongo and return a new (connected) socket. @@ -129,31 +166,44 @@ class _Pool(threading.local): pid = os.getpid() if pid != self.pid: - self.sock = None - self.sockets = [] - self.pid = pid + self.forklock.acquire() + try: + if pid != self.pid: + # Only once + self.sockets = [] + self.pid = pid + finally: + self.forklock.release() - if self.sock is not None and self.sock[0] == pid: - return (self.sock[1], True) + holder = self.holder + if holder is None or pid != holder.pid: + self.holder = holder = _SocketHolder(self) + + if holder.reusable and holder.sock is not None: + holder.reusable = False + return (holder.sock, True) try: - self.sock = (pid, self.sockets.pop()) - return (self.sock[1], True) + holder.attach(self.sockets.pop()) + return (holder.sock, True) except IndexError: - self.sock = (pid, self.connect(host, port)) - return (self.sock[1], False) + holder.attach(self.connect(host, port)) + return (holder.sock, False) def return_socket(self): - if self.sock is not None and self.sock[0] == os.getpid(): - # There's a race condition here, but we deliberately - # ignore it. It means that if the pool_size is 10 we - # might actually keep slightly more than that. - if len(self.sockets) < self.pool_size: - self.sockets.append(self.sock[1]) - else: - self.sock[1].close() - self.sock = None + if self.holder is not None: + self.holder.discard() + + + def discard_socket(self): + if self.holder is not None: + self.holder.discard(False) + + + def socket_reusable(self, reusable): + if self.holder is not None: + self.holder.reusable = reusable class Connection(common.BaseObject): @@ -623,12 +673,14 @@ class Connection(common.BaseObject): "%s:%d: %s" % (host, port, str(why))) t = time.time() if t - self.__last_checkout > 1: - if _closed(sock): - self.disconnect() + while from_pool and _closed(sock): + self.__pool.discard_socket() sock, from_pool = self.__pool.get_socket(host, port) self.__last_checkout = t if self.__auth_credentials and not from_pool: + self.__pool.socket_reusable(True) self.__authenticate_socket() + self.__pool.socket_reusable(False) return sock def disconnect(self): @@ -743,7 +795,9 @@ class Connection(common.BaseObject): if with_last_error: response = self.__receive_message_on_socket(1, request_id, sock) + self.__pool.socket_reusable(True) return self.__check_response_to_last_error(response) + self.__pool.socket_reusable(True) return None except (ConnectionFailure, socket.error), e: self.disconnect() @@ -801,7 +855,9 @@ class Connection(common.BaseObject): try: if "network_timeout" in kwargs: sock.settimeout(kwargs["network_timeout"]) - return self.__send_and_receive(message, sock) + response = self.__send_and_receive(message, sock) + self.__pool.socket_reusable(True) + return response except (ConnectionFailure, socket.error), e: self.disconnect() raise AutoReconnect(str(e))