Add a listener for state changes to zookeeper elections to, hopefully,
[python_utils.git] / zookeeper.py
index f96d414ad03530ec382becaf88c7884ae206704f..9decfef82a0961fbda0c5f1f0689b639bebe276d 100644 (file)
@@ -13,7 +13,6 @@ import os
 import platform
 import sys
 import threading
-import time
 from typing import Callable, Optional
 
 from kazoo.client import KazooClient
@@ -26,8 +25,8 @@ logger = logging.getLogger(__name__)
 
 
 # On module load, grab what we presume to be our process' program name.
-# This is used, by default, as part of internal zookeeper paths (e.g.
-# to name a lease or election).
+# This is used, by default, to construct internal zookeeper paths (e.g.
+# to identify a lease or election).
 PROGRAM_NAME: str = os.path.basename(sys.argv[0])
 
 
@@ -36,16 +35,24 @@ def obtain_lease(
     *,
     lease_id: str = PROGRAM_NAME,
     contender_id: str = platform.node(),
-    initial_duration: datetime.timedelta = datetime.timedelta(minutes=5),
+    duration: datetime.timedelta = datetime.timedelta(minutes=5),
     also_pass_lease: bool = False,
     also_pass_zk_client: bool = False,
 ):
-    """Obtain the named lease before invoking a function and skip
-    invoking the function if the lease cannot be obtained.
+    """Obtain an exclusive lease identified by the lease_id name
+    before invoking a function or skip invoking the function if the
+    lease cannot be obtained.
+
+    Args:
+        lease_id: string identifying the lease to obtain
+        contender_id: string identifying who's attempting to obtain
+        duration: how long should the lease be held, if obtained?
+        also_pass_lease: pass the lease into the user function
+        also_pass_zk_client: pass our zk client into the user function
 
     >>> @obtain_lease(
     ...         lease_id='zookeeper_doctest',
-    ...         initial_duration=datetime.timedelta(seconds=10),
+    ...         duration=datetime.timedelta(seconds=5),
     ... )
     ... def f(name: str) -> int:
     ...     print(f'Hello, {name}')
