From: Scott Date: Thu, 27 Jan 2022 05:34:26 +0000 (-0800) Subject: Ran black code formatter on everything. X-Git-Url: https://wannabe.guru.org/gitweb/?a=commitdiff_plain;h=36fea7f15ed17150691b5b3ead75450e575229ef;p=python_utils.git Ran black code formatter on everything. --- diff --git a/acl.py b/acl.py index 2b34767..adec643 100644 --- a/acl.py +++ b/acl.py @@ -19,6 +19,7 @@ class Order(enum.Enum): """A helper to express the order of evaluation for allows/denies in an Access Control List. """ + UNDEFINED = 0 ALLOW_DENY = 1 DENY_ALLOW = 2 @@ -28,17 +29,16 @@ class SimpleACL(ABC): """A simple Access Control List interface.""" def __init__( - self, - *, - order_to_check_allow_deny: Order, - default_answer: bool + self, *, order_to_check_allow_deny: Order, default_answer: bool ): if order_to_check_allow_deny not in ( - Order.ALLOW_DENY, Order.DENY_ALLOW + Order.ALLOW_DENY, + Order.DENY_ALLOW, ): raise Exception( - 'order_to_check_allow_deny must be Order.ALLOW_DENY or ' + - 'Order.DENY_ALLOW') + 'order_to_check_allow_deny must be Order.ALLOW_DENY or ' + + 'Order.DENY_ALLOW' + ) self.order_to_check_allow_deny = order_to_check_allow_deny self.default_answer = default_answer @@ -64,8 +64,8 @@ class SimpleACL(ABC): return True logger.debug( - f'{x} was not explicitly allowed or denied; ' + - f'using default answer ({self.default_answer})' + f'{x} was not explicitly allowed or denied; ' + + f'using default answer ({self.default_answer})' ) return self.default_answer @@ -82,15 +82,18 @@ class SimpleACL(ABC): class SetBasedACL(SimpleACL): """An ACL that allows or denies based on membership in a set.""" - def __init__(self, - *, - allow_set: Optional[Set[Any]] = None, - deny_set: Optional[Set[Any]] = None, - order_to_check_allow_deny: Order, - default_answer: bool) -> None: + + def __init__( + self, + *, + allow_set: Optional[Set[Any]] = None, + deny_set: Optional[Set[Any]] = None, + order_to_check_allow_deny: Order, + default_answer: bool, + ) -> None: super().__init__( order_to_check_allow_deny=order_to_check_allow_deny, - default_answer=default_answer + default_answer=default_answer, ) self.allow_set = allow_set self.deny_set = deny_set @@ -112,52 +115,55 @@ class AllowListACL(SetBasedACL): """Convenience subclass for a list that only allows known items. i.e. a 'allowlist' """ - def __init__(self, - *, - allow_set: Optional[Set[Any]]) -> None: + + def __init__(self, *, allow_set: Optional[Set[Any]]) -> None: super().__init__( - allow_set = allow_set, - order_to_check_allow_deny = Order.ALLOW_DENY, - default_answer = False) + allow_set=allow_set, + order_to_check_allow_deny=Order.ALLOW_DENY, + default_answer=False, + ) class DenyListACL(SetBasedACL): """Convenience subclass for a list that only disallows known items. i.e. a 'blocklist' """ - def __init__(self, - *, - deny_set: Optional[Set[Any]]) -> None: + + def __init__(self, *, deny_set: Optional[Set[Any]]) -> None: super().__init__( - deny_set = deny_set, - order_to_check_allow_deny = Order.ALLOW_DENY, - default_answer = True) + deny_set=deny_set, + order_to_check_allow_deny=Order.ALLOW_DENY, + default_answer=True, + ) class BlockListACL(SetBasedACL): """Convenience subclass for a list that only disallows known items. i.e. a 'blocklist' """ - def __init__(self, - *, - deny_set: Optional[Set[Any]]) -> None: + + def __init__(self, *, deny_set: Optional[Set[Any]]) -> None: super().__init__( - deny_set = deny_set, - order_to_check_allow_deny = Order.ALLOW_DENY, - default_answer = True) + deny_set=deny_set, + order_to_check_allow_deny=Order.ALLOW_DENY, + default_answer=True, + ) class PredicateListBasedACL(SimpleACL): """An ACL that allows or denies by applying predicates.""" - def __init__(self, - *, - allow_predicate_list: Sequence[Callable[[Any], bool]] = None, - deny_predicate_list: Sequence[Callable[[Any], bool]] = None, - order_to_check_allow_deny: Order, - default_answer: bool) -> None: + + def __init__( + self, + *, + allow_predicate_list: Sequence[Callable[[Any], bool]] = None, + deny_predicate_list: Sequence[Callable[[Any], bool]] = None, + order_to_check_allow_deny: Order, + default_answer: bool, + ) -> None: super().__init__( order_to_check_allow_deny=order_to_check_allow_deny, - default_answer=default_answer + default_answer=default_answer, ) self.allow_predicate_list = allow_predicate_list self.deny_predicate_list = deny_predicate_list @@ -177,12 +183,15 @@ class PredicateListBasedACL(SimpleACL): class StringWildcardBasedACL(PredicateListBasedACL): """An ACL that allows or denies based on string glob (*, ?) patterns.""" - def __init__(self, - *, - allowed_patterns: Optional[List[str]] = None, - denied_patterns: Optional[List[str]] = None, - order_to_check_allow_deny: Order, - default_answer: bool) -> None: + + def __init__( + self, + *, + allowed_patterns: Optional[List[str]] = None, + denied_patterns: Optional[List[str]] = None, + order_to_check_allow_deny: Order, + default_answer: bool, + ) -> None: allow_predicates = [] if allowed_patterns is not None: for pattern in allowed_patterns: @@ -207,12 +216,15 @@ class StringWildcardBasedACL(PredicateListBasedACL): class StringREBasedACL(PredicateListBasedACL): """An ACL that allows or denies by applying regexps.""" - def __init__(self, - *, - allowed_regexs: Optional[List[re.Pattern]] = None, - denied_regexs: Optional[List[re.Pattern]] = None, - order_to_check_allow_deny: Order, - default_answer: bool) -> None: + + def __init__( + self, + *, + allowed_regexs: Optional[List[re.Pattern]] = None, + denied_regexs: Optional[List[re.Pattern]] = None, + order_to_check_allow_deny: Order, + default_answer: bool, + ) -> None: allow_predicates = None if allowed_regexs is not None: allow_predicates = [] @@ -237,14 +249,17 @@ class StringREBasedACL(PredicateListBasedACL): class AnyCompoundACL(SimpleACL): """An ACL that allows if any of its subacls allow.""" - def __init__(self, - *, - subacls: Optional[List[SimpleACL]] = None, - order_to_check_allow_deny: Order, - default_answer: bool) -> None: + + def __init__( + self, + *, + subacls: Optional[List[SimpleACL]] = None, + order_to_check_allow_deny: Order, + default_answer: bool, + ) -> None: super().__init__( - order_to_check_allow_deny = order_to_check_allow_deny, - default_answer = default_answer + order_to_check_allow_deny=order_to_check_allow_deny, + default_answer=default_answer, ) self.subacls = subacls @@ -263,14 +278,17 @@ class AnyCompoundACL(SimpleACL): class AllCompoundACL(SimpleACL): """An ACL that allows if all of its subacls allow.""" - def __init__(self, - *, - subacls: Optional[List[SimpleACL]] = None, - order_to_check_allow_deny: Order, - default_answer: bool) -> None: + + def __init__( + self, + *, + subacls: Optional[List[SimpleACL]] = None, + order_to_check_allow_deny: Order, + default_answer: bool, + ) -> None: super().__init__( - order_to_check_allow_deny = order_to_check_allow_deny, - default_answer = default_answer + order_to_check_allow_deny=order_to_check_allow_deny, + default_answer=default_answer, ) self.subacls = subacls diff --git a/ansi.py b/ansi.py index 950a0c4..5fde4af 100755 --- a/ansi.py +++ b/ansi.py @@ -1731,13 +1731,15 @@ def _find_color_by_name(name: str) -> Tuple[int, int, int]: @logging_utils.squelch_repeated_log_messages(1) -def fg(name: Optional[str] = "", - red: Optional[int] = None, - green: Optional[int] = None, - blue: Optional[int] = None, - *, - force_16color: bool = False, - force_216color: bool = False) -> str: +def fg( + name: Optional[str] = "", + red: Optional[int] = None, + green: Optional[int] = None, + blue: Optional[int] = None, + *, + force_16color: bool = False, + force_216color: bool = False, +) -> str: """Return the ANSI escape sequence to change the foreground color being printed. Target colors may be indicated by name or R/G/B. Result will use the 16 or 216 color scheme if force_16color or @@ -1762,7 +1764,7 @@ def fg(name: Optional[str] = "", rgb[1], rgb[2], force_16color=force_16color, - force_216color=force_216color + force_216color=force_216color, ) if red is None: @@ -1791,14 +1793,16 @@ def _rgb_to_yiq(rgb: Tuple[int, int, int]) -> int: def _contrast(rgb: Tuple[int, int, int]) -> Tuple[int, int, int]: if _rgb_to_yiq(rgb) < 128: - return (0xff, 0xff, 0xff) + return (0xFF, 0xFF, 0xFF) return (0, 0, 0) -def pick_contrasting_color(name: Optional[str] = "", - red: Optional[int] = None, - green: Optional[int] = None, - blue: Optional[int] = None) -> Tuple[int, int, int]: +def pick_contrasting_color( + name: Optional[str] = "", + red: Optional[int] = None, + green: Optional[int] = None, + blue: Optional[int] = None, +) -> Tuple[int, int, int]: """This method will return a red, green, blue tuple representing a contrasting color given the red, green, blue of a background color or a color name of the background color. @@ -1827,11 +1831,7 @@ def guess_name(name: str) -> str: best_guess = None max_ratio = None for possibility in COLOR_NAMES_TO_RGB: - r = difflib.SequenceMatcher( - None, - name, - possibility - ).ratio() + r = difflib.SequenceMatcher(None, name, possibility).ratio() if max_ratio is None or r > max_ratio: max_ratio = r best_guess = possibility @@ -1840,13 +1840,15 @@ def guess_name(name: str) -> str: return best_guess -def bg(name: Optional[str] = "", - red: Optional[int] = None, - green: Optional[int] = None, - blue: Optional[int] = None, - *, - force_16color: bool = False, - force_216color: bool = False) -> str: +def bg( + name: Optional[str] = "", + red: Optional[int] = None, + green: Optional[int] = None, + blue: Optional[int] = None, + *, + force_16color: bool = False, + force_216color: bool = False, +) -> str: """Returns an ANSI color code for changing the current background color. @@ -1868,7 +1870,7 @@ def bg(name: Optional[str] = "", rgb[1], rgb[2], force_16color=force_16color, - force_216color=force_216color + force_216color=force_216color, ) if red is None: red = 0 @@ -1912,8 +1914,8 @@ class StdoutInterceptor(io.TextIOBase): class ProgrammableColorizer(StdoutInterceptor): def __init__( - self, - patterns: Iterable[Tuple[re.Pattern, Callable[[Any, re.Pattern], str]]] + self, + patterns: Iterable[Tuple[re.Pattern, Callable[[Any, re.Pattern], str]]], ): super().__init__() self.patterns = [_ for _ in patterns] @@ -1926,8 +1928,10 @@ class ProgrammableColorizer(StdoutInterceptor): if __name__ == '__main__': + def main() -> None: import doctest + doctest.testmod() name = " ".join(sys.argv[1:]) @@ -1938,6 +1942,9 @@ if __name__ == '__main__': _ = pick_contrasting_color(possibility) xf = fg(None, _[0], _[1], _[2]) xb = bg(None, _[0], _[1], _[2]) - print(f'{f}{xb}{possibility}{reset()}\t\t\t' - f'{b}{xf}{possibility}{reset()}') + print( + f'{f}{xb}{possibility}{reset()}\t\t\t' + f'{b}{xf}{possibility}{reset()}' + ) + main() diff --git a/argparse_utils.py b/argparse_utils.py index 1c61b24..8c254ae 100644 --- a/argparse_utils.py +++ b/argparse_utils.py @@ -17,12 +17,7 @@ logger = logging.getLogger(__name__) class ActionNoYes(argparse.Action): def __init__( - self, - option_strings, - dest, - default=None, - required=False, - help=None + self, option_strings, dest, default=None, required=False, help=None ): if default is None: msg = 'You must provide a default with Yes/No action' @@ -47,14 +42,13 @@ class ActionNoYes(argparse.Action): const=None, default=default, required=required, - help=help + help=help, ) @overrides def __call__(self, parser, namespace, values, option_strings=None): - if ( - option_strings.startswith('--no-') or - option_strings.startswith('--no_') + if option_strings.startswith('--no-') or option_strings.startswith( + '--no_' ): setattr(namespace, self.dest, False) else: @@ -89,6 +83,7 @@ def valid_bool(v: Any) -> bool: if isinstance(v, bool): return v from string_utils import to_bool + try: return to_bool(v) except Exception: @@ -110,6 +105,7 @@ def valid_ip(ip: str) -> str: """ from string_utils import extract_ip_v4 + s = extract_ip_v4(ip.strip()) if s is not None: return s @@ -136,6 +132,7 @@ def valid_mac(mac: str) -> str: """ from string_utils import extract_mac_address + s = extract_mac_address(mac) if s is not None: return s @@ -206,6 +203,7 @@ def valid_date(txt: str) -> datetime.date: -ANYTHING- """ from string_utils import to_date + date = to_date(txt) if date is not None: return date @@ -228,6 +226,7 @@ def valid_datetime(txt: str) -> datetime.datetime: -ANYTHING- """ from string_utils import to_datetime + dt = to_datetime(txt) if dt is not None: return dt @@ -250,6 +249,7 @@ def valid_duration(txt: str) -> datetime.timedelta: """ from datetime_utils import parse_duration + try: secs = parse_duration(txt) except Exception as e: @@ -260,5 +260,6 @@ def valid_duration(txt: str) -> datetime.timedelta: if __name__ == '__main__': import doctest + doctest.ELLIPSIS_MARKER = '-ANYTHING-' doctest.testmod() diff --git a/arper.py b/arper.py index 29a8a12..39aecf9 100644 --- a/arper.py +++ b/arper.py @@ -131,7 +131,10 @@ class Arper(persistent.Persistent): mac = mac.lower() ip = ip.strip() cached_state[mac] = ip - if len(cached_state) > config.config['arper_min_entries_to_be_valid']: + if ( + len(cached_state) + > config.config['arper_min_entries_to_be_valid'] + ): return cls(cached_state) else: msg = f'{cache_file} is invalid: only {len(cached_state)} entries. Deleting it.' @@ -144,8 +147,12 @@ class Arper(persistent.Persistent): @overrides def save(self) -> bool: if len(self.state) > config.config['arper_min_entries_to_be_valid']: - logger.debug(f'Persisting state to {config.config["arper_cache_location"]}') - with file_utils.FileWriter(config.config['arper_cache_location']) as wf: + logger.debug( + f'Persisting state to {config.config["arper_cache_location"]}' + ) + with file_utils.FileWriter( + config.config['arper_cache_location'] + ) as wf: for (mac, ip) in self.state.items(): mac = mac.lower() print(f'{mac}, {ip}', file=wf) diff --git a/base_presence.py b/base_presence.py index f846e65..f774dbc 100755 --- a/base_presence.py +++ b/base_presence.py @@ -85,7 +85,9 @@ class PresenceDetection(object): delta = now - self.last_update if ( delta.total_seconds() - > config.config['presence_tolerable_staleness_seconds'].total_seconds() + > config.config[ + 'presence_tolerable_staleness_seconds' + ].total_seconds() ): logger.debug( f"It's been {delta.total_seconds()}s since last update; refreshing now." @@ -144,7 +146,9 @@ class PresenceDetection(object): warnings.warn(msg, stacklevel=2) self.dark_locations.add(Location.HOUSE) - def read_persisted_macs_file(self, filename: str, location: Location) -> None: + def read_persisted_macs_file( + self, filename: str, location: Location + ) -> None: if location is Location.UNKNOWN: return with open(filename, "r") as rf: @@ -173,9 +177,9 @@ class PresenceDetection(object): logger.exception(e) continue mac = mac.strip() - (self.location_ts_by_mac[location])[mac] = datetime.datetime.fromtimestamp( - int(ts.strip()) - ) + (self.location_ts_by_mac[location])[ + mac + ] = datetime.datetime.fromtimestamp(int(ts.strip())) ip_name = ip_name.strip() match = re.match(r"(\d+\.\d+\.\d+\.\d+) +\(([^\)]+)\)", ip_name) if match is not None: @@ -188,7 +192,9 @@ class PresenceDetection(object): def is_anyone_in_location_now(self, location: Location) -> bool: self.maybe_update() if location in self.dark_locations: - raise Exception(f"Can't see {location} right now; answer undefined.") + raise Exception( + f"Can't see {location} right now; answer undefined." + ) for person in Person: if person is not None: loc = self.where_is_person_now(person) @@ -201,9 +207,7 @@ class PresenceDetection(object): def where_is_person_now(self, name: Person) -> Location: self.maybe_update() if len(self.dark_locations) > 0: - msg = ( - f"Can't see {self.dark_locations} right now; answer confidence impacted" - ) + msg = f"Can't see {self.dark_locations} right now; answer confidence impacted" logger.warning(msg) warnings.warn(msg, stacklevel=2) logger.debug(f'Looking for {name}...') @@ -223,21 +227,28 @@ class PresenceDetection(object): if mac not in self.names_by_mac: continue mac_name = self.names_by_mac[mac] - logger.debug(f'Looking for {name}... check for mac {mac} ({mac_name})') + logger.debug( + f'Looking for {name}... check for mac {mac} ({mac_name})' + ) for location in self.location_ts_by_mac: if mac in self.location_ts_by_mac[location]: ts = (self.location_ts_by_mac[location])[mac] - logger.debug(f'Seen {mac} ({mac_name}) at {location} since {ts}') + logger.debug( + f'Seen {mac} ({mac_name}) at {location} since {ts}' + ) tiebreaks[location] = ts - (most_recent_location, first_seen_ts) = dict_utils.item_with_max_value( - tiebreaks - ) + ( + most_recent_location, + first_seen_ts, + ) = dict_utils.item_with_max_value(tiebreaks) bonus = credit v = votes.get(most_recent_location, 0) votes[most_recent_location] = v + bonus logger.debug(f'{name}: {location} gets {bonus} votes.') - credit = int(credit * 0.2) # Note: list most important devices first + credit = int( + credit * 0.2 + ) # Note: list most important devices first if credit <= 0: credit = 1 if len(votes) > 0: diff --git a/bootstrap.py b/bootstrap.py index 738fcea..c3b70db 100644 --- a/bootstrap.py +++ b/bootstrap.py @@ -17,18 +17,19 @@ logger = logging.getLogger(__name__) args = config.add_commandline_args( f'Bootstrap ({__file__})', - 'Args related to python program bootstrapper and Swiss army knife') + 'Args related to python program bootstrapper and Swiss army knife', +) args.add_argument( '--debug_unhandled_exceptions', action=ActionNoYes, default=False, - help='Break into pdb on top level unhandled exceptions.' + help='Break into pdb on top level unhandled exceptions.', ) args.add_argument( '--show_random_seed', action=ActionNoYes, default=False, - help='Should we display (and log.debug) the global random seed?' + help='Should we display (and log.debug) the global random seed?', ) args.add_argument( '--set_random_seed', @@ -36,13 +37,13 @@ args.add_argument( nargs=1, default=None, metavar='SEED_INT', - help='Override the global random seed with a particular number.' + help='Override the global random seed with a particular number.', ) args.add_argument( '--dump_all_objects', action=ActionNoYes, default=False, - help='Should we dump the Python import tree before main?' + help='Should we dump the Python import tree before main?', ) args.add_argument( '--audit_import_events', @@ -70,18 +71,17 @@ def handle_uncaught_exception(exc_type, exc_value, exc_tb): sys.__excepthook__(exc_type, exc_value, exc_tb) return else: - if ( - not sys.stderr.isatty() or - not sys.stdin.isatty() - ): + if not sys.stderr.isatty() or not sys.stdin.isatty(): # stdin or stderr is redirected, just do the normal thing original_hook(exc_type, exc_value, exc_tb) else: # a terminal is attached and stderr is not redirected, maybe debug. import traceback + traceback.print_exception(exc_type, exc_value, exc_tb) if config.config['debug_unhandled_exceptions']: import pdb + logger.info("Invoking the debugger...") pdb.pm() else: @@ -91,6 +91,7 @@ def handle_uncaught_exception(exc_type, exc_value, exc_tb): class ImportInterceptor(object): def __init__(self): import collect.trie + self.module_by_filename_cache = {} self.repopulate_modules_by_filename() self.tree = collect.trie.Trie() @@ -120,7 +121,9 @@ class ImportInterceptor(object): loading_module = self.module_by_filename_cache[filename] else: self.repopulate_modules_by_filename() - loading_module = self.module_by_filename_cache.get(filename, 'unknown') + loading_module = self.module_by_filename_cache.get( + filename, 'unknown' + ) path = self.tree_node_by_module.get(loading_module, []) path.extend([loaded_module]) @@ -215,6 +218,7 @@ def initialize(entry_point): seed, etc... before running main. """ + @functools.wraps(entry_point) def initialize_wrapper(*args, **kwargs): # Hook top level unhandled exceptions, maybe invoke debugger. @@ -225,8 +229,8 @@ def initialize(entry_point): # parse configuration (based on cmdline flags, environment vars # etc...) if ( - '__globals__' in entry_point.__dict__ and - '__file__' in entry_point.__globals__ + '__globals__' in entry_point.__dict__ + and '__file__' in entry_point.__globals__ ): config.parse(entry_point.__globals__['__file__']) else: @@ -240,6 +244,7 @@ def initialize(entry_point): # Allow programs that don't bother to override the random seed # to be replayed via the commandline. import random + random_seed = config.config['set_random_seed'] if random_seed is not None: random_seed = random_seed[0] @@ -256,6 +261,7 @@ def initialize(entry_point): logger.debug(f'Starting {entry_point.__name__} (program entry point)') ret = None import stopwatch + with stopwatch.Timer() as t: ret = entry_point(*args, **kwargs) logger.debug( @@ -272,13 +278,15 @@ def initialize(entry_point): walltime = t() (utime, stime, cutime, cstime, elapsed_time) = os.times() - logger.debug('\n' - f'user: {utime}s\n' - f'system: {stime}s\n' - f'child user: {cutime}s\n' - f'child system: {cstime}s\n' - f'machine uptime: {elapsed_time}s\n' - f'walltime: {walltime}s') + logger.debug( + '\n' + f'user: {utime}s\n' + f'system: {stime}s\n' + f'child user: {cutime}s\n' + f'child system: {cstime}s\n' + f'machine uptime: {elapsed_time}s\n' + f'walltime: {walltime}s' + ) # If it doesn't return cleanly, call attention to the return value. if ret is not None and ret != 0: @@ -286,4 +294,5 @@ def initialize(entry_point): else: logger.debug(f'Exit {ret}') sys.exit(ret) + return initialize_wrapper diff --git a/camera_utils.py b/camera_utils.py index 799efd3..03ac621 100644 --- a/camera_utils.py +++ b/camera_utils.py @@ -56,7 +56,8 @@ def sanity_check_image(hsv: np.ndarray) -> SanityCheckImageMetadata: hs_zero_count += 1 logger.debug(f"hszero#={hs_zero_count}, weird_orange={weird_orange_count}") return SanityCheckImageMetadata( - hs_zero_count > (num_pixels * 0.75), weird_orange_count > (num_pixels * 0.75) + hs_zero_count > (num_pixels * 0.75), + weird_orange_count > (num_pixels * 0.75), ) @@ -73,7 +74,9 @@ def fetch_camera_image_from_video_server( response = requests.get(url, stream=False, timeout=10.0) if response.ok: raw = response.content - logger.debug(f'Read {len(response.content)} byte image from HTTP server') + logger.debug( + f'Read {len(response.content)} byte image from HTTP server' + ) tmp = np.frombuffer(raw, dtype="uint8") logger.debug( f'Translated raw content into {tmp.shape} {type(tmp)} with element type {type(tmp[0])}.' @@ -169,7 +172,9 @@ def _fetch_camera_image( camera_name, width=width, quality=quality ) if raw is None: - logger.debug("Reading from video server failed; trying direct RTSP stream") + logger.debug( + "Reading from video server failed; trying direct RTSP stream" + ) raw = fetch_camera_image_from_rtsp_stream(camera_name, width=width) if raw is not None and len(raw) > 0: tmp = np.frombuffer(raw, dtype="uint8") @@ -180,7 +185,9 @@ def _fetch_camera_image( jpg=jpg, hsv=hsv, ) - msg = "Failed to retieve image from both video server and direct RTSP stream" + msg = ( + "Failed to retieve image from both video server and direct RTSP stream" + ) logger.warning(msg) warnings.warn(msg, stacklevel=2) return RawJpgHsv(None, None, None) diff --git a/config.py b/config.py index dc0042d..a608cf5 100644 --- a/config.py +++ b/config.py @@ -89,7 +89,7 @@ args = argparse.ArgumentParser( description=None, formatter_class=argparse.ArgumentDefaultsHelpFormatter, fromfile_prefix_chars="@", - epilog=f'{program_name} uses config.py ({__file__}) for global, cross-module configuration setup and parsing.' + epilog=f'{program_name} uses config.py ({__file__}) for global, cross-module configuration setup and parsing.', ) # Keep track of if we've been called and prevent being called more @@ -138,10 +138,10 @@ group.add_argument( default=False, action='store_true', help=( - 'If present, config will raise an exception if it doesn\'t recognize an argument. The ' + - 'default behavior is to ignore this so as to allow interoperability with programs that ' + - 'want to use their own argparse calls to parse their own, separate commandline args.' - ) + 'If present, config will raise an exception if it doesn\'t recognize an argument. The ' + + 'default behavior is to ignore this so as to allow interoperability with programs that ' + + 'want to use their own argparse calls to parse their own, separate commandline args.' + ), ) @@ -210,6 +210,7 @@ def parse(entry_module: Optional[str]) -> Dict[str, Any]: f'Initialized from environment: {var} = {value}' ) from string_utils import to_bool + if len(chunks) == 1 and to_bool(value): sys.argv.append(var) elif len(chunks) > 1: @@ -238,16 +239,22 @@ def parse(entry_module: Optional[str]) -> Dict[str, Any]: if loadfile is not None: if saw_other_args: - msg = f'Augmenting commandline arguments with those from {loadfile}.' + msg = ( + f'Augmenting commandline arguments with those from {loadfile}.' + ) print(msg, file=sys.stderr) saved_messages.append(msg) if not os.path.exists(loadfile): - print(f'ERROR: --config_loadfile argument must be a file, {loadfile} not found.', - file=sys.stderr) + print( + f'ERROR: --config_loadfile argument must be a file, {loadfile} not found.', + file=sys.stderr, + ) sys.exit(-1) with open(loadfile, 'r') as rf: newargs = rf.readlines() - newargs = [arg.strip('\n') for arg in newargs if 'config_savefile' not in arg] + newargs = [ + arg.strip('\n') for arg in newargs if 'config_savefile' not in arg + ] sys.argv += newargs # Parse (possibly augmented, possibly completely overwritten) @@ -264,16 +271,16 @@ def parse(entry_module: Optional[str]) -> Dict[str, Any]: raise Exception( f'Encountered unrecognized config argument(s) {unknown} with --config_rejects_unrecognized_arguments enabled; halting.' ) - saved_messages.append(f'Config encountered unrecognized commandline arguments: {unknown}') + saved_messages.append( + f'Config encountered unrecognized commandline arguments: {unknown}' + ) sys.argv = sys.argv[:1] + unknown # Check for savefile and populate it if requested. savefile = config['config_savefile'] if savefile and len(savefile) > 0: with open(savefile, 'w') as wf: - wf.write( - "\n".join(original_argv[1:]) - ) + wf.write("\n".join(original_argv[1:])) # Also dump the config on stderr if requested. if config['config_dump']: diff --git a/conversion_utils.py b/conversion_utils.py index d2225fd..4326840 100644 --- a/conversion_utils.py +++ b/conversion_utils.py @@ -23,12 +23,15 @@ class Converter(object): each have the potential to overflow, underflow, or introduce floating point errors. Caveat emptor. """ - def __init__(self, - name: str, - category: str, - to_canonical: Callable, # convert to canonical unit - from_canonical: Callable, # convert from canonical unit - unit: str) -> None: + + def __init__( + self, + name: str, + category: str, + to_canonical: Callable, # convert to canonical unit + from_canonical: Callable, # convert from canonical unit + unit: str, + ) -> None: self.name = name self.category = category self.to_canonical_f = to_canonical @@ -47,52 +50,56 @@ class Converter(object): # A catalog of converters. conversion_catalog = { - "Second": Converter("Second", - "time", - lambda s: s, - lambda s: s, - "s"), - "Minute": Converter("Minute", - "time", - lambda m: (m * constants.SECONDS_PER_MINUTE), - lambda s: (s / constants.SECONDS_PER_MINUTE), - "m"), - "Hour": Converter("Hour", - "time", - lambda h: (h * constants.SECONDS_PER_HOUR), - lambda s: (s / constants.SECONDS_PER_HOUR), - "h"), - "Day": Converter("Day", - "time", - lambda d: (d * constants.SECONDS_PER_DAY), - lambda s: (s / constants.SECONDS_PER_DAY), - "d"), - "Week": Converter("Week", - "time", - lambda w: (w * constants.SECONDS_PER_WEEK), - lambda s: (s / constants.SECONDS_PER_WEEK), - "w"), - "Fahrenheit": Converter("Fahrenheit", - "temperature", - lambda f: (f - 32.0) * 0.55555555, - lambda c: c * 1.8 + 32.0, - "°F"), - "Celsius": Converter("Celsius", - "temperature", - lambda c: c, - lambda c: c, - "°C"), - "Kelvin": Converter("Kelvin", - "temperature", - lambda k: k - 273.15, - lambda c: c + 273.15, - "°K"), + "Second": Converter("Second", "time", lambda s: s, lambda s: s, "s"), + "Minute": Converter( + "Minute", + "time", + lambda m: (m * constants.SECONDS_PER_MINUTE), + lambda s: (s / constants.SECONDS_PER_MINUTE), + "m", + ), + "Hour": Converter( + "Hour", + "time", + lambda h: (h * constants.SECONDS_PER_HOUR), + lambda s: (s / constants.SECONDS_PER_HOUR), + "h", + ), + "Day": Converter( + "Day", + "time", + lambda d: (d * constants.SECONDS_PER_DAY), + lambda s: (s / constants.SECONDS_PER_DAY), + "d", + ), + "Week": Converter( + "Week", + "time", + lambda w: (w * constants.SECONDS_PER_WEEK), + lambda s: (s / constants.SECONDS_PER_WEEK), + "w", + ), + "Fahrenheit": Converter( + "Fahrenheit", + "temperature", + lambda f: (f - 32.0) * 0.55555555, + lambda c: c * 1.8 + 32.0, + "°F", + ), + "Celsius": Converter( + "Celsius", "temperature", lambda c: c, lambda c: c, "°C" + ), + "Kelvin": Converter( + "Kelvin", + "temperature", + lambda k: k - 273.15, + lambda c: c + 273.15, + "°K", + ), } -def convert(magnitude: Number, - from_thing: str, - to_thing: str) -> float: +def convert(magnitude: Number, from_thing: str, to_thing: str) -> float: src = conversion_catalog.get(from_thing, None) dst = conversion_catalog.get(to_thing, None) if src is None or dst is None: @@ -102,9 +109,9 @@ def convert(magnitude: Number, return _convert(magnitude, src, dst) -def _convert(magnitude: Number, - from_unit: Converter, - to_unit: Converter) -> float: +def _convert( + magnitude: Number, from_unit: Converter, to_unit: Converter +) -> float: canonical = from_unit.to_canonical(magnitude) converted = to_unit.from_canonical(canonical) return float(converted) @@ -332,4 +339,5 @@ def c_to_f(temp_c: float) -> float: if __name__ == '__main__': import doctest + doctest.testmod() diff --git a/datetime_utils.py b/datetime_utils.py index 7c5516b..9794720 100644 --- a/datetime_utils.py +++ b/datetime_utils.py @@ -27,18 +27,16 @@ def is_timezone_aware(dt: datetime.datetime) -> bool: True """ - return ( - dt.tzinfo is not None and - dt.tzinfo.utcoffset(dt) is not None - ) + return dt.tzinfo is not None and dt.tzinfo.utcoffset(dt) is not None def is_timezone_naive(dt: datetime.datetime) -> bool: return not is_timezone_aware(dt) -def replace_timezone(dt: datetime.datetime, - tz: datetime.tzinfo) -> datetime.datetime: +def replace_timezone( + dt: datetime.datetime, tz: datetime.tzinfo +) -> datetime.datetime: """ Replaces the timezone on a datetime object directly (leaving the year, month, day, hour, minute, second, micro, etc... alone). @@ -57,13 +55,20 @@ def replace_timezone(dt: datetime.datetime, """ return datetime.datetime( - dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second, dt.microsecond, - tzinfo=tz + dt.year, + dt.month, + dt.day, + dt.hour, + dt.minute, + dt.second, + dt.microsecond, + tzinfo=tz, ) -def replace_time_timezone(t: datetime.time, - tz: datetime.tzinfo) -> datetime.time: +def replace_time_timezone( + t: datetime.time, tz: datetime.tzinfo +) -> datetime.time: """ Replaces the timezone on a datetime.time directly without performing any translation. @@ -80,8 +85,9 @@ def replace_time_timezone(t: datetime.time, return t.replace(tzinfo=tz) -def translate_timezone(dt: datetime.datetime, - tz: datetime.tzinfo) -> datetime.datetime: +def translate_timezone( + dt: datetime.datetime, tz: datetime.tzinfo +) -> datetime.datetime: """ Translates dt into a different timezone by adjusting the year, month, day, hour, minute, second, micro, etc... appropriately. The returned @@ -125,12 +131,7 @@ def date_to_datetime(date: datetime.date) -> datetime.datetime: datetime.datetime(2021, 12, 25, 0, 0) """ - return datetime.datetime( - date.year, - date.month, - date.day, - 0, 0, 0, 0 - ) + return datetime.datetime(date.year, date.month, date.day, 0, 0, 0, 0) def time_to_datetime_today(time: datetime.time) -> datetime.datetime: @@ -169,8 +170,9 @@ def time_to_datetime_today(time: datetime.time) -> datetime.datetime: return datetime.datetime.combine(now, time, tz) -def date_and_time_to_datetime(date: datetime.date, - time: datetime.time) -> datetime.datetime: +def date_and_time_to_datetime( + date: datetime.date, time: datetime.time +) -> datetime.datetime: """ Given a date and time, merge them and return a datetime. @@ -193,7 +195,7 @@ def date_and_time_to_datetime(date: datetime.date, def datetime_to_date_and_time( - dt: datetime.datetime + dt: datetime.datetime, ) -> Tuple[datetime.date, datetime.time]: """Return the component date and time objects of a datetime. @@ -235,6 +237,7 @@ def datetime_to_time(dt: datetime.datetime) -> datetime.time: class TimeUnit(enum.Enum): """An enum to represent units with which we can compute deltas.""" + MONDAYS = 0 TUESDAYS = 1 WEDNESDAYS = 2 @@ -265,9 +268,7 @@ class TimeUnit(enum.Enum): def n_timeunits_from_base( - count: int, - unit: TimeUnit, - base: datetime.datetime + count: int, unit: TimeUnit, base: datetime.datetime ) -> datetime.datetime: """Return a datetime that is N units before/after a base datetime. e.g. 3 Wednesdays from base datetime, 2 weeks from base date, 10 @@ -348,10 +349,8 @@ def n_timeunits_from_base( if base.year != old_year: skips = holidays.US(years=base.year).keys() if ( - base.weekday() < 5 and - datetime.date(base.year, - base.month, - base.day) not in skips + base.weekday() < 5 + and datetime.date(base.year, base.month, base.day) not in skips ): count -= 1 return base @@ -396,13 +395,17 @@ def n_timeunits_from_base( base.tzinfo, ) - if unit not in set([TimeUnit.MONDAYS, - TimeUnit.TUESDAYS, - TimeUnit.WEDNESDAYS, - TimeUnit.THURSDAYS, - TimeUnit.FRIDAYS, - TimeUnit.SATURDAYS, - TimeUnit.SUNDAYS]): + if unit not in set( + [ + TimeUnit.MONDAYS, + TimeUnit.TUESDAYS, + TimeUnit.WEDNESDAYS, + TimeUnit.THURSDAYS, + TimeUnit.FRIDAYS, + TimeUnit.SATURDAYS, + TimeUnit.SUNDAYS, + ] + ): raise ValueError(unit) # N weekdays from base (e.g. 4 wednesdays from today) @@ -420,14 +423,14 @@ def n_timeunits_from_base( def get_format_string( - *, - date_time_separator=" ", - include_timezone=True, - include_dayname=False, - use_month_abbrevs=False, - include_seconds=True, - include_fractional=False, - twelve_hour=True, + *, + date_time_separator=" ", + include_timezone=True, + include_dayname=False, + use_month_abbrevs=False, + include_seconds=True, + include_fractional=False, + twelve_hour=True, ) -> str: """ Helper to return a format string without looking up the documentation @@ -502,20 +505,21 @@ def datetime_to_string( include_dayname=include_dayname, include_seconds=include_seconds, include_fractional=include_fractional, - twelve_hour=twelve_hour) + twelve_hour=twelve_hour, + ) return dt.strftime(fstring).strip() def string_to_datetime( - txt: str, - *, - date_time_separator=" ", - include_timezone=True, - include_dayname=False, - use_month_abbrevs=False, - include_seconds=True, - include_fractional=False, - twelve_hour=True, + txt: str, + *, + date_time_separator=" ", + include_timezone=True, + include_dayname=False, + use_month_abbrevs=False, + include_seconds=True, + include_fractional=False, + twelve_hour=True, ) -> Tuple[datetime.datetime, str]: """A nice way to convert a string into a datetime. Returns both the datetime and the format string used to parse it. Also consider @@ -534,11 +538,9 @@ def string_to_datetime( include_dayname=include_dayname, include_seconds=include_seconds, include_fractional=include_fractional, - twelve_hour=twelve_hour) - return ( - datetime.datetime.strptime(txt, fstring), - fstring + twelve_hour=twelve_hour, ) + return (datetime.datetime.strptime(txt, fstring), fstring) def timestamp() -> str: @@ -705,7 +707,7 @@ def parse_duration(duration: str) -> int: return seconds -def describe_duration(seconds: int, *, include_seconds = False) -> str: +def describe_duration(seconds: int, *, include_seconds=False) -> str: """ Describe a duration represented as a count of seconds nicely. @@ -816,4 +818,5 @@ def describe_timedelta_briefly(delta: datetime.timedelta) -> str: if __name__ == '__main__': import doctest + doctest.testmod() diff --git a/decorator_utils.py b/decorator_utils.py index eb5a0c9..daae64e 100644 --- a/decorator_utils.py +++ b/decorator_utils.py @@ -48,6 +48,7 @@ def timed(func: Callable) -> Callable: print(msg) logger.info(msg) return value + return wrapper_timer @@ -75,10 +76,13 @@ def invocation_logged(func: Callable) -> Callable: print(msg) logger.info(msg) return ret + return wrapper_invocation_logged -def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable: +def rate_limited( + n_calls: int, *, per_period_in_seconds: float = 1.0 +) -> Callable: """Limit invocation of a wrapped function to n calls per period. Thread safe. In testing this was relatively fair with multiple threads using it though that hasn't been measured. @@ -152,7 +156,9 @@ def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callabl ) cv.notify() return ret + return wrapper_wrapper_rate_limited + return wrapper_rate_limited @@ -188,6 +194,7 @@ def debug_args(func: Callable) -> Callable: print(msg) logger.info(msg) return value + return wrapper_debug_args @@ -213,10 +220,13 @@ def debug_count_calls(func: Callable) -> Callable: @functools.wraps(func) def wrapper_debug_count_calls(*args, **kwargs): wrapper_debug_count_calls.num_calls += 1 - msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}" + msg = ( + f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}" + ) print(msg) logger.info(msg) return func(*args, **kwargs) + wrapper_debug_count_calls.num_calls = 0 return wrapper_debug_count_calls @@ -251,6 +261,7 @@ def delay( True """ + def decorator_delay(func: Callable) -> Callable: @functools.wraps(func) def wrapper_delay(*args, **kwargs): @@ -266,6 +277,7 @@ def delay( ) time.sleep(seconds) return retval + return wrapper_delay if _func is None: @@ -350,6 +362,7 @@ def memoized(func: Callable) -> Callable: True """ + @functools.wraps(func) def wrapper_memoized(*args, **kwargs): cache_key = args + tuple(kwargs.items()) @@ -362,6 +375,7 @@ def memoized(func: Callable) -> Callable: else: logger.debug(f"Returning memoized value for {func.__name__}") return wrapper_memoized.cache[cache_key] + wrapper_memoized.cache = dict() return wrapper_memoized @@ -416,7 +430,9 @@ def retry_predicate( mdelay *= backoff retval = f(*args, **kwargs) return retval + return f_retry + return deco_retry @@ -475,6 +491,7 @@ def deprecated(func): when the function is used. """ + @functools.wraps(func) def wrapper_deprecated(*args, **kwargs): msg = f"Call to deprecated function {func.__qualname__}" @@ -482,6 +499,7 @@ def deprecated(func): warnings.warn(msg, category=DeprecationWarning, stacklevel=2) print(msg, file=sys.stderr) return func(*args, **kwargs) + return wrapper_deprecated @@ -655,6 +673,7 @@ def timeout( """ if use_signals is None: import thread_utils + use_signals = thread_utils.is_current_thread_main_thread() def decorate(function): @@ -708,6 +727,7 @@ class non_reentrant_code(object): self._entered = True f(*args, **kwargs) self._entered = False + return _gatekeeper @@ -724,6 +744,7 @@ class rlocked(object): self._entered = True f(*args, **kwargs) self._entered = False + return _gatekeeper @@ -742,7 +763,9 @@ def call_with_sample_rate(sample_rate: float) -> Callable: logger.debug( f"@call_with_sample_rate skipping a call to {f.__name__}" ) + return _call_with_sample_rate + return decorator @@ -751,6 +774,7 @@ def decorate_matching_methods_with(decorator, acl=None): prefix. If prefix is None (default), decorate all methods in the class. """ + def decorate_the_class(cls): for name, m in inspect.getmembers(cls, inspect.isfunction): if acl is None: @@ -759,10 +783,11 @@ def decorate_matching_methods_with(decorator, acl=None): if acl(name): setattr(cls, name, decorator(m)) return cls + return decorate_the_class if __name__ == '__main__': import doctest - doctest.testmod() + doctest.testmod() diff --git a/deferred_operand.py b/deferred_operand.py index 4b12279..22bcb83 100644 --- a/deferred_operand.py +++ b/deferred_operand.py @@ -91,7 +91,9 @@ class DeferredOperand(ABC, Generic[T]): return DeferredOperand.resolve(self) is DeferredOperand.resolve(other) def is_not(self, other): - return DeferredOperand.resolve(self) is not DeferredOperand.resolve(other) + return DeferredOperand.resolve(self) is not DeferredOperand.resolve( + other + ) def __abs__(self): return abs(DeferredOperand.resolve(self)) @@ -152,4 +154,5 @@ class DeferredOperand(ABC, Generic[T]): return getattr(DeferredOperand.resolve(self), method_name)( *args, **kwargs ) + return method diff --git a/dict_utils.py b/dict_utils.py index b1464c6..79c86ed 100644 --- a/dict_utils.py +++ b/dict_utils.py @@ -198,7 +198,9 @@ def min_key(d: Dict[Any, Any]) -> Any: return min(d.keys()) -def parallel_lists_to_dict(keys: List[Any], values: List[Any]) -> Dict[Any, Any]: +def parallel_lists_to_dict( + keys: List[Any], values: List[Any] +) -> Dict[Any, Any]: """Given two parallel lists (keys and values), create and return a dict. @@ -209,7 +211,9 @@ def parallel_lists_to_dict(keys: List[Any], values: List[Any]) -> Dict[Any, Any] """ if len(keys) != len(values): - raise Exception("Parallel keys and values lists must have the same length") + raise Exception( + "Parallel keys and values lists must have the same length" + ) return dict(zip(keys, values)) diff --git a/directory_filter.py b/directory_filter.py index 03602d1..8d03ff6 100644 --- a/directory_filter.py +++ b/directory_filter.py @@ -30,9 +30,11 @@ class DirectoryFileFilter(object): >>> os.remove(testfile) """ + def __init__(self, directory: str): super().__init__() import file_utils + if not file_utils.does_directory_exist(directory): raise ValueError(directory) self.directory = directory @@ -49,12 +51,15 @@ class DirectoryFileFilter(object): def _update_file(self, filename: str, mtime: Optional[float] = None): import file_utils + assert file_utils.does_file_exist(filename) if mtime is None: mtime = file_utils.get_file_raw_mtime(filename) if self.mtime_by_filename.get(filename, 0) != mtime: md5 = file_utils.get_file_md5(filename) - logger.debug(f'Computed/stored {filename}\'s MD5 at ts={mtime} ({md5})') + logger.debug( + f'Computed/stored {filename}\'s MD5 at ts={mtime} ({md5})' + ) self.mtime_by_filename[filename] = mtime self.md5_by_filename[filename] = md5 @@ -97,12 +102,14 @@ class DirectoryAllFilesFilter(DirectoryFileFilter): >>> os.remove(testfile) """ + def __init__(self, directory: str): self.all_md5s = set() super().__init__(directory) def _update_file(self, filename: str, mtime: Optional[float] = None): import file_utils + assert file_utils.does_file_exist(filename) if mtime is None: mtime = file_utils.get_file_raw_mtime(filename) @@ -122,4 +129,5 @@ class DirectoryAllFilesFilter(DirectoryFileFilter): if __name__ == '__main__': import doctest + doctest.testmod() diff --git a/exceptions.py b/exceptions.py index 59aa262..82a82a5 100644 --- a/exceptions.py +++ b/exceptions.py @@ -3,6 +3,7 @@ # This module is commonly used by others in here and should avoid # taking any unnecessary dependencies back on them. + class PreconditionException(AssertionError): pass diff --git a/exec_utils.py b/exec_utils.py index c1dbdcb..0163107 100644 --- a/exec_utils.py +++ b/exec_utils.py @@ -12,7 +12,9 @@ from typing import List, Optional logger = logging.getLogger(__file__) -def cmd_showing_output(command: str, ) -> int: +def cmd_showing_output( + command: str, +) -> int: """Kick off a child process. Capture and print all output that it produces on stdout and stderr. Wait for the subprocess to exit and return the exit value as the return code of this function. @@ -119,25 +121,30 @@ def run_silently(command: str, timeout_seconds: Optional[float] = None) -> None: def cmd_in_background( - command: str, *, silent: bool = False + command: str, *, silent: bool = False ) -> subprocess.Popen: args = shlex.split(command) if silent: - subproc = subprocess.Popen(args, - stdin=subprocess.DEVNULL, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL) + subproc = subprocess.Popen( + args, + stdin=subprocess.DEVNULL, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) else: subproc = subprocess.Popen(args, stdin=subprocess.DEVNULL) def kill_subproc() -> None: try: if subproc.poll() is None: - logger.info("At exit handler: killing {}: {}".format(subproc, command)) + logger.info( + "At exit handler: killing {}: {}".format(subproc, command) + ) subproc.terminate() subproc.wait(timeout=10.0) except BaseException as be: logger.exception(be) + atexit.register(kill_subproc) return subproc @@ -152,4 +159,5 @@ def cmd_list(command: List[str]) -> str: if __name__ == '__main__': import doctest + doctest.testmod() diff --git a/executors.py b/executors.py index 3786954..46812c2 100644 --- a/executors.py +++ b/executors.py @@ -160,7 +160,9 @@ class ProcessExecutor(BaseExecutor): self.adjust_task_count(+1) pickle = make_cloud_pickle(function, *args, **kwargs) result = self._process_executor.submit(self.run_cloud_pickle, pickle) - result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start)) + result.add_done_callback( + lambda _: self.histogram.add_item(time.time() - start) + ) return result @overrides @@ -256,7 +258,9 @@ class RemoteExecutorStatus: self.finished_bundle_timings_per_worker: Dict[ RemoteWorkerRecord, List[float] ] = {} - self.in_flight_bundles_by_worker: Dict[RemoteWorkerRecord, Set[str]] = {} + self.in_flight_bundles_by_worker: Dict[ + RemoteWorkerRecord, Set[str] + ] = {} self.bundle_details_by_uuid: Dict[str, BundleDetails] = {} self.finished_bundle_timings: List[float] = [] self.last_periodic_dump: Optional[float] = None @@ -266,7 +270,9 @@ class RemoteExecutorStatus: # as a memory fence for modifications to bundle. self.lock: threading.Lock = threading.Lock() - def record_acquire_worker(self, worker: RemoteWorkerRecord, uuid: str) -> None: + def record_acquire_worker( + self, worker: RemoteWorkerRecord, uuid: str + ) -> None: with self.lock: self.record_acquire_worker_already_locked(worker, uuid) @@ -284,7 +290,9 @@ class RemoteExecutorStatus: with self.lock: self.record_bundle_details_already_locked(details) - def record_bundle_details_already_locked(self, details: BundleDetails) -> None: + def record_bundle_details_already_locked( + self, details: BundleDetails + ) -> None: assert self.lock.locked() self.bundle_details_by_uuid[details.uuid] = details @@ -295,7 +303,9 @@ class RemoteExecutorStatus: was_cancelled: bool, ) -> None: with self.lock: - self.record_release_worker_already_locked(worker, uuid, was_cancelled) + self.record_release_worker_already_locked( + worker, uuid, was_cancelled + ) def record_release_worker_already_locked( self, @@ -367,7 +377,11 @@ class RemoteExecutorStatus: ret += f' ...{in_flight} bundles currently in flight:\n' for bundle_uuid in self.in_flight_bundles_by_worker[worker]: details = self.bundle_details_by_uuid.get(bundle_uuid, None) - pid = str(details.pid) if (details and details.pid != 0) else "TBD" + pid = ( + str(details.pid) + if (details and details.pid != 0) + else "TBD" + ) if self.start_per_bundle[bundle_uuid] is not None: sec = ts - self.start_per_bundle[bundle_uuid] ret += f' (pid={pid}): {details} for {sec:.1f}s so far ' @@ -398,7 +412,10 @@ class RemoteExecutorStatus: assert self.lock.locked() self.total_bundles_submitted = total_bundles_submitted ts = time.time() - if self.last_periodic_dump is None or ts - self.last_periodic_dump > 5.0: + if ( + self.last_periodic_dump is None + or ts - self.last_periodic_dump > 5.0 + ): print(self) self.last_periodic_dump = ts @@ -412,7 +429,9 @@ class RemoteWorkerSelectionPolicy(ABC): pass @abstractmethod - def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]: + def acquire_worker( + self, machine_to_avoid=None + ) -> Optional[RemoteWorkerRecord]: pass @@ -425,7 +444,9 @@ class WeightedRandomRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy): return False @overrides - def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]: + def acquire_worker( + self, machine_to_avoid=None + ) -> Optional[RemoteWorkerRecord]: grabbag = [] for worker in self.workers: for x in range(0, worker.count): @@ -482,7 +503,9 @@ class RoundRobinRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy): class RemoteExecutor(BaseExecutor): def __init__( - self, workers: List[RemoteWorkerRecord], policy: RemoteWorkerSelectionPolicy + self, + workers: List[RemoteWorkerRecord], + policy: RemoteWorkerSelectionPolicy, ) -> None: super().__init__() self.workers = workers @@ -562,7 +585,9 @@ class RemoteExecutor(BaseExecutor): break for uuid in bundle_uuids: - bundle = self.status.bundle_details_by_uuid.get(uuid, None) + bundle = self.status.bundle_details_by_uuid.get( + uuid, None + ) if ( bundle is not None and bundle.src_bundle is None @@ -653,7 +678,9 @@ class RemoteExecutor(BaseExecutor): logger.critical(msg) raise Exception(msg) - def release_worker(self, bundle: BundleDetails, *, was_cancelled=True) -> None: + def release_worker( + self, bundle: BundleDetails, *, was_cancelled=True + ) -> None: worker = bundle.worker assert worker is not None logger.debug(f'Released worker {worker}') @@ -737,14 +764,14 @@ class RemoteExecutor(BaseExecutor): # Send input code / data to worker machine if it's not local. if hostname not in machine: try: - cmd = ( - f'{SCP} {bundle.code_file} {username}@{machine}:{bundle.code_file}' - ) + cmd = f'{SCP} {bundle.code_file} {username}@{machine}:{bundle.code_file}' start_ts = time.time() logger.info(f"{bundle}: Copying work to {worker} via {cmd}.") run_silently(cmd) xfer_latency = time.time() - start_ts - logger.debug(f"{bundle}: Copying to {worker} took {xfer_latency:.1f}s.") + logger.debug( + f"{bundle}: Copying to {worker} took {xfer_latency:.1f}s." + ) except Exception as e: self.release_worker(bundle) if is_original: @@ -777,7 +804,9 @@ class RemoteExecutor(BaseExecutor): f' /home/scott/lib/python_modules/remote_worker.py' f' --code_file {bundle.code_file} --result_file {bundle.result_file}"' ) - logger.debug(f'{bundle}: Executing {cmd} in the background to kick off work...') + logger.debug( + f'{bundle}: Executing {cmd} in the background to kick off work...' + ) p = cmd_in_background(cmd, silent=True) bundle.pid = p.pid logger.debug( @@ -906,7 +935,9 @@ class RemoteExecutor(BaseExecutor): # Re-raise the exception; the code in wait_for_process may # decide to emergency_retry_nasty_bundle here. raise Exception(e) - logger.debug(f'Removing local (master) {code_file} and {result_file}.') + logger.debug( + f'Removing local (master) {code_file} and {result_file}.' + ) os.remove(f'{result_file}') os.remove(f'{code_file}') diff --git a/file_utils.py b/file_utils.py index 5d9a0be..12aadca 100644 --- a/file_utils.py +++ b/file_utils.py @@ -366,7 +366,9 @@ def get_file_mtime_timedelta(filename: str) -> Optional[datetime.timedelta]: return get_file_timestamp_timedelta(filename, lambda x: x.st_mtime) -def describe_file_timestamp(filename: str, extractor, *, brief=False) -> Optional[str]: +def describe_file_timestamp( + filename: str, extractor, *, brief=False +) -> Optional[str]: from datetime_utils import describe_duration, describe_duration_briefly age = get_file_timestamp_age_seconds(filename, extractor) diff --git a/function_utils.py b/function_utils.py index 5b70981..f027f8c 100644 --- a/function_utils.py +++ b/function_utils.py @@ -19,6 +19,7 @@ def function_identifier(f: Callable) -> str: if f.__module__ == '__main__': from pathlib import Path import __main__ + module = __main__.__file__ module = Path(module).stem return f'{module}:{f.__name__}' diff --git a/google_assistant.py b/google_assistant.py index 49c08d3..75ca643 100644 --- a/google_assistant.py +++ b/google_assistant.py @@ -21,14 +21,14 @@ parser.add_argument( type=str, default="http://kiosk.house:3000", metavar="URL", - help="How to contact the Google Assistant bridge" + help="How to contact the Google Assistant bridge", ) parser.add_argument( "--google_assistant_username", type=str, metavar="GOOGLE_ACCOUNT", default="scott.gasch", - help="The user account for talking to Google Assistant" + help="The user account for talking to Google Assistant", ) @@ -105,7 +105,9 @@ def ask_google(cmd: str, *, recognize_speech=True) -> GoogleResponse: audio_transcription=audio_transcription, ) else: - message = f'HTTP request to {url} with {payload} failed; code {r.status_code}' + message = ( + f'HTTP request to {url} with {payload} failed; code {r.status_code}' + ) logger.error(message) return GoogleResponse( success=False, diff --git a/histogram.py b/histogram.py index 3391b0b..4aa4749 100644 --- a/histogram.py +++ b/histogram.py @@ -15,6 +15,7 @@ class SimpleHistogram(Generic[T]): def __init__(self, buckets: List[Tuple[T, T]]): from math_utils import RunningMedian + self.buckets = {} for start_end in buckets: if self._get_bucket(start_end[0]) is not None: @@ -28,9 +29,9 @@ class SimpleHistogram(Generic[T]): @staticmethod def n_evenly_spaced_buckets( - min_bound: T, - max_bound: T, - n: int, + min_bound: T, + max_bound: T, + n: int, ) -> List[Tuple[T, T]]: ret = [] stride = int((max_bound - min_bound) / n) @@ -66,8 +67,7 @@ class SimpleHistogram(Generic[T]): all_true = all_true and self.add_item(item) return all_true - def __repr__(self, - label_formatter='%10s') -> str: + def __repr__(self, label_formatter='%10s') -> str: from text_utils import bar_graph max_population: Optional[int] = None @@ -82,18 +82,23 @@ class SimpleHistogram(Generic[T]): if max_population is None: return txt - for bucket in sorted(self.buckets, key=lambda x : x[0]): + for bucket in sorted(self.buckets, key=lambda x: x[0]): pop = self.buckets[bucket] start = bucket[0] end = bucket[1] bar = bar_graph( (pop / max_population), - include_text = False, - width = 58, - left_end = "", - right_end = "") + include_text=False, + width=58, + left_end="", + right_end="", + ) label = f'{label_formatter}..{label_formatter}' % (start, end) - txt += f'{label:20}: ' + bar + f"({pop/self.count*100.0:5.2f}% n={pop})\n" + txt += ( + f'{label:20}: ' + + bar + + f"({pop/self.count*100.0:5.2f}% n={pop})\n" + ) if start == last_bucket_start: break return txt diff --git a/id_generator.py b/id_generator.py index bcd3a83..d4c7016 100644 --- a/id_generator.py +++ b/id_generator.py @@ -34,4 +34,5 @@ def get(name: str, *, start=0) -> int: if __name__ == '__main__': import doctest + doctest.testmod() diff --git a/letter_compress.py b/letter_compress.py index b5d3264..9b4cf19 100644 --- a/letter_compress.py +++ b/letter_compress.py @@ -34,10 +34,12 @@ def compress(uncompressed: str) -> bytes: compressed = bitstring.BitArray() for (n, letter) in enumerate(uncompressed): if 'a' <= letter <= 'z': - bits = ord(letter) - ord('a') + 1 # 1..26 + bits = ord(letter) - ord('a') + 1 # 1..26 else: if letter not in special_characters: - raise Exception(f'"{uncompressed}" contains uncompressable char="{letter}"') + raise Exception( + f'"{uncompressed}" contains uncompressable char="{letter}"' + ) bits = special_characters[letter] compressed.append(f"uint:5={bits}") while len(compressed) % 8 != 0: @@ -100,4 +102,5 @@ def decompress(kompressed: bytes) -> str: if __name__ == '__main__': import doctest + doctest.testmod() diff --git a/lockfile.py b/lockfile.py index ebd9115..2bbe6f4 100644 --- a/lockfile.py +++ b/lockfile.py @@ -16,14 +16,14 @@ import decorator_utils cfg = config.add_commandline_args( - f'Lockfile ({__file__})', - 'Args related to lockfiles') + f'Lockfile ({__file__})', 'Args related to lockfiles' +) cfg.add_argument( '--lockfile_held_duration_warning_threshold_sec', type=float, default=10.0, metavar='SECONDS', - help='If a lock is held for longer than this threshold we log a warning' + help='If a lock is held for longer than this threshold we log a warning', ) logger = logging.getLogger(__name__) @@ -50,13 +50,14 @@ class LockFile(object): # some logic for detecting stale locks. """ + def __init__( - self, - lockfile_path: str, - *, - do_signal_cleanup: bool = True, - expiration_timestamp: Optional[float] = None, - override_command: Optional[str] = None, + self, + lockfile_path: str, + *, + do_signal_cleanup: bool = True, + expiration_timestamp: Optional[float] = None, + override_command: Optional[str] = None, ) -> None: self.is_locked = False self.lockfile = lockfile_path @@ -93,16 +94,15 @@ class LockFile(object): return False def acquire_with_retries( - self, - *, - initial_delay: float = 1.0, - backoff_factor: float = 2.0, - max_attempts = 5 + self, + *, + initial_delay: float = 1.0, + backoff_factor: float = 2.0, + max_attempts=5, ) -> bool: - - @decorator_utils.retry_if_false(tries = max_attempts, - delay_sec = initial_delay, - backoff = backoff_factor) + @decorator_utils.retry_if_false( + tries=max_attempts, delay_sec=initial_delay, backoff=backoff_factor + ) def _try_acquire_lock_with_retries() -> bool: success = self.try_acquire_lock_once() if not success and os.path.exists(self.lockfile): @@ -132,8 +132,13 @@ class LockFile(object): if self.locktime: ts = datetime.datetime.now().timestamp() duration = ts - self.locktime - if duration >= config.config['lockfile_held_duration_warning_threshold_sec']: - str_duration = datetime_utils.describe_duration_briefly(duration) + if ( + duration + >= config.config['lockfile_held_duration_warning_threshold_sec'] + ): + str_duration = datetime_utils.describe_duration_briefly( + duration + ) msg = f'Held {self.lockfile} for {str_duration}' logger.warning(msg) warnings.warn(msg, stacklevel=2) @@ -153,9 +158,9 @@ class LockFile(object): else: cmd = ' '.join(sys.argv) contents = LockFileContents( - pid = os.getpid(), - commandline = cmd, - expiration_timestamp = self.expiration_timestamp, + pid=os.getpid(), + commandline=cmd, + expiration_timestamp=self.expiration_timestamp, ) return json.dumps(contents.__dict__) diff --git a/logging_utils.py b/logging_utils.py index a15ccd6..bf8d8b0 100644 --- a/logging_utils.py +++ b/logging_utils.py @@ -22,7 +22,9 @@ import pytz import argparse_utils import config -cfg = config.add_commandline_args(f'Logging ({__file__})', 'Args related to logging') +cfg = config.add_commandline_args( + f'Logging ({__file__})', 'Args related to logging' +) cfg.add_argument( '--logging_config_file', type=argparse_utils.valid_filename, @@ -231,7 +233,9 @@ class SquelchRepeatedMessagesFilter(logging.Filter): if id1 not in squelched_logging_counts: return True threshold = squelched_logging_counts[id1] - logsite = f'{record.pathname}+{record.lineno}+{record.levelno}+{record.msg}' + logsite = ( + f'{record.pathname}+{record.lineno}+{record.levelno}+{record.msg}' + ) count = self.counters[logsite] self.counters[logsite] += 1 return count < threshold @@ -440,8 +444,12 @@ def initialize_logging(logger=None) -> logging.Logger: if config.config['logging_syslog']: if sys.platform not in ('win32', 'cygwin'): if config.config['logging_syslog_facility']: - facility_name = 'LOG_' + config.config['logging_syslog_facility'] - facility = SysLogHandler.__dict__.get(facility_name, SysLogHandler.LOG_USER) + facility_name = ( + 'LOG_' + config.config['logging_syslog_facility'] + ) + facility = SysLogHandler.__dict__.get( + facility_name, SysLogHandler.LOG_USER + ) handler = SysLogHandler(facility=facility, address='/dev/log') handler.setFormatter( MillisecondAwareFormatter( @@ -525,7 +533,9 @@ def initialize_logging(logger=None) -> logging.Logger: level_name = logging._levelToName.get( default_logging_level, str(default_logging_level) ) - logger.debug(f'Initialized global logging; default logging level is {level_name}.') + logger.debug( + f'Initialized global logging; default logging level is {level_name}.' + ) if ( config.config['logging_clear_preexisting_handlers'] and preexisting_handlers_count > 0 @@ -654,17 +664,23 @@ class OutputMultiplexer(object): self.logger = logger if filenames is not None: - self.f = [open(filename, 'wb', buffering=0) for filename in filenames] + self.f = [ + open(filename, 'wb', buffering=0) for filename in filenames + ] else: if destination_bitv & OutputMultiplexer.FILENAMES: - raise ValueError("Filenames argument is required if bitv & FILENAMES") + raise ValueError( + "Filenames argument is required if bitv & FILENAMES" + ) self.f = None if handles is not None: self.h = [handle for handle in handles] else: if destination_bitv & OutputMultiplexer.Destination.FILEHANDLES: - raise ValueError("Handle argument is required if bitv & FILEHANDLES") + raise ValueError( + "Handle argument is required if bitv & FILEHANDLES" + ) self.h = None self.set_destination_bitv(destination_bitv) @@ -674,9 +690,13 @@ class OutputMultiplexer(object): def set_destination_bitv(self, destination_bitv: int): if destination_bitv & self.Destination.FILENAMES and self.f is None: - raise ValueError("Filename argument is required if bitv & FILENAMES") + raise ValueError( + "Filename argument is required if bitv & FILENAMES" + ) if destination_bitv & self.Destination.FILEHANDLES and self.h is None: - raise ValueError("Handle argument is required if bitv & FILEHANDLES") + raise ValueError( + "Handle argument is required if bitv & FILEHANDLES" + ) self.destination_bitv = destination_bitv def print(self, *args, **kwargs): @@ -699,12 +719,18 @@ class OutputMultiplexer(object): end = "\n" if end == '\n': buf += '\n' - if self.destination_bitv & self.Destination.FILENAMES and self.f is not None: + if ( + self.destination_bitv & self.Destination.FILENAMES + and self.f is not None + ): for _ in self.f: _.write(buf.encode('utf-8')) _.flush() - if self.destination_bitv & self.Destination.FILEHANDLES and self.h is not None: + if ( + self.destination_bitv & self.Destination.FILEHANDLES + and self.h is not None + ): for _ in self.h: _.write(buf) _.flush() @@ -755,7 +781,10 @@ class OutputMultiplexerContext(OutputMultiplexer, contextlib.ContextDecorator): handles=None, ): super().__init__( - destination_bitv, logger=logger, filenames=filenames, handles=handles + destination_bitv, + logger=logger, + filenames=filenames, + handles=handles, ) def __enter__(self): diff --git a/logical_search.py b/logical_search.py index 3ebaee5..85f9461 100644 --- a/logical_search.py +++ b/logical_search.py @@ -403,4 +403,5 @@ class Node(object): if __name__ == '__main__': import doctest + doctest.testmod() diff --git a/math_utils.py b/math_utils.py index e0e3f6c..3216d4a 100644 --- a/math_utils.py +++ b/math_utils.py @@ -39,7 +39,7 @@ class RunningMedian(object): def get_median(self): if len(self.lowers) == len(self.highers): - return (-self.lowers[0] + self.highers[0])/2 + return (-self.lowers[0] + self.highers[0]) / 2 elif len(self.lowers) > len(self.highers): return -self.lowers[0] else: @@ -143,12 +143,12 @@ def is_prime(n: int) -> bool: # This is checked so that we can skip middle five numbers in below # loop - if (n % 2 == 0 or n % 3 == 0): + if n % 2 == 0 or n % 3 == 0: return False i = 5 while i * i <= n: - if (n % i == 0 or n % (i + 2) == 0): + if n % i == 0 or n % (i + 2) == 0: return False i = i + 6 return True @@ -156,4 +156,5 @@ def is_prime(n: int) -> bool: if __name__ == '__main__': import doctest + doctest.testmod() diff --git a/orb_utils.py b/orb_utils.py index 5bc6d1a..fe82e2e 100644 --- a/orb_utils.py +++ b/orb_utils.py @@ -16,20 +16,18 @@ parser.add_argument( default="/Users/scott/orb_color", metavar="FILENAME", type=str, - help="The location of the orb file on whatever machine is hosting it." + help="The location of the orb file on whatever machine is hosting it.", ) parser.add_argument( "--orb_utils_user_machine", default="scott@cheetah.house", metavar="USER@HOSTNAME", type=str, - help="The username/machine combo that is hosting the orb." + help="The username/machine combo that is hosting the orb.", ) def make_orb(color: str) -> None: user_machine = config.config['orb_utils_user_machine'] orbfile_path = config.config['orb_utils_file_location'] - os.system( - f"ssh {user_machine} 'echo \"{color}\" > {orbfile_path}'" - ) + os.system(f"ssh {user_machine} 'echo \"{color}\" > {orbfile_path}'") diff --git a/parallelize.py b/parallelize.py index 698a7ec..98f883c 100644 --- a/parallelize.py +++ b/parallelize.py @@ -15,7 +15,9 @@ class Method(Enum): def parallelize( - _funct: typing.Optional[typing.Callable] = None, *, method: Method = Method.THREAD + _funct: typing.Optional[typing.Callable] = None, + *, + method: Method = Method.THREAD ) -> typing.Callable: """Usage: diff --git a/persistent.py b/persistent.py index 8829d6d..5c2b132 100644 --- a/persistent.py +++ b/persistent.py @@ -20,6 +20,7 @@ class Persistent(ABC): and implement their save() and load() methods. """ + @abstractmethod def save(self) -> bool: """ @@ -65,15 +66,15 @@ def was_file_written_today(filename: str) -> bool: mtime = file_utils.get_file_mtime_as_datetime(filename) now = datetime.datetime.now() return ( - mtime.month == now.month and - mtime.day == now.day and - mtime.year == now.year + mtime.month == now.month + and mtime.day == now.day + and mtime.year == now.year ) def was_file_written_within_n_seconds( - filename: str, - limit_seconds: int, + filename: str, + limit_seconds: int, ) -> bool: """Returns True if filename was written within the pas limit_seconds seconds. @@ -93,9 +94,10 @@ class PersistAtShutdown(enum.Enum): to disk. See details below. """ - NEVER = 0, - IF_NOT_LOADED = 1, - ALWAYS = 2, + + NEVER = (0,) + IF_NOT_LOADED = (1,) + ALWAYS = (2,) class persistent_autoloaded_singleton(object): @@ -118,10 +120,12 @@ class persistent_autoloaded_singleton(object): implementation. """ + def __init__( - self, - *, - persist_at_shutdown: PersistAtShutdown = PersistAtShutdown.IF_NOT_LOADED): + self, + *, + persist_at_shutdown: PersistAtShutdown = PersistAtShutdown.IF_NOT_LOADED, + ): self.persist_at_shutdown = persist_at_shutdown self.instance = None @@ -140,27 +144,33 @@ class persistent_autoloaded_singleton(object): # Otherwise, try to load it from persisted state. was_loaded = False - logger.debug(f'Attempting to load {cls.__name__} from persisted state.') + logger.debug( + f'Attempting to load {cls.__name__} from persisted state.' + ) self.instance = cls.load() if not self.instance: msg = 'Loading from cache failed.' logger.warning(msg) - logger.debug(f'Attempting to instantiate {cls.__name__} directly.') + logger.debug( + f'Attempting to instantiate {cls.__name__} directly.' + ) self.instance = cls(*args, **kwargs) else: - logger.debug(f'Class {cls.__name__} was loaded from persisted state successfully.') + logger.debug( + f'Class {cls.__name__} was loaded from persisted state successfully.' + ) was_loaded = True assert self.instance is not None - if ( - self.persist_at_shutdown is PersistAtShutdown.ALWAYS or - ( - not was_loaded and - self.persist_at_shutdown is PersistAtShutdown.IF_NOT_LOADED - ) + if self.persist_at_shutdown is PersistAtShutdown.ALWAYS or ( + not was_loaded + and self.persist_at_shutdown is PersistAtShutdown.IF_NOT_LOADED ): - logger.debug('Scheduling a deferred called to save at process shutdown time.') + logger.debug( + 'Scheduling a deferred called to save at process shutdown time.' + ) atexit.register(self.instance.save) return self.instance + return _load diff --git a/profanity_filter.py b/profanity_filter.py index 4723a2d..95540fa 100755 --- a/profanity_filter.py +++ b/profanity_filter.py @@ -494,7 +494,9 @@ class ProfanityFilter(object): result = result.replace('3', 'e') for x in string.punctuation: result = result.replace(x, "") - chunks = [self.stemmer.stem(word) for word in nltk.word_tokenize(result)] + chunks = [ + self.stemmer.stem(word) for word in nltk.word_tokenize(result) + ] return ' '.join(chunks) def tokenize(self, text: str): diff --git a/remote_worker.py b/remote_worker.py index 211b213..b58c6ba 100755 --- a/remote_worker.py +++ b/remote_worker.py @@ -31,20 +31,20 @@ cfg.add_argument( type=str, required=True, metavar='FILENAME', - help='The location of the bundle of code to execute.' + help='The location of the bundle of code to execute.', ) cfg.add_argument( '--result_file', type=str, required=True, metavar='FILENAME', - help='The location where we should write the computation results.' + help='The location where we should write the computation results.', ) cfg.add_argument( '--watch_for_cancel', action=argparse_utils.ActionNoYes, default=True, - help='Should we watch for the cancellation of our parent ssh process?' + help='Should we watch for the cancellation of our parent ssh process?', ) @@ -63,7 +63,9 @@ def watch_for_cancel(terminate_event: threading.Event) -> None: saw_sshd = True break if not saw_sshd: - logger.error('Did not see sshd in our ancestors list?! Committing suicide.') + logger.error( + 'Did not see sshd in our ancestors list?! Committing suicide.' + ) os.system('pstree') os.kill(os.getpid(), signal.SIGTERM) time.sleep(5.0) diff --git a/site_config.py b/site_config.py index 1281661..62c2b98 100644 --- a/site_config.py +++ b/site_config.py @@ -13,7 +13,7 @@ logger = logging.getLogger(__name__) args = config.add_commandline_args( f'Global Site Config ({__file__})', - 'Args related to global site-specific configuration' + 'Args related to global site-specific configuration', ) args.add_argument( '--site_config_override_location', @@ -64,6 +64,7 @@ def get_location(): def is_anyone_present_wrapper(location: Location): import base_presence + p = base_presence.PresenceDetection() return p.is_anyone_in_location_now(location) @@ -90,25 +91,29 @@ def get_config(): location = 'CABIN' if location == 'HOUSE': return SiteConfig( - location_name = 'HOUSE', - location = Location.HOUSE, - network = '10.0.0.0/24', - network_netmask = '255.255.255.0', - network_router_ip = '10.0.0.1', - presence_location = Location.HOUSE, - is_anyone_present = lambda x=Location.HOUSE: is_anyone_present_wrapper(x), - arper_minimum_device_count = 50, + location_name='HOUSE', + location=Location.HOUSE, + network='10.0.0.0/24', + network_netmask='255.255.255.0', + network_router_ip='10.0.0.1', + presence_location=Location.HOUSE, + is_anyone_present=lambda x=Location.HOUSE: is_anyone_present_wrapper( + x + ), + arper_minimum_device_count=50, ) elif location == 'CABIN': return SiteConfig( - location_name = 'CABIN', - location = Location.CABIN, - network = '192.168.0.0/24', - network_netmask = '255.255.255.0', - network_router_ip = '192.168.0.1', - presence_location = Location.CABIN, - is_anyone_present = lambda x=Location.CABIN: is_anyone_present_wrapper(x), - arper_minimum_device_count = 15, + location_name='CABIN', + location=Location.CABIN, + network='192.168.0.0/24', + network_netmask='255.255.255.0', + network_router_ip='192.168.0.1', + presence_location=Location.CABIN, + is_anyone_present=lambda x=Location.CABIN: is_anyone_present_wrapper( + x + ), + arper_minimum_device_count=15, ) else: raise Exception(f'Unknown site location: {location}') @@ -116,4 +121,5 @@ def get_config(): if __name__ == '__main__': import doctest + doctest.testmod() diff --git a/smart_future.py b/smart_future.py index c097d53..8f23e77 100644 --- a/smart_future.py +++ b/smart_future.py @@ -37,9 +37,7 @@ def wait_any(futures: List[SmartFuture], *, callback: Callable = None): def wait_all(futures: List[SmartFuture]) -> None: real_futures = [x.wrapped_future for x in futures] (done, not_done) = concurrent.futures.wait( - real_futures, - timeout=None, - return_when=concurrent.futures.ALL_COMPLETED + real_futures, timeout=None, return_when=concurrent.futures.ALL_COMPLETED ) assert len(done) == len(real_futures) assert len(not_done) == 0 diff --git a/state_tracker.py b/state_tracker.py index 9ce61e3..4836e3e 100644 --- a/state_tracker.py +++ b/state_tracker.py @@ -22,6 +22,7 @@ class StateTracker(ABC): provided to the c'tor. """ + def __init__(self, update_ids_to_update_secs: Dict[str, float]) -> None: """The update_ids_to_update_secs dict parameter describes one or more update types (unique update_ids) and the periodicity(ies), in @@ -106,6 +107,7 @@ class AutomaticStateTracker(StateTracker): terminate the updates. """ + @background_thread def pace_maker(self, should_terminate) -> None: """Entry point for a background thread to own calling heartbeat() @@ -128,6 +130,7 @@ class AutomaticStateTracker(StateTracker): override_sleep_delay: Optional[float] = None, ) -> None: import math_utils + super().__init__(update_ids_to_update_secs) if override_sleep_delay is not None: logger.debug(f'Overriding sleep delay to {override_sleep_delay}') @@ -174,15 +177,17 @@ class WaitableAutomaticStateTracker(AutomaticStateTracker): # else before looping up into wait again. """ + def __init__( - self, - update_ids_to_update_secs: Dict[str, float], - *, - override_sleep_delay: Optional[float] = None, + self, + update_ids_to_update_secs: Dict[str, float], + *, + override_sleep_delay: Optional[float] = None, ) -> None: self._something_changed = threading.Event() - super().__init__(update_ids_to_update_secs, - override_sleep_delay=override_sleep_delay) + super().__init__( + update_ids_to_update_secs, override_sleep_delay=override_sleep_delay + ) def something_changed(self): self._something_changed.set() @@ -193,9 +198,5 @@ class WaitableAutomaticStateTracker(AutomaticStateTracker): def reset(self): self._something_changed.clear() - def wait(self, - *, - timeout=None): - return self._something_changed.wait( - timeout=timeout - ) + def wait(self, *, timeout=None): + return self._something_changed.wait(timeout=timeout) diff --git a/text_utils.py b/text_utils.py index bc05dd9..94df3e3 100644 --- a/text_utils.py +++ b/text_utils.py @@ -23,6 +23,7 @@ def get_console_rows_columns() -> RowsColumns: """Returns the number of rows/columns on the current console.""" from exec_utils import cmd_with_timeout + try: rows, columns = cmd_with_timeout( "stty size", @@ -50,16 +51,13 @@ def progress_graph( ret = "\r" if redraw else "\n" bar = bar_graph( percent, - include_text = True, - width = width, - fgcolor = fgcolor, - left_end = left_end, - right_end = right_end) - print( - bar, - end=ret, - flush=True, - file=sys.stderr) + include_text=True, + width=width, + fgcolor=fgcolor, + left_end=left_end, + right_end=right_end, + ) + print(bar, end=ret, flush=True, file=sys.stderr) def bar_graph( @@ -94,13 +92,16 @@ def bar_graph( part_width = math.floor(remainder_width * 8) part_char = [" ", "▏", "▎", "▍", "▌", "▋", "▊", "▉"][part_width] return ( - left_end + - fgcolor + - "█" * whole_width + part_char + - " " * (width - whole_width - 1) + - reset + - right_end + " " + - text) + left_end + + fgcolor + + "█" * whole_width + + part_char + + " " * (width - whole_width - 1) + + reset + + right_end + + " " + + text + ) def distribute_strings( @@ -128,9 +129,9 @@ def distribute_strings( string, width=subwidth, alignment=alignment, padding=padding ) retval += string - while(len(retval) > width): + while len(retval) > width: retval = retval.replace(' ', ' ', 1) - while(len(retval) < width): + while len(retval) < width: retval = retval.replace(' ', ' ', 1) return retval @@ -150,7 +151,13 @@ def justify_string_by_chunk( padding = padding[0] first, *rest, last = string.split() w = width - (len(first) + 1 + len(last) + 1) - ret = first + padding + distribute_strings(rest, width=w, padding=padding) + padding + last + ret = ( + first + + padding + + distribute_strings(rest, width=w, padding=padding) + + padding + + last + ) return ret @@ -177,11 +184,7 @@ def justify_string( elif alignment == "r": string = padding + string elif alignment == "j": - return justify_string_by_chunk( - string, - width=width, - padding=padding - ) + return justify_string_by_chunk(string, width=width, padding=padding) elif alignment == "c": if len(string) % 2 == 0: string += padding @@ -251,11 +254,14 @@ class Indenter(object): with i: i.print('1, 2, 3') """ - def __init__(self, - *, - pad_prefix: Optional[str] = None, - pad_char: str = ' ', - pad_count: int = 4): + + def __init__( + self, + *, + pad_prefix: Optional[str] = None, + pad_char: str = ' ', + pad_count: int = 4, + ): self.level = -1 if pad_prefix is not None: self.pad_prefix = pad_prefix @@ -274,6 +280,7 @@ class Indenter(object): def print(self, *arg, **kwargs): import string_utils + text = string_utils.sprintf(*arg, **kwargs) print(self.pad_prefix + self.padding * self.level + text, end='') @@ -287,7 +294,7 @@ def header(title: str, *, width: int = 80, color: str = ''): """ w = width - w -= (len(title) + 4) + w -= len(title) + 4 if w >= 4: left = 4 * '-' right = (w - 4) * '-' @@ -302,4 +309,5 @@ def header(title: str, *, width: int = 80, color: str = ''): if __name__ == '__main__': import doctest + doctest.testmod() diff --git a/thread_utils.py b/thread_utils.py index 0130cdc..ad1f0bf 100644 --- a/thread_utils.py +++ b/thread_utils.py @@ -27,7 +27,7 @@ def is_current_thread_main_thread() -> bool: def background_thread( - _funct: Optional[Callable] + _funct: Optional[Callable], ) -> Tuple[threading.Thread, threading.Event]: """A function decorator to create a background thread. @@ -58,10 +58,11 @@ def background_thread( periodically check. If the event is set, it means the thread has been requested to terminate ASAP. """ + def wrapper(funct: Callable): @functools.wraps(funct) def inner_wrapper( - *a, **kwa + *a, **kwa ) -> Tuple[threading.Thread, threading.Event]: should_terminate = threading.Event() should_terminate.clear() @@ -72,10 +73,9 @@ def background_thread( kwargs=kwa, ) thread.start() - logger.debug( - f'Started thread {thread.name} tid={thread.ident}' - ) + logger.debug(f'Started thread {thread.name} tid={thread.ident}') return (thread, should_terminate) + return inner_wrapper if _funct is None: @@ -85,8 +85,8 @@ def background_thread( def periodically_invoke( - period_sec: float, - stop_after: Optional[int], + period_sec: float, + stop_after: Optional[int], ): """ Periodically invoke a decorated function. Stop after N invocations @@ -108,6 +108,7 @@ def periodically_invoke( print(f"Hello, {name}") """ + def decorator_repeat(func): def helper_thread(should_terminate, *args, **kwargs) -> None: if stop_after is None: @@ -130,14 +131,12 @@ def periodically_invoke( should_terminate.clear() newargs = (should_terminate, *args) thread = threading.Thread( - target=helper_thread, - args = newargs, - kwargs = kwargs + target=helper_thread, args=newargs, kwargs=kwargs ) thread.start() - logger.debug( - f'Started thread {thread.name} tid={thread.ident}' - ) + logger.debug(f'Started thread {thread.name} tid={thread.ident}') return (thread, should_terminate) + return wrapper_repeat + return decorator_repeat diff --git a/unittest_utils.py b/unittest_utils.py index d63f2b5..4a9669d 100644 --- a/unittest_utils.py +++ b/unittest_utils.py @@ -32,8 +32,8 @@ import sqlalchemy as sa logger = logging.getLogger(__name__) cfg = config.add_commandline_args( - f'Logging ({__file__})', - 'Args related to function decorators') + f'Logging ({__file__})', 'Args related to function decorators' +) cfg.add_argument( '--unittests_ignore_perf', action='store_true', @@ -44,34 +44,34 @@ cfg.add_argument( '--unittests_num_perf_samples', type=int, default=50, - help='The count of perf timing samples we need to see before blocking slow runs on perf grounds' + help='The count of perf timing samples we need to see before blocking slow runs on perf grounds', ) cfg.add_argument( '--unittests_drop_perf_traces', type=str, nargs=1, default=None, - help='The identifier (i.e. file!test_fixture) for which we should drop all perf data' + help='The identifier (i.e. file!test_fixture) for which we should drop all perf data', ) cfg.add_argument( '--unittests_persistance_strategy', choices=['FILE', 'DATABASE'], default='DATABASE', - help='Should we persist perf data in a file or db?' + help='Should we persist perf data in a file or db?', ) cfg.add_argument( '--unittests_perfdb_filename', type=str, metavar='FILENAME', default=f'{os.environ["HOME"]}/.python_unittest_performance_db', - help='File in which to store perf data (iff --unittests_persistance_strategy is FILE)' + help='File in which to store perf data (iff --unittests_persistance_strategy is FILE)', ) cfg.add_argument( '--unittests_perfdb_spec', type=str, metavar='DBSPEC', default='mariadb+pymysql://python_unittest:@db.house:3306/python_unittest_performance', - help='Db connection spec for perf data (iff --unittest_persistance_strategy is DATABASE)' + help='Db connection spec for perf data (iff --unittest_persistance_strategy is DATABASE)', ) # >>> This is the hacky business, FYI. <<< @@ -87,7 +87,9 @@ class PerfRegressionDataPersister(ABC): pass @abstractmethod - def save_performance_data(self, method_id: str, data: Dict[str, List[float]]): + def save_performance_data( + self, method_id: str, data: Dict[str, List[float]] + ): pass @abstractmethod @@ -104,7 +106,9 @@ class FileBasedPerfRegressionDataPersister(PerfRegressionDataPersister): with open(self.filename, 'rb') as f: return pickle.load(f) - def save_performance_data(self, method_id: str, data: Dict[str, List[float]]): + def save_performance_data( + self, method_id: str, data: Dict[str, List[float]] + ): for trace in self.traces_to_delete: if trace in data: data[trace] = [] @@ -134,7 +138,9 @@ class DatabasePerfRegressionDataPersister(PerfRegressionDataPersister): results.close() return ret - def save_performance_data(self, method_id: str, data: Dict[str, List[float]]): + def save_performance_data( + self, method_id: str, data: Dict[str, List[float]] + ): self.delete_performance_data(method_id) for (method_id, perf_data) in data.items(): sql = 'INSERT INTO runtimes_by_function (function, runtime) VALUES ' @@ -155,6 +161,7 @@ def check_method_for_perf_regressions(func: Callable) -> Callable: message if it has become too slow. """ + @functools.wraps(func) def wrapper_perf_monitor(*args, **kwargs): if config.config['unittests_persistance_strategy'] == 'FILE': @@ -162,7 +169,9 @@ def check_method_for_perf_regressions(func: Callable) -> Callable: helper = FileBasedPerfRegressionDataPersister(filename) elif config.config['unittests_persistance_strategy'] == 'DATABASE': dbspec = config.config['unittests_perfdb_spec'] - dbspec = dbspec.replace('', scott_secrets.MARIADB_UNITTEST_PERF_PASSWORD) + dbspec = dbspec.replace( + '', scott_secrets.MARIADB_UNITTEST_PERF_PASSWORD + ) helper = DatabasePerfRegressionDataPersister(dbspec) else: raise Exception( @@ -198,14 +207,14 @@ def check_method_for_perf_regressions(func: Callable) -> Callable: hist = perfdb.get(func_id, []) if len(hist) < config.config['unittests_num_perf_samples']: hist.append(run_time) - logger.debug( - f'Still establishing a perf baseline for {func_name}' - ) + logger.debug(f'Still establishing a perf baseline for {func_name}') else: stdev = statistics.stdev(hist) logger.debug(f'For {func_name}, performance stdev={stdev}') slowest = hist[-1] - logger.debug(f'For {func_name}, slowest perf on record is {slowest:f}s') + logger.debug( + f'For {func_name}, slowest perf on record is {slowest:f}s' + ) limit = slowest + stdev * 4 logger.debug( f'For {func_name}, max acceptable runtime is {limit:f}s' @@ -213,10 +222,7 @@ def check_method_for_perf_regressions(func: Callable) -> Callable: logger.debug( f'For {func_name}, actual observed runtime was {run_time:f}s' ) - if ( - run_time > limit and - not config.config['unittests_ignore_perf'] - ): + if run_time > limit and not config.config['unittests_ignore_perf']: msg = f'''{func_id} performance has regressed unacceptably. {slowest:f}s is the slowest runtime on record in {len(hist)} perf samples. It just ran in {run_time:f}s which is 4+ stdevs slower than the slowest. @@ -226,8 +232,8 @@ Here is the current, full db perf timing distribution: for x in hist: msg += f'{x:f}\n' logger.error(msg) - slf = args[0] # Peek at the wrapped function's self ref. - slf.fail(msg) # ...to fail the testcase. + slf = args[0] # Peek at the wrapped function's self ref. + slf.fail(msg) # ...to fail the testcase. else: hist.append(run_time) @@ -239,6 +245,7 @@ Here is the current, full db perf timing distribution: perfdb[func_id] = hist helper.save_performance_data(func_id, perfdb) return value + return wrapper_perf_monitor @@ -255,6 +262,7 @@ def check_all_methods_for_perf_regressions(prefix='test_'): ... """ + def decorate_the_testcase(cls): if issubclass(cls, unittest.TestCase): for name, m in inspect.getmembers(cls, inspect.isfunction): @@ -262,12 +270,14 @@ def check_all_methods_for_perf_regressions(prefix='test_'): setattr(cls, name, check_method_for_perf_regressions(m)) logger.debug(f'Wrapping {cls.__name__}:{name}.') return cls + return decorate_the_testcase def breakpoint(): """Hard code a breakpoint somewhere; drop into pdb.""" import pdb + pdb.set_trace() @@ -346,4 +356,5 @@ class RecordMultipleStreams(object): if __name__ == '__main__': import doctest + doctest.testmod() diff --git a/unscrambler.py b/unscrambler.py index d3686d6..3abb6d8 100644 --- a/unscrambler.py +++ b/unscrambler.py @@ -121,9 +121,13 @@ class Unscrambler(object): # 52 bits @staticmethod - def _compute_word_fingerprint(word: str, population: Mapping[str, int]) -> int: + def _compute_word_fingerprint( + word: str, population: Mapping[str, int] + ) -> int: fp = 0 - for pair in sorted(population.items(), key=lambda x: x[1], reverse=True): + for pair in sorted( + population.items(), key=lambda x: x[1], reverse=True + ): letter = pair[0] if letter in fprint_feature_bit: count = pair[1] @@ -142,7 +146,9 @@ class Unscrambler(object): population: Mapping[str, int], ) -> int: sig = 0 - for pair in sorted(population.items(), key=lambda x: x[1], reverse=True): + for pair in sorted( + population.items(), key=lambda x: x[1], reverse=True + ): letter = pair[0] if letter not in letter_sigs: continue @@ -183,7 +189,9 @@ class Unscrambler(object): """ population = list_utils.population_counts(word) fprint = Unscrambler._compute_word_fingerprint(word, population) - letter_sig = Unscrambler._compute_word_letter_sig(letter_sigs, word, population) + letter_sig = Unscrambler._compute_word_letter_sig( + letter_sigs, word, population + ) assert fprint & letter_sig == 0 sig = fprint | letter_sig return sig @@ -230,7 +238,9 @@ class Unscrambler(object): """ sig = Unscrambler.compute_word_sig(word) - return self.lookup_by_sig(sig, include_fuzzy_matches=include_fuzzy_matches) + return self.lookup_by_sig( + sig, include_fuzzy_matches=include_fuzzy_matches + ) def lookup_by_sig( self, sig: int, *, include_fuzzy_matches: bool = False diff --git a/waitable_presence.py b/waitable_presence.py index 9e0a9d0..cd5501d 100644 --- a/waitable_presence.py +++ b/waitable_presence.py @@ -20,7 +20,9 @@ import state_tracker logger = logging.getLogger(__name__) -class WaitablePresenceDetectorWithMemory(state_tracker.WaitableAutomaticStateTracker): +class WaitablePresenceDetectorWithMemory( + state_tracker.WaitableAutomaticStateTracker +): """ This is a waitable class that keeps a PresenceDetector internally and periodically polls it to detect changes in presence in a @@ -38,16 +40,18 @@ class WaitablePresenceDetectorWithMemory(state_tracker.WaitableAutomaticStateTra """ def __init__( - self, - override_update_interval_sec: float = 60.0, - override_location: Location = site_config.get_location(), + self, + override_update_interval_sec: float = 60.0, + override_location: Location = site_config.get_location(), ) -> None: self.last_someone_is_home: Optional[bool] = None self.someone_is_home: Optional[bool] = None self.everyone_gone_since: Optional[datetime.datetime] = None self.someone_home_since: Optional[datetime.datetime] = None self.location = override_location - self.detector: base_presence.PresenceDetection = base_presence.PresenceDetection() + self.detector: base_presence.PresenceDetection = ( + base_presence.PresenceDetection() + ) super().__init__( { 'poll_presence': override_update_interval_sec, @@ -84,7 +88,9 @@ class WaitablePresenceDetectorWithMemory(state_tracker.WaitableAutomaticStateTra def check_detector(self) -> None: if len(self.detector.dark_locations) > 0: - logger.debug('PresenceDetector is incomplete; trying to reinitialize...') + logger.debug( + 'PresenceDetector is incomplete; trying to reinitialize...' + ) self.detector = base_presence.PresenceDetection() def is_someone_home(self) -> Tuple[bool, datetime.datetime]: