Cleanup more contextlib.AbstractContextManagers and Literal[False]s.
authorScott Gasch <[email protected]>
Wed, 9 Feb 2022 04:21:02 +0000 (20:21 -0800)
committerScott Gasch <[email protected]>
Wed, 9 Feb 2022 04:21:02 +0000 (20:21 -0800)
ansi.py
file_utils.py
lockfile.py
stopwatch.py
string_utils.py
text_utils.py

diff --git a/ansi.py b/ansi.py
index 03f8fd27473c4e295cdafbb9e7122b4cec296453..a49760037e31c6c01e84a225e7147d588f8f5934 100755 (executable)
--- a/ansi.py
+++ b/ansi.py
@@ -4,20 +4,20 @@
 setting the text color, background, etc... using ANSI escape
 sequences."""
 
+import contextlib
 import difflib
 import io
 import logging
 import re
 import sys
 from abc import abstractmethod
-from typing import Any, Callable, Dict, Iterable, Optional, Tuple
+from typing import Any, Callable, Dict, Iterable, Literal, Optional, Tuple
 
 from overrides import overrides
 
 import logging_utils
 import string_utils
 
-
 logger = logging.getLogger(__name__)
 
 # https://en.wikipedia.org/wiki/ANSI_escape_code
@@ -1885,10 +1885,9 @@ def bg(
     return bg_24bit(red, green, blue)
 
 
-class StdoutInterceptor(io.TextIOBase):
-    """An interceptor for data written to stdout.  Use as a context.
+class StdoutInterceptor(io.TextIOBase, contextlib.AbstractContextManager):
+    """An interceptor for data written to stdout.  Use as a context."""
 
-    """
     def __init__(self):
         super().__init__()
         self.saved_stdout: io.TextIO = None
@@ -1903,10 +1902,10 @@ class StdoutInterceptor(io.TextIOBase):
         sys.stdout = self
         return self
 
-    def __exit__(self, *args) -> Optional[bool]:
+    def __exit__(self, *args) -> Literal[False]:
         sys.stdout = self.saved_stdout
         print(self.buf)
-        return None
+        return False
 
 
 class ProgrammableColorizer(StdoutInterceptor):
@@ -1914,6 +1913,7 @@ class ProgrammableColorizer(StdoutInterceptor):
     something (usually add color to) the match.
 
     """
