Improve identifier for zookeeper based lockfiles.
[pyutils.git] / src / pyutils / files / lockfile.py
index 11bb1001156127eaa5ded750d6fadb74f77884fe..158a636653eefd7b66ee31c698a7dc909dc0e770 100644 (file)
@@ -2,7 +2,22 @@
 
 # © Copyright 2021-2022, Scott Gasch
 
-"""File-based locking helper."""
+"""This is a lockfile implementation I created for use with cronjobs
+on my machine to prevent multiple copies of a job from running in
+parallel.
+
+For local operations, when one job is running this code keeps a file
+on disk to indicate a lock is held.  Other copies will fail to start
+if they detect this lock until the lock is released.  There are
+provisions in the code for timing out locks, cleaning up a lock when a
+signal is received, gracefully retrying lock acquisition on failure,
+etc...
+
+Also allows for Zookeeper-based locks when lockfile path is prefixed
+with 'zk:' in order to synchronize processes across different
+machines.
+
+"""
 
 from __future__ import annotations
 
@@ -11,22 +26,25 @@ import datetime
 import json
 import logging
 import os
+import platform
 import signal
 import sys
 import warnings
 from dataclasses import dataclass
 from typing import Literal, Optional
 
-from pyutils import config, decorator_utils
-from pyutils.datetimez import datetime_utils
+import kazoo
 
-cfg = config.add_commandline_args(f'Lockfile ({__file__})', 'Args related to lockfiles')
+from pyutils import argparse_utils, config, decorator_utils, zookeeper
+from pyutils.datetimes import datetime_utils
+
+cfg = config.add_commandline_args(f"Lockfile ({__file__})", "Args related to lockfiles")
 cfg.add_argument(
-    '--lockfile_held_duration_warning_threshold_sec',
-    type=float,
-    default=60.0,
-    metavar='SECONDS',
-    help='If a lock is held for longer than this threshold we log a warning',
+    "--lockfile_held_duration_warning_threshold",
+    type=argparse_utils.valid_duration,
+    default=datetime.timedelta(60.0),
+    metavar="DURATION",
+    help="If a lock is held for longer than this threshold we log a warning",
 )
 logger = logging.getLogger(__name__)
 
@@ -38,7 +56,7 @@ class LockFileException(Exception):
 
 
 @dataclass
-class LockFileContents:
+class LocalLockFileContents:
     """The contents we'll write to each lock file."""
 
     pid: int
@@ -73,55 +91,94 @@ class LockFile(contextlib.AbstractContextManager):
         """C'tor.
 
         Args:
-            lockfile_path: path of the lockfile to acquire
+            lockfile_path: path of the lockfile to acquire; may begin
+                with zk: to indicate a path in zookeeper rather than
+                on the local filesystem.  Note that zookeeper-based
+                locks require an expiration_timestamp as the stale
+                detection semantics are skipped for non-local locks.
             do_signal_cleanup: handle SIGINT and SIGTERM events by
                 releasing the lock before exiting
             expiration_timestamp: when our lease on the lock should
                 expire (as seconds since the Epoch).  None means the
                 lock will not expire until we explicltly release it.
+                Note that this is required for zookeeper based locks.
             override_command: don't use argv to determine our commandline
                 rather use this instead if provided.
         """
         self.is_locked: bool = False
-        self.lockfile: str = lockfile_path
-        self.locktime: Optional[int] = None
+        self.lockfile: str = ""
+        self.zk_client: Optional[kazoo.client.KazooClient] = None
+        self.zk_lease: Optional[zookeeper.RenewableReleasableLease] = None
+
+        if lockfile_path.startswith("zk:"):
+            logger.debug("Lockfile is on Zookeeper.")
+            if expiration_timestamp is None:
+                raise Exception("Zookeeper locks require an expiration timestamp")
+            self.lockfile = lockfile_path[3:]
+            if not self.lockfile.startswith("/leases"):
+                self.lockfile = "/leases" + self.lockfile
+            self.zk_client = zookeeper.get_started_zk_client()
+        else:
+            logger.debug("Lockfile is local.")
+            self.lockfile = lockfile_path
+        self.locktime: Optional[float] = None
         self.override_command: Optional[str] = override_command
         if do_signal_cleanup:
             signal.signal(signal.SIGINT, self._signal)
             signal.signal(signal.SIGTERM, self._signal)
         self.expiration_timestamp = expiration_timestamp
 