@@ -59,20 +66,21 @@ def obtain_lease(
     if not lease_id.startswith('/leases/'):
         lease_id = f'/leases/{lease_id}'
         lease_id = file_utils.fix_multiple_slashes(lease_id)
-    zk = KazooClient(
-        hosts=scott_secrets.ZOOKEEPER_NODES,
-        use_ssl=True,
-        verify_certs=False,
-        keyfile=scott_secrets.ZOOKEEPER_CLIENT_CERT,
-        keyfile_password=scott_secrets.ZOOKEEPER_CLIENT_PASS,
-        certfile=scott_secrets.ZOOKEEPER_CLIENT_CERT,
-    )
-    zk.start()
-    logger.debug('We have an active zookeeper connection.')
 
     def wrapper(func: Callable) -> Callable:
         @functools.wraps(func)
         def wrapper2(*args, **kwargs):
+            zk = KazooClient(
+                hosts=scott_secrets.ZOOKEEPER_NODES,
+                use_ssl=True,
+                verify_certs=False,
+                keyfile=scott_secrets.ZOOKEEPER_CLIENT_CERT,
+                keyfile_password=scott_secrets.ZOOKEEPER_CLIENT_PASS,
+                certfile=scott_secrets.ZOOKEEPER_CLIENT_CERT,
+            )
+            zk.start()
+            logger.debug('We have an active zookeeper connection.')
+
             logger.debug(
                 'Trying to obtain %s for contender %s now...',
                 lease_id,
@@ -80,7 +88,7 @@ def obtain_lease(
             )
             lease = zk.NonBlockingLease(
                 lease_id,
-                initial_duration,
+                duration,
                 contender_id,
             )
             if lease:
@@ -133,7 +141,12 @@ def run_for_election(
     the wrapper to also return and effectively cede leadership.
 
     Because the user's code is run in a separate thread, it may
-    not return anything.
+    not return anything / whatever it returns will be dropped.
+
+    Args:
+        election_id: global string identifier for the election
+        contender_id: string identifying who is running for leader
+        also_pass_zk_client: pass the zk client into the user code
 
     >>> @run_for_election(
     ...         election_id='zookeeper_doctest',
@@ -163,55 +176,72 @@ def run_for_election(
     if not election_id.startswith('/elections/'):
         election_id = f'/elections/{election_id}'
         election_id = file_utils.fix_multiple_slashes(election_id)
-    zk = KazooClient(
-        hosts=scott_secrets.ZOOKEEPER_NODES,
-        use_ssl=True,
-        verify_certs=False,
-        keyfile=scott_secrets.ZOOKEEPER_CLIENT_CERT,
-        keyfile_password=scott_secrets.ZOOKEEPER_CLIENT_PASS,
-        certfile=scott_secrets.ZOOKEEPER_CLIENT_CERT,
-    )
-    zk.start()
-    logger.debug('We have an active zookeeper connection.')
 
-    def wrapper(func: Callable) -> Callable:
-        @functools.wraps(func)
-        def runit(func, *args, **kwargs):
-            stop_event = threading.Event()
-            stop_event.clear()
+    class wrapper:
+        """Helper wrapper class."""
+
+        def __init__(self, func: Callable) -> None:
+            functools.update_wrapper(self, func)
+            self.func = func
+            self.zk = KazooClient(
+                hosts=scott_secrets.ZOOKEEPER_NODES,
+                use_ssl=True,
+                verify_certs=False,
+                keyfile=scott_secrets.ZOOKEEPER_CLIENT_CERT,
+                keyfile_password=scott_secrets.ZOOKEEPER_CLIENT_PASS,
+                certfile=scott_secrets.ZOOKEEPER_CLIENT_CERT,
+            )
+            self.zk.start()
+            logger.debug('We have an active zookeeper connection.')
+            self.stop_event = threading.Event()
+            self.stop_event.clear()
+
+        def zk_listener(self, state: KazooState) -> None:
+            logger.debug('Listener received state %s.', state)
+            if state != KazooState.CONNECTED:
+                logger.debug(
+                    'Bad connection to zookeeper (state=%s); bailing out.',
+                    state,
+                )
+                self.stop_event.set()
+
+        def runit(self, *args, **kwargs) -> None:
+            # Possibly augment args if requested; always pass stop_event
             if also_pass_zk_client:
-                args = (*args, zk)
-            args = (*args, stop_event)
+                args = (*args, self.zk)
+            args = (*args, self.stop_event)
+
             logger.debug('Invoking user code on separate thread.')
             thread = threading.Thread(
-                target=func,
+                target=self.func,
                 args=args,
                 kwargs=kwargs,
             )
             thread.start()
 
+            # Watch the state (fail safe for listener) and the thread.
             while True:
-                state = zk.client_state
+                state = self.zk.client_state
                 if state != KazooState.CONNECTED:
                     logger.error(
                         'Bad connection to zookeeper (state=%s); bailing out.',
                         state,
                     )
-                    stop_event.set()
+                    self.stop_event.set()
+                    logger.debug('Waiting for user thread to tear down...')
                     thread.join()
+                    return
 
+                thread.join(timeout=5.0)
                 if not thread.is_alive():
                     logger.info('Child thread exited, I\'m exiting too.')
                     return
-                time.sleep(5.0)
 
-        @functools.wraps(runit)
-        def wrapper2(*args, **kwargs):
-            election = zk.Election(election_id, contender_id)
-            election.run(runit, func, *args, **kwargs)
-            zk.stop()
-
-        return wrapper2
+        def __call__(self, *args, **kwargs):
+            election = self.zk.Election(election_id, contender_id)
+            self.zk.add_listener(self.zk_listener)
+            election.run(self.runit, *args, **kwargs)
+            self.zk.stop()
 
     if f is None:
         return wrapper