Adds support for zookeeper-based lockfiles.
[pyutils.git] / src / pyutils / files / lockfile.py
index e148c3dcbc5518c1a3db520233516ae37939d121..26adfb3c31ff6700ca9c826119fd8e3977633622 100644 (file)
@@ -4,11 +4,19 @@
 
 """This is a lockfile implementation I created for use with cronjobs
 on my machine to prevent multiple copies of a job from running in
-parallel.  When one job is running this code keeps a file on disk to
-indicate a lock is held.  Other copies will fail to start if they
-detect this lock until the lock is released.  There are provisions in
-the code for timing out locks, cleaning up a lock when a signal is
-received, gracefully retrying lock acquisition on failure, etc...
+parallel.
+
+For local operations, when one job is running this code keeps a file
+on disk to indicate a lock is held.  Other copies will fail to start
+if they detect this lock until the lock is released.  There are
+provisions in the code for timing out locks, cleaning up a lock when a
+signal is received, gracefully retrying lock acquisition on failure,
+etc...
+
+Also allows for Zookeeper-based locks when lockfile path is prefixed
+with 'zk:' in order to synchronize processes across different
+machines.
+
 """
 
 from __future__ import annotations
@@ -24,7 +32,9 @@ import warnings
 from dataclasses import dataclass
 from typing import Literal, Optional
 
-from pyutils import argparse_utils, config, decorator_utils
+import kazoo
+
+from pyutils import argparse_utils, config, decorator_utils, zookeeper
 from pyutils.datetimes import datetime_utils
 
 cfg = config.add_commandline_args(f"Lockfile ({__file__})", "Args related to lockfiles")
@@ -45,7 +55,7 @@ class LockFileException(Exception):
 
 
 @dataclass
-class LockFileContents:
+class LocalLockFileContents:
     """The contents we'll write to each lock file."""
 
     pid: int
@@ -80,17 +90,36 @@ class LockFile(contextlib.AbstractContextManager):
         """C'tor.
 
         Args:
-            lockfile_path: path of the lockfile to acquire
+            lockfile_path: path of the lockfile to acquire; may begin
+                with zk: to indicate a path in zookeeper rather than
+                on the local filesystem.  Note that zookeeper-based
+                locks require an expiration_timestamp as the stale
+                detection semantics are skipped for non-local locks.
             do_signal_cleanup: handle SIGINT and SIGTERM events by
                 releasing the lock before exiting
             expiration_timestamp: when our lease on the lock should
                 expire (as seconds since the Epoch).  None means the
                 lock will not expire until we explicltly release it.
+                Note that this is required for zookeeper based locks.
             override_command: don't use argv to determine our commandline
                 rather use this instead if provided.
         """
         self.is_locked: bool = False
-        self.lockfile: str = lockfile_path
+        self.lockfile: str = ""
+        self.zk_client: Optional[kazoo.client.KazooClient] = None
+        self.zk_lease: Optional[zookeeper.RenewableReleasableLease] = None
+
+        if lockfile_path.startswith("zk:"):
+            logger.debug("Lockfile is on Zookeeper.")
+            if expiration_timestamp is None:
+                raise Exception("Zookeeper locks require an expiration timestamp")
+            self.lockfile = lockfile_path[3:]
+            if not self.lockfile.startswith("/leases"):
+                self.lockfile = "/leases" + self.lockfile
+            self.zk_client = zookeeper.get_started_zk_client()
+        else:
+            logger.debug("Lockfile is local.")
+            self.lockfile = lockfile_path
         self.locktime: Optional[float] = None
         self.override_command: Optional[str] = override_command
         if do_signal_cleanup:
@@ -102,9 +131,31 @@ class LockFile(contextlib.AbstractContextManager):
         """Is it locked currently?"""
         return self.is_locked
 
-    def available(self) -> bool:
-        """Is it available currently?"""
-        return not os.path.exists(self.lockfile)
+    def _try_acquire_local_filesystem_lock(self) -> bool:
+        """Attempt to create the lockfile.  These flags cause os.open
+        to raise an OSError if the file already exists.
+        """
+        try:
+            logger.debug("Trying to acquire local lock %s.", self.lockfile)
+            fd = os.open(self.lockfile, os.O_CREAT | os.O_EXCL | os.O_RDWR)
+            with os.fdopen(fd, "a") as f:
+                contents = self._construct_local_lockfile_contents()
+                logger.debug(contents)
+                f.write(contents)
+            return True
+        except OSError:
+            logger.warning("Couldn't acquire local lock %s.", self.lockfile)
+            return False
+
+    def _try_acquire_zk_lock(self) -> bool:
+        assert self.expiration_timestamp
+        self.zk_lease = zookeeper.RenewableReleasableLease(
+            self.zk_client,
+            self.lockfile,
+            datetime.timedelta(seconds=self.expiration_timestamp),
+            f"Pyutils lockfile pid={os.getpid()}",
+        )
+        return self.zk_lease
 
     def try_acquire_lock_once(self) -> bool:
         """Attempt to acquire the lock with no blocking.
@@ -112,24 +163,18 @@ class LockFile(contextlib.AbstractContextManager):
         Returns:
             True if the lock was acquired and False otherwise.
         """
-        logger.debug("Trying to acquire %s.", self.lockfile)
-        try:
-            # Attempt to create the lockfile.  These flags cause
-            # os.open to raise an OSError if the file already
-            # exists.
-            fd = os.open(self.lockfile, os.O_CREAT | os.O_EXCL | os.O_RDWR)
-            with os.fdopen(fd, "a") as f:
-                contents = self._get_lockfile_contents()
-                logger.debug(contents)
-                f.write(contents)
+        success = False
+        if self.zk_client:
+            if self._try_acquire_zk_lock():
+                success = True
+        else:
+            success = self._try_acquire_local_filesystem_lock()
+
+        if success:
             self.locktime = datetime.datetime.now().timestamp()
             logger.debug("Success; I own %s.", self.lockfile)
             self.is_locked = True
-            return True
-        except OSError:
-            pass
-        logger.warning("Couldn't acquire %s.", self.lockfile)
-        return False
+        return success
 
     def acquire_with_retries(
         self,
@@ -167,21 +212,28 @@ class LockFile(contextlib.AbstractContextManager):
 
     def release(self) -> None:
         """Release the lock"""
-        try:
-            os.unlink(self.lockfile)
-        except Exception as e:
-            logger.exception(e)
+
+        if not self.zk_client:
+            try:
+                os.unlink(self.lockfile)
+            except Exception as e:
+                logger.exception(e)
+        else:
+            if self.zk_lease:
+                self.zk_lease.release()
+            self.zk_client.stop()
         self.is_locked = False
 
     def __enter__(self):
         if self.acquire_with_retries():
             return self
-        raw_contents = self._read_lockfile()
-        if raw_contents:
-            contents = LockFileContents(**json.loads(raw_contents))
-            msg = f"Couldn't acquire {self.lockfile} after several attempts.  It's held by pid={contents.pid} ({contents.commandline}).  Giving up."
-        else:
-            msg = "Couldn't acquire lockfile; giving up."
+
+        msg = "Couldn't acquire lockfile; giving up."
+        if not self.zk_client:
+            raw_contents = self._read_lockfile()
+            if raw_contents:
+                contents = LocalLockFileContents(**json.loads(raw_contents))
+                msg = f"Couldn't acquire {self.lockfile} after several attempts.  It's held by pid={contents.pid} ({contents.commandline}).  Giving up."
         logger.warning(msg)
         raise LockFileException(msg)
 
@@ -211,52 +263,55 @@ class LockFile(contextlib.AbstractContextManager):
         if self.is_locked:
             self.release()
 
-    def _get_lockfile_contents(self) -> str:
-        if self.override_command:
-            cmd = self.override_command
-        else:
-            cmd = " ".join(sys.argv)
-        contents = LockFileContents(
-            pid=os.getpid(),
-            commandline=cmd,
-            expiration_timestamp=self.expiration_timestamp,
-        )
-        return json.dumps(contents.__dict__)
+    def _construct_local_lockfile_contents(self) -> str:
+        if not self.zk_client:
+            if self.override_command:
+                cmd = self.override_command
+            else:
+                cmd = " ".join(sys.argv)
+            contents = LocalLockFileContents(
+                pid=os.getpid(),
+                commandline=cmd,
+                expiration_timestamp=self.expiration_timestamp,
+            )
+            return json.dumps(contents.__dict__)
+        raise Exception("Non-local lockfiles should not call this?!")
 
     def _read_lockfile(self) -> Optional[str]:
-        try:
-            with open(self.lockfile, "r") as rf:
-                lines = rf.readlines()
-                return lines[0]
-        except Exception as e:
-            logger.exception(e)
+        if not self.zk_client:
+            try:
+                with open(self.lockfile, "r") as rf:
+                    lines = rf.readlines()
+                    return lines[0]
+            except Exception as e:
+                logger.exception(e)
         return None
 
     def _detect_stale_lockfile(self) -> None:
-        raw_contents = self._read_lockfile()
-        if not raw_contents:
-            return
-
-        contents = LockFileContents(**json.loads(raw_contents))
-        logger.debug('Blocking lock contents="%s"', contents)
-
-        # Does the PID exist still?
-        try:
-            os.kill(contents.pid, 0)
-        except OSError:
-            logger.warning(
-                "Lockfile %s's pid (%d) is stale; force acquiring...",
-                self.lockfile,
-                contents.pid,
-            )
-            self.release()
-
-        # Has the lock expiration expired?
-        if contents.expiration_timestamp is not None:
-            now = datetime.datetime.now().timestamp()
-            if now > contents.expiration_timestamp:
+        if not self.zk_client:
+            raw_contents = self._read_lockfile()
+            if not raw_contents:
+                return
+            contents = LocalLockFileContents(**json.loads(raw_contents))
+            logger.debug('Blocking lock contents="%s"', contents)
+
+            # Does the PID exist still?
+            try:
+                os.kill(contents.pid, 0)
+            except OSError:
                 logger.warning(
-                    "Lockfile %s's expiration time has passed; force acquiring",
+                    "Lockfile %s's pid (%d) is stale; force acquiring...",
                     self.lockfile,
+                    contents.pid,
                 )
                 self.release()
+
+            # Has the lock expiration expired?
+            if contents.expiration_timestamp is not None:
+                now = datetime.datetime.now().timestamp()
+                if now > contents.expiration_timestamp:
+                    logger.warning(
+                        "Lockfile %s's expiration time has passed; force acquiring",
+                        self.lockfile,
+                    )
+                    self.release()