From 2f5b47c8b30d1b7d86443391332be2f3805cdafd Mon Sep 17 00:00:00 2001 From: Scott Gasch Date: Tue, 8 Feb 2022 20:21:02 -0800 Subject: [PATCH] Cleanup more contextlib.AbstractContextManagers and Literal[False]s. --- ansi.py | 15 ++++++++------- file_utils.py | 24 +++++++++++++----------- lockfile.py | 28 +++++++++++++++++++--------- stopwatch.py | 9 +++++---- string_utils.py | 8 ++++---- text_utils.py | 8 +++++--- 6 files changed, 54 insertions(+), 38 deletions(-) diff --git a/ansi.py b/ansi.py index 03f8fd2..a497600 100755 --- a/ansi.py +++ b/ansi.py @@ -4,20 +4,20 @@ setting the text color, background, etc... using ANSI escape sequences.""" +import contextlib import difflib import io import logging import re import sys from abc import abstractmethod -from typing import Any, Callable, Dict, Iterable, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, Literal, Optional, Tuple from overrides import overrides import logging_utils import string_utils - logger = logging.getLogger(__name__) # https://en.wikipedia.org/wiki/ANSI_escape_code @@ -1885,10 +1885,9 @@ def bg( return bg_24bit(red, green, blue) -class StdoutInterceptor(io.TextIOBase): - """An interceptor for data written to stdout. Use as a context. +class StdoutInterceptor(io.TextIOBase, contextlib.AbstractContextManager): + """An interceptor for data written to stdout. Use as a context.""" - """ def __init__(self): super().__init__() self.saved_stdout: io.TextIO = None @@ -1903,10 +1902,10 @@ class StdoutInterceptor(io.TextIOBase): sys.stdout = self return self - def __exit__(self, *args) -> Optional[bool]: + def __exit__(self, *args) -> Literal[False]: sys.stdout = self.saved_stdout print(self.buf) - return None + return False class ProgrammableColorizer(StdoutInterceptor): @@ -1914,6 +1913,7 @@ class ProgrammableColorizer(StdoutInterceptor): something (usually add color to) the match. """ + def __init__( self, patterns: Iterable[Tuple[re.Pattern, Callable[[Any, re.Pattern], str]]], @@ -1929,6 +1929,7 @@ class ProgrammableColorizer(StdoutInterceptor): if __name__ == '__main__': + def main() -> None: import doctest diff --git a/file_utils.py b/file_utils.py index d1e2eff..98e8c26 100644 --- a/file_utils.py +++ b/file_utils.py @@ -2,6 +2,7 @@ """Utilities for working with files.""" +import contextlib import datetime import errno import glob @@ -12,21 +13,21 @@ import pathlib import re import time from os.path import exists, isfile, join -from typing import List, Optional, TextIO +from typing import Callable, List, Literal, Optional, TextIO from uuid import uuid4 logger = logging.getLogger(__name__) -def remove_newlines(x): +def remove_newlines(x: str) -> str: return x.replace('\n', '') -def strip_whitespace(x): +def strip_whitespace(x: str) -> str: return x.strip() -def remove_hash_comments(x): +def remove_hash_comments(x: str) -> str: return re.sub(r'#.*$', '', x) @@ -34,15 +35,16 @@ def slurp_file( filename: str, *, skip_blank_lines=False, - line_transformers=[], + line_transformers: Optional[List[Callable[[str], str]]] = None, ): ret = [] if not file_is_readable(filename): raise Exception(f'{filename} can\'t be read.') with open(filename) as rf: for line in rf: - for transformation in line_transformers: - line = transformation(line) + if line_transformers is not None: + for transformation in line_transformers: + line = transformation(line) if skip_blank_lines and line == '': continue ret.append(line) @@ -460,7 +462,7 @@ def get_files_recursive(directory: str): yield file_or_directory -class FileWriter(object): +class FileWriter(contextlib.AbstractContextManager): """A helper that writes a file to a temporary location and then moves it atomically to its ultimate destination on close. @@ -477,14 +479,14 @@ class FileWriter(object): self.handle = open(self.tempfile, mode="w") return self.handle - def __exit__(self, exc_type, exc_val, exc_tb) -> Optional[bool]: + def __exit__(self, exc_type, exc_val, exc_tb) -> Literal[False]: if self.handle is not None: self.handle.close() cmd = f'/bin/mv -f {self.tempfile} {self.filename}' ret = os.system(cmd) if (ret >> 8) != 0: - raise Exception(f'{cmd} failed, exit value {ret>>8}') - return None + raise Exception(f'{cmd} failed, exit value {ret>>8}!') + return False if __name__ == '__main__': diff --git a/lockfile.py b/lockfile.py index 6993cb8..10fe10d 100644 --- a/lockfile.py +++ b/lockfile.py @@ -2,6 +2,9 @@ """File-based locking helper.""" +from __future__ import annotations + +import contextlib import datetime import json import logging @@ -10,7 +13,7 @@ import signal import sys import warnings from dataclasses import dataclass -from typing import Optional +from typing import Literal, Optional import config import datetime_utils @@ -42,7 +45,7 @@ class LockFileContents: expiration_timestamp: Optional[float] -class LockFile(object): +class LockFile(contextlib.AbstractContextManager): """A file locking mechanism that has context-manager support so you can use it in a with statement. e.g. @@ -131,16 +134,18 @@ class LockFile(object): logger.warning(msg) raise LockFileException(msg) - def __exit__(self, _, value, traceback): + def __exit__(self, _, value, traceback) -> Literal[False]: if self.locktime: ts = datetime.datetime.now().timestamp() duration = ts - self.locktime if duration >= config.config['lockfile_held_duration_warning_threshold_sec']: - str_duration = datetime_utils.describe_duration_briefly(duration) + # Note: describe duration briefly only does 1s granularity... + str_duration = datetime_utils.describe_duration_briefly(int(duration)) msg = f'Held {self.lockfile} for {str_duration}' logger.warning(msg) warnings.warn(msg, stacklevel=2) self.release() + return False def __del__(self): if self.is_locked: @@ -176,16 +181,21 @@ class LockFile(object): try: os.kill(contents.pid, 0) except OSError: - msg = f'Lockfile {self.lockfile}\'s pid ({contents.pid}) is stale; force acquiring' - logger.warning(msg) + logger.warning( + 'Lockfile %s\'s pid (%d) is stale; force acquiring...', + self.lockfile, + contents.pid, + ) self.release() # Has the lock expiration expired? if contents.expiration_timestamp is not None: now = datetime.datetime.now().timestamp() if now > contents.expiration_timestamp: - msg = f'Lockfile {self.lockfile} expiration time has passed; force acquiring' - logger.warning(msg) + logger.warning( + 'Lockfile %s\'s expiration time has passed; force acquiring', + self.lockfile, + ) self.release() except Exception: - pass + pass # If the lockfile doesn't exist or disappears, good. diff --git a/stopwatch.py b/stopwatch.py index fa4f2b5..ae8e01a 100644 --- a/stopwatch.py +++ b/stopwatch.py @@ -2,11 +2,12 @@ """A simple stopwatch decorator / context for timing things.""" +import contextlib import time -from typing import Callable, Optional +from typing import Callable, Literal -class Timer(object): +class Timer(contextlib.AbstractContextManager): """ A stopwatch to time how long something takes (walltime). @@ -31,6 +32,6 @@ class Timer(object): self.end = 0.0 return lambda: self.end - self.start - def __exit__(self, *args) -> Optional[bool]: + def __exit__(self, *args) -> Literal[False]: self.end = time.perf_counter() - return None # don't suppress exceptions + return False diff --git a/string_utils.py b/string_utils.py index d75c6ba..adfb149 100644 --- a/string_utils.py +++ b/string_utils.py @@ -40,7 +40,7 @@ import string import unicodedata import warnings from itertools import zip_longest -from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple +from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Sequence, Tuple from uuid import uuid4 import list_utils @@ -1208,7 +1208,7 @@ def sprintf(*args, **kwargs) -> str: return ret -class SprintfStdout(object): +class SprintfStdout(contextlib.AbstractContextManager): """ A context manager that captures outputs to stdout. @@ -1228,10 +1228,10 @@ class SprintfStdout(object): self.recorder.__enter__() return lambda: self.destination.getvalue() - def __exit__(self, *args) -> None: + def __exit__(self, *args) -> Literal[False]: self.recorder.__exit__(*args) self.destination.seek(0) - return None # don't suppress exceptions + return False def capitalize_first_letter(txt: str) -> str: diff --git a/text_utils.py b/text_utils.py index 4384a1e..7910990 100644 --- a/text_utils.py +++ b/text_utils.py @@ -3,11 +3,12 @@ """Utilities for dealing with "text".""" +import contextlib import logging import math import sys from collections import defaultdict -from typing import Dict, Generator, List, NamedTuple, Optional, Tuple +from typing import Dict, Generator, List, Literal, NamedTuple, Optional, Tuple from ansi import fg, reset @@ -261,7 +262,7 @@ def wrap_string(text: str, n: int) -> str: return out -class Indenter(object): +class Indenter(contextlib.AbstractContextManager): """ with Indenter(pad_count = 8) as i: i.print('test') @@ -289,10 +290,11 @@ class Indenter(object): self.level += 1 return self - def __exit__(self, exc_type, exc_value, exc_tb): + def __exit__(self, exc_type, exc_value, exc_tb) -> Literal[False]: self.level -= 1 if self.level < -1: self.level = -1 + return False def print(self, *arg, **kwargs): import string_utils -- 2.45.2