From: Scott Date: Wed, 2 Feb 2022 19:09:18 +0000 (-0800) Subject: mypy clean! X-Git-Url: https://wannabe.guru.org/gitweb/?a=commitdiff_plain;h=5317c50ce7a96a37acfab3800c0935580766dbbf;p=python_utils.git mypy clean! --- diff --git a/conversion_utils.py b/conversion_utils.py index 684edc0..8eaecd5 100644 --- a/conversion_utils.py +++ b/conversion_utils.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 -from numbers import Number -from typing import Callable +from typing import Callable, SupportsFloat import constants @@ -38,10 +37,10 @@ class Converter(object): self.from_canonical_f = from_canonical self.unit = unit - def to_canonical(self, n: Number) -> Number: + def to_canonical(self, n: SupportsFloat) -> SupportsFloat: return self.to_canonical_f(n) - def from_canonical(self, n: Number) -> Number: + def from_canonical(self, n: SupportsFloat) -> SupportsFloat: return self.from_canonical_f(n) def unit_suffix(self) -> str: @@ -97,7 +96,7 @@ conversion_catalog = { } -def convert(magnitude: Number, from_thing: str, to_thing: str) -> float: +def convert(magnitude: SupportsFloat, from_thing: str, to_thing: str) -> float: src = conversion_catalog.get(from_thing, None) dst = conversion_catalog.get(to_thing, None) if src is None or dst is None: @@ -107,7 +106,9 @@ def convert(magnitude: Number, from_thing: str, to_thing: str) -> float: return _convert(magnitude, src, dst) -def _convert(magnitude: Number, from_unit: Converter, to_unit: Converter) -> float: +def _convert( + magnitude: SupportsFloat, from_unit: Converter, to_unit: Converter +) -> float: canonical = from_unit.to_canonical(magnitude) converted = to_unit.from_canonical(canonical) return float(converted) diff --git a/directory_filter.py b/directory_filter.py index 508baf3..b076bad 100644 --- a/directory_filter.py +++ b/directory_filter.py @@ -3,7 +3,7 @@ import hashlib import logging import os -from typing import Any, Optional +from typing import Any, Dict, Optional, Set logger = logging.getLogger(__name__) @@ -38,8 +38,8 @@ class DirectoryFileFilter(object): if not file_utils.does_directory_exist(directory): raise ValueError(directory) self.directory = directory - self.md5_by_filename = {} - self.mtime_by_filename = {} + self.md5_by_filename: Dict[str, str] = {} + self.mtime_by_filename: Dict[str, float] = {} self._update() def _update(self): @@ -55,6 +55,7 @@ class DirectoryFileFilter(object): assert file_utils.does_file_exist(filename) if mtime is None: mtime = file_utils.get_file_raw_mtime(filename) + assert mtime if self.mtime_by_filename.get(filename, 0) != mtime: md5 = file_utils.get_file_md5(filename) logger.debug(f'Computed/stored {filename}\'s MD5 at ts={mtime} ({md5})') @@ -102,7 +103,7 @@ class DirectoryAllFilesFilter(DirectoryFileFilter): """ def __init__(self, directory: str): - self.all_md5s = set() + self.all_md5s: Set[str] = set() super().__init__(directory) def _update_file(self, filename: str, mtime: Optional[float] = None): @@ -111,13 +112,15 @@ class DirectoryAllFilesFilter(DirectoryFileFilter): assert file_utils.does_file_exist(filename) if mtime is None: mtime = file_utils.get_file_raw_mtime(filename) + assert mtime if self.mtime_by_filename.get(filename, 0) != mtime: md5 = file_utils.get_file_md5(filename) self.mtime_by_filename[filename] = mtime self.md5_by_filename[filename] = md5 self.all_md5s.add(md5) - def apply(self, item: Any) -> bool: + def apply(self, item: Any, ignored_filename: str = None) -> bool: + assert not ignored_filename self._update() mem_hash = hashlib.md5() mem_hash.update(item) diff --git a/list_utils.py b/list_utils.py index 71630dc..d70159a 100644 --- a/list_utils.py +++ b/list_utils.py @@ -65,7 +65,7 @@ def remove_list_if_one_element(lst: List[Any]) -> Any: return lst -def population_counts(lst: List[Any]) -> Counter: +def population_counts(lst: Sequence[Any]) -> Counter: """ Return a population count mapping for the list (i.e. the keys are list items and the values are the number of occurrances of that diff --git a/lockfile.py b/lockfile.py index 4b6aade..03fbb9e 100644 --- a/lockfile.py +++ b/lockfile.py @@ -34,7 +34,7 @@ class LockFileException(Exception): class LockFileContents: pid: int commandline: str - expiration_timestamp: float + expiration_timestamp: Optional[float] class LockFile(object): @@ -181,7 +181,7 @@ class LockFile(object): # Has the lock expiration expired? if contents.expiration_timestamp is not None: now = datetime.datetime.now().timestamp() - if now > contents.expiration_datetime: + if now > contents.expiration_timestamp: msg = f'Lockfile {self.lockfile} expiration time has passed; force acquiring' logger.warning(msg) self.release() diff --git a/remote_worker.py b/remote_worker.py index b58c6ba..12a5028 100755 --- a/remote_worker.py +++ b/remote_worker.py @@ -10,6 +10,7 @@ import signal import threading import sys import time +from typing import Optional import cloudpickle # type: ignore import psutil # type: ignore @@ -17,6 +18,7 @@ import psutil # type: ignore import argparse_utils import bootstrap import config +from stopwatch import Timer from thread_utils import background_thread @@ -76,11 +78,24 @@ def watch_for_cancel(terminate_event: threading.Event) -> None: time.sleep(1.0) +def cleanup_and_exit( + thread: Optional[threading.Thread], + stop_thread: Optional[threading.Event], + exit_code: int, +) -> None: + if stop_thread is not None: + stop_thread.set() + assert thread is not None + thread.join() + sys.exit(exit_code) + + @bootstrap.initialize def main() -> None: in_file = config.config['code_file'] out_file = config.config['result_file'] + thread = None stop_thread = None if config.config['watch_for_cancel']: (thread, stop_thread) = watch_for_cancel() @@ -92,8 +107,7 @@ def main() -> None: except Exception as e: logger.exception(e) logger.critical(f'Problem reading {in_file}. Aborting.') - stop_thread.set() - sys.exit(-1) + cleanup_and_exit(thread, stop_thread, 1) logger.debug(f'Deserializing {in_file}.') try: @@ -101,14 +115,12 @@ def main() -> None: except Exception as e: logger.exception(e) logger.critical(f'Problem deserializing {in_file}. Aborting.') - stop_thread.set() - sys.exit(-1) + cleanup_and_exit(thread, stop_thread, 2) logger.debug('Invoking user code...') - start = time.time() - ret = fun(*args, **kwargs) - end = time.time() - logger.debug(f'User code took {end - start:.1f}s') + with Timer() as t: + ret = fun(*args, **kwargs) + logger.debug(f'User code took {t():.1f}s') logger.debug('Serializing results') try: @@ -116,8 +128,7 @@ def main() -> None: except Exception as e: logger.exception(e) logger.critical(f'Could not serialize result ({type(ret)}). Aborting.') - stop_thread.set() - sys.exit(-1) + cleanup_and_exit(thread, stop_thread, 3) logger.debug(f'Writing {out_file}.') try: @@ -126,12 +137,8 @@ def main() -> None: except Exception as e: logger.exception(e) logger.critical(f'Error writing {out_file}. Aborting.') - stop_thread.set() - sys.exit(-1) - - if stop_thread is not None: - stop_thread.set() - thread.join() + cleanup_and_exit(thread, stop_thread, 4) + cleanup_and_exit(thread, stop_thread, 0) if __name__ == '__main__': diff --git a/stopwatch.py b/stopwatch.py index 516138c..c6c154c 100644 --- a/stopwatch.py +++ b/stopwatch.py @@ -10,7 +10,7 @@ class Timer(object): e.g. - with timer.Timer() as t: + with stopwatch.Timer() as t: do_the_thing() walltime = t() diff --git a/unittest_utils.py b/unittest_utils.py index b9746a8..81b339a 100644 --- a/unittest_utils.py +++ b/unittest_utils.py @@ -18,7 +18,7 @@ import random import statistics import time import tempfile -from typing import Callable, Dict, List +from typing import Any, Callable, Dict, List, Optional import unittest import warnings @@ -83,7 +83,7 @@ class PerfRegressionDataPersister(ABC): pass @abstractmethod - def load_performance_data(self) -> Dict[str, List[float]]: + def load_performance_data(self, method_id: str) -> Dict[str, List[float]]: pass @abstractmethod @@ -98,7 +98,7 @@ class PerfRegressionDataPersister(ABC): class FileBasedPerfRegressionDataPersister(PerfRegressionDataPersister): def __init__(self, filename: str): self.filename = filename - self.traces_to_delete = [] + self.traces_to_delete: List[str] = [] def load_performance_data(self, method_id: str) -> Dict[str, List[float]]: with open(self.filename, 'rb') as f: @@ -128,7 +128,7 @@ class DatabasePerfRegressionDataPersister(PerfRegressionDataPersister): f'SELECT * FROM runtimes_by_function WHERE function = "{method_id}";' ) ) - ret = {method_id: []} + ret: Dict[str, List[float]] = {method_id: []} for result in results.all(): ret[method_id].append(result['runtime']) results.close() @@ -283,14 +283,16 @@ class RecordStdout(object): def __init__(self) -> None: self.destination = tempfile.SpooledTemporaryFile(mode='r+') - self.recorder = None + self.recorder: Optional[contextlib.redirect_stdout] = None def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]: self.recorder = contextlib.redirect_stdout(self.destination) + assert self.recorder self.recorder.__enter__() return lambda: self.destination - def __exit__(self, *args) -> bool: + def __exit__(self, *args) -> Optional[bool]: + assert self.recorder self.recorder.__exit__(*args) self.destination.seek(0) return None @@ -310,14 +312,16 @@ class RecordStderr(object): def __init__(self) -> None: self.destination = tempfile.SpooledTemporaryFile(mode='r+') - self.recorder = None + self.recorder: Optional[contextlib.redirect_stdout[Any]] = None def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]: - self.recorder = contextlib.redirect_stderr(self.destination) + self.recorder = contextlib.redirect_stderr(self.destination) # type: ignore + assert self.recorder self.recorder.__enter__() return lambda: self.destination - def __exit__(self, *args) -> bool: + def __exit__(self, *args) -> Optional[bool]: + assert self.recorder self.recorder.__exit__(*args) self.destination.seek(0) return None @@ -331,7 +335,7 @@ class RecordMultipleStreams(object): def __init__(self, *files) -> None: self.files = [*files] self.destination = tempfile.SpooledTemporaryFile(mode='r+') - self.saved_writes = [] + self.saved_writes: List[Callable[..., Any]] = [] def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]: for f in self.files: @@ -339,10 +343,11 @@ class RecordMultipleStreams(object): f.write = self.destination.write return lambda: self.destination - def __exit__(self, *args) -> bool: + def __exit__(self, *args) -> Optional[bool]: for f in self.files: f.write = self.saved_writes.pop() self.destination.seek(0) + return None if __name__ == '__main__': diff --git a/unscrambler.py b/unscrambler.py index 8fcd65d..d9e4253 100644 --- a/unscrambler.py +++ b/unscrambler.py @@ -198,7 +198,7 @@ class Unscrambler(object): unless you want to populate the same exact files. """ - words_by_sigs = {} + words_by_sigs: Dict[int, str] = {} seen = set() with open(dictfile, "r") as f: for word in f: diff --git a/waitable_presence.py b/waitable_presence.py index d54511f..77b6e81 100644 --- a/waitable_presence.py +++ b/waitable_presence.py @@ -98,6 +98,8 @@ class WaitablePresenceDetectorWithMemory(state_tracker.WaitableAutomaticStateTra if self.someone_is_home is None: raise Exception("Too Soon!") if self.someone_is_home: + assert self.someone_home_since return (True, self.someone_home_since) else: + assert self.everyone_gone_since return (False, self.everyone_gone_since)