+
     def __init__(
         self,
         patterns: Iterable[Tuple[re.Pattern, Callable[[Any, re.Pattern], str]]],
@@ -1929,6 +1929,7 @@ class ProgrammableColorizer(StdoutInterceptor):
 
 
 if __name__ == '__main__':
+
     def main() -> None:
         import doctest
 
index d1e2eff1fac729b6251179ad7dfc584a0674215f..98e8c2670fd173fb11be51327fb6925a2b47f1eb 100644 (file)
@@ -2,6 +2,7 @@
 
 """Utilities for working with files."""
 
+import contextlib
 import datetime
 import errno
 import glob
@@ -12,21 +13,21 @@ import pathlib
 import re
 import time
 from os.path import exists, isfile, join
-from typing import List, Optional, TextIO
+from typing import Callable, List, Literal, Optional, TextIO
 from uuid import uuid4
 
 logger = logging.getLogger(__name__)
 
 
-def remove_newlines(x):
+def remove_newlines(x: str) -> str:
     return x.replace('\n', '')
 
 
-def strip_whitespace(x):
+def strip_whitespace(x: str) -> str:
     return x.strip()
 
 
-def remove_hash_comments(x):
+def remove_hash_comments(x: str) -> str:
     return re.sub(r'#.*$', '', x)
 
 
@@ -34,15 +35,16 @@ def slurp_file(
     filename: str,
     *,
     skip_blank_lines=False,
-    line_transformers=[],
+    line_transformers: Optional[List[Callable[[str], str]]] = None,
 ):
     ret = []
     if not file_is_readable(filename):
         raise Exception(f'{filename} can\'t be read.')
     with open(filename) as rf:
         for line in rf:
-            for transformation in line_transformers:
-                line = transformation(line)
+            if line_transformers is not None:
+                for transformation in line_transformers:
+                    line = transformation(line)
             if skip_blank_lines and line == '':
                 continue
             ret.append(line)
@@ -460,7 +462,7 @@ def get_files_recursive(directory: str):
             yield file_or_directory
 
 
-class FileWriter(object):
+class FileWriter(contextlib.AbstractContextManager):
     """A helper that writes a file to a temporary location and then moves
     it atomically to its ultimate destination on close.
 
@@ -477,14 +479,14 @@ class FileWriter(object):
         self.handle = open(self.tempfile, mode="w")
         return self.handle
 
-    def __exit__(self, exc_type, exc_val, exc_tb) -> Optional[bool]:
+    def __exit__(self, exc_type, exc_val, exc_tb) -> Literal[False]:
         if self.handle is not None:
             self.handle.close()
             cmd = f'/bin/mv -f {self.tempfile} {self.filename}'
             ret = os.system(cmd)
             if (ret >> 8) != 0:
-                raise Exception(f'{cmd} failed, exit value {ret>>8}')
-        return None
+                raise Exception(f'{cmd} failed, exit value {ret>>8}!')
+        return False
 
 
 if __name__ == '__main__':
index 6993cb84d5e88f8dd6fc1a9d28a849f0cfd28713..10fe10dde1215b87c664e869b477e970de54f9f8 100644 (file)
@@ -2,6 +2,9 @@
 
 """File-based locking helper."""
 
+from __future__ import annotations
+
+import contextlib
 import datetime
 import json
 import logging
@@ -10,7 +13,7 @@ import signal
 import sys
 import warnings
 from dataclasses import dataclass
-from typing import Optional
+from typing import Literal, Optional
 
 import config
 import datetime_utils
@@ -42,7 +45,7 @@ class LockFileContents:
     expiration_timestamp: Optional[float]
 
 
-class LockFile(object):
+class LockFile(contextlib.AbstractContextManager):
     """A file locking mechanism that has context-manager support so you
     can use it in a with statement.  e.g.
 
@@ -131,16 +134,18 @@ class LockFile(object):
         logger.warning(msg)
         raise LockFileException(msg)
 
-    def __exit__(self, _, value, traceback):
+    def __exit__(self, _, value, traceback) -> Literal[False]:
         if self.locktime:
             ts = datetime.datetime.now().timestamp()
             duration = ts - self.locktime
             if duration >= config.config['lockfile_held_duration_warning_threshold_sec']:
-                str_duration = datetime_utils.describe_duration_briefly(duration)
+                # Note: describe duration briefly only does 1s granularity...
+                str_duration = datetime_utils.describe_duration_briefly(int(duration))
                 msg = f'Held {self.lockfile} for {str_duration}'
                 logger.warning(msg)
                 warnings.warn(msg, stacklevel=2)
         self.release()
+        return False
 
     def __del__(self):
         if self.is_locked:
@@ -176,16 +181,21 @@ class LockFile(object):
                     try:
                         os.kill(contents.pid, 0)
                     except OSError:
-                        msg = f'Lockfile {self.lockfile}\'s pid ({contents.pid}) is stale; force acquiring'
-                        logger.warning(msg)
+                        logger.warning(
+                            'Lockfile %s\'s pid (%d) is stale; force acquiring...',
+                            self.lockfile,
+                            contents.pid,
+                        )
                         self.release()
 
                     # Has the lock expiration expired?
                     if contents.expiration_timestamp is not None:
                         now = datetime.datetime.now().timestamp()
                         if now > contents.expiration_timestamp:
-                            msg = f'Lockfile {self.lockfile} expiration time has passed; force acquiring'
-                            logger.warning(msg)
+                            logger.warning(
+                                'Lockfile %s\'s expiration time has passed; force acquiring',
+                                self.lockfile,
+                            )
                             self.release()
         except Exception:
-            pass
+            pass  # If the lockfile doesn't exist or disappears, good.
index fa4f2b52f54998bb029f7e14433e5451abd67812..ae8e01aa18452ed932ef589b9047c7f0c68ab58a 100644 (file)
@@ -2,11 +2,12 @@
 
 """A simple stopwatch decorator / context for timing things."""
 
+import contextlib
 import time
-from typing import Callable, Optional
+from typing import Callable, Literal
 
 
-class Timer(object):
+class Timer(contextlib.AbstractContextManager):
     """
     A stopwatch to time how long something takes (walltime).
 
@@ -31,6 +32,6 @@ class Timer(object):
         self.end = 0.0
         return lambda: self.end - self.start
 
-    def __exit__(self, *args) -> Optional[bool]:
+    def __exit__(self, *args) -> Literal[False]:
         self.end = time.perf_counter()
-        return None  # don't suppress exceptions
+        return False
index d75c6ba1aca2c559ed4254d535747c54f4719bf5..adfb149204b525327fc854fc56d6ab5645e321bc 100644 (file)
@@ -40,7 +40,7 @@ import string
 import unicodedata
 import warnings
 from itertools import zip_longest
-from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple
+from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Sequence, Tuple
 from uuid import uuid4
 
 import list_utils
@@ -1208,7 +1208,7 @@ def sprintf(*args, **kwargs) -> str:
     return ret
 
 
-class SprintfStdout(object):
+class SprintfStdout(contextlib.AbstractContextManager):
     """
     A context manager that captures outputs to stdout.
 
@@ -1228,10 +1228,10 @@ class SprintfStdout(object):
         self.recorder.__enter__()
         return lambda: self.destination.getvalue()
 
-    def __exit__(self, *args) -> None:
+    def __exit__(self, *args) -> Literal[False]:
         self.recorder.__exit__(*args)
         self.destination.seek(0)
-        return None  # don't suppress exceptions
+        return False
 
 
 def capitalize_first_letter(txt: str) -> str:
index 4384a1e6134810982e9227d2bb1dfdb517627f72..7910990f2a58ec318ba6a68a7ba9e87ff415c77b 100644 (file)
@@ -3,11 +3,12 @@
 
 """Utilities for dealing with "text"."""
 
+import contextlib
 import logging
 import math
 import sys
 from collections import defaultdict
-from typing import Dict, Generator, List, NamedTuple, Optional, Tuple
+from typing import Dict, Generator, List, Literal, NamedTuple, Optional, Tuple
 
 from ansi import fg, reset
 
@@ -261,7 +262,7 @@ def wrap_string(text: str, n: int) -> str:
     return out
 
 
-class Indenter(object):
+class Indenter(contextlib.AbstractContextManager):
     """
     with Indenter(pad_count = 8) as i:
         i.print('test')
@@ -289,10 +290,11 @@ class Indenter(object):
         self.level += 1
         return self
 
-    def __exit__(self, exc_type, exc_value, exc_tb):
+    def __exit__(self, exc_type, exc_value, exc_tb) -> Literal[False]:
         self.level -= 1
         if self.level < -1:
             self.level = -1
+        return False
 
     def print(self, *arg, **kwargs):
         import string_utils