-    def locked(self):
+    def locked(self) -> bool:
         """Is it locked currently?"""
         return self.is_locked
 
-    def available(self):
-        """Is it available currently?"""
-        return not os.path.exists(self.lockfile)
-
-    def try_acquire_lock_once(self) -> bool:
-        """Attempt to acquire the lock with no blocking.
-
-        Returns:
-            True if the lock was acquired and False otherwise.
+    def _try_acquire_local_filesystem_lock(self) -> bool:
+        """Attempt to create the lockfile.  These flags cause os.open
+        to raise an OSError if the file already exists.
         """
-        logger.debug("Trying to acquire %s.", self.lockfile)
         try:
-            # Attempt to create the lockfile.  These flags cause
-            # os.open to raise an OSError if the file already
-            # exists.
+            logger.debug("Trying to acquire local lock %s.", self.lockfile)
             fd = os.open(self.lockfile, os.O_CREAT | os.O_EXCL | os.O_RDWR)
             with os.fdopen(fd, "a") as f:
-                contents = self._get_lockfile_contents()
+                contents = self._construct_local_lockfile_contents()
                 logger.debug(contents)
                 f.write(contents)
-            logger.debug('Success; I own %s.', self.lockfile)
-            self.is_locked = True
             return True
         except OSError:
-            pass
-        logger.warning('Couldn\'t acquire %s.', self.lockfile)
-        return False
+            logger.warning("Couldn't acquire local lock %s.", self.lockfile)
+            return False
+
+    def _try_acquire_zk_lock(self) -> bool:
+        assert self.expiration_timestamp
+        identifier = f"Lockfile for pid={os.getpid()} on machine {platform.node()}"
+        if self.override_command:
+            identifier += f" running {self.override_command}"
+        self.zk_lease = zookeeper.RenewableReleasableLease(
+            self.zk_client,
+            self.lockfile,
+            datetime.timedelta(seconds=self.expiration_timestamp),
+            identifier,
+        )
+        return self.zk_lease
+
+    def try_acquire_lock_once(self) -> bool:
+        """Attempt to acquire the lock with no blocking.
+
+        Returns:
+            True if the lock was acquired and False otherwise.
+        """
+        success = False
+        if self.zk_client:
+            if self._try_acquire_zk_lock():
+                success = True
+        else:
+            success = self._try_acquire_local_filesystem_lock()
+
+        if success:
+            self.locktime = datetime.datetime.now().timestamp()
+            logger.debug("Success; I own %s.", self.lockfile)
+            self.is_locked = True
+        return success
 
     def acquire_with_retries(
         self,
@@ -157,19 +214,30 @@ class LockFile(contextlib.AbstractContextManager):
             self._detect_stale_lockfile()
         return _try_acquire_lock_with_retries()
 
-    def release(self):
+    def release(self) -> None:
         """Release the lock"""
-        try:
-            os.unlink(self.lockfile)
-        except Exception as e:
-            logger.exception(e)
+
+        if not self.zk_client:
+            try:
+                os.unlink(self.lockfile)
+            except Exception as e:
+                logger.exception(e)
+        else:
+            if self.zk_lease:
+                self.zk_lease.release()
+            self.zk_client.stop()
         self.is_locked = False
 
     def __enter__(self):
         if self.acquire_with_retries():
-            self.locktime = datetime.datetime.now().timestamp()
             return self
-        msg = f"Couldn't acquire {self.lockfile}; giving up."
+
+        msg = "Couldn't acquire lockfile; giving up."
+        if not self.zk_client:
+            raw_contents = self._read_lockfile()
+            if raw_contents:
+                contents = LocalLockFileContents(**json.loads(raw_contents))
+                msg = f"Couldn't acquire {self.lockfile} after several attempts.  It's held by pid={contents.pid} ({contents.commandline}).  Giving up."
         logger.warning(msg)
         raise LockFileException(msg)
 
@@ -179,11 +247,13 @@ class LockFile(contextlib.AbstractContextManager):
             duration = ts - self.locktime
             if (
                 duration
-                >= config.config['lockfile_held_duration_warning_threshold_sec']
+                >= config.config[
+                    "lockfile_held_duration_warning_threshold"
+                ].total_seconds()
             ):
                 # Note: describe duration briefly only does 1s granularity...
                 str_duration = datetime_utils.describe_duration_briefly(int(duration))
-                msg = f'Held {self.lockfile} for {str_duration}'
+                msg = f"Held {self.lockfile} for {str_duration}"
                 logger.warning(msg)
                 warnings.warn(msg, stacklevel=2)
         self.release()
@@ -197,47 +267,55 @@ class LockFile(contextlib.AbstractContextManager):
         if self.is_locked:
             self.release()
 
-    def _get_lockfile_contents(self) -> str:
-        if self.override_command:
-            cmd = self.override_command
-        else:
-            cmd = ' '.join(sys.argv)
-        contents = LockFileContents(
-            pid=os.getpid(),
-            commandline=cmd,
-            expiration_timestamp=self.expiration_timestamp,
-        )
-        return json.dumps(contents.__dict__)
+    def _construct_local_lockfile_contents(self) -> str:
+        if not self.zk_client:
+            if self.override_command:
+                cmd = self.override_command
+            else:
+                cmd = " ".join(sys.argv)
+            contents = LocalLockFileContents(
+                pid=os.getpid(),
+                commandline=cmd,
+                expiration_timestamp=self.expiration_timestamp,
+            )
+            return json.dumps(contents.__dict__)
+        raise Exception("Non-local lockfiles should not call this?!")
+
+    def _read_lockfile(self) -> Optional[str]:
+        if not self.zk_client:
+            try:
+                with open(self.lockfile, "r") as rf:
+                    lines = rf.readlines()
+                    return lines[0]
+            except Exception as e:
+                logger.exception(e)
+        return None
 
     def _detect_stale_lockfile(self) -> None:
-        try:
-            with open(self.lockfile, 'r') as rf:
-                lines = rf.readlines()
-                if len(lines) == 1:
-                    line = lines[0]
-                    line_dict = json.loads(line)
-                    contents = LockFileContents(**line_dict)
-                    logger.debug('Blocking lock contents="%s"', contents)
-
-                    # Does the PID exist still?
-                    try:
-                        os.kill(contents.pid, 0)
-                    except OSError:
-                        logger.warning(
-                            'Lockfile %s\'s pid (%d) is stale; force acquiring...',
-                            self.lockfile,
-                            contents.pid,
-                        )
-                        self.release()
-
-                    # Has the lock expiration expired?
-                    if contents.expiration_timestamp is not None:
-                        now = datetime.datetime.now().timestamp()
-                        if now > contents.expiration_timestamp:
-                            logger.warning(
-                                'Lockfile %s\'s expiration time has passed; force acquiring',
-                                self.lockfile,
-                            )
-                            self.release()
-        except Exception:
-            pass  # If the lockfile doesn't exist or disappears, good.
+        if not self.zk_client:
+            raw_contents = self._read_lockfile()
+            if not raw_contents:
+                return
+            contents = LocalLockFileContents(**json.loads(raw_contents))
+            logger.debug('Blocking lock contents="%s"', contents)
+
+            # Does the PID exist still?
+            try:
+                os.kill(contents.pid, 0)
+            except OSError:
+                logger.warning(
+                    "Lockfile %s's pid (%d) is stale; force acquiring...",
+                    self.lockfile,
+                    contents.pid,
+                )
+                self.release()
+
+            # Has the lock expiration expired?
+            if contents.expiration_timestamp is not None:
+                now = datetime.datetime.now().timestamp()
+                if now > contents.expiration_timestamp:
+                    logger.warning(
+                        "Lockfile %s's expiration time has passed; force acquiring",
+                        self.lockfile,
+                    )
+                    self.release()