mypy clean!
authorScott <[email protected]>
Wed, 2 Feb 2022 19:09:18 +0000 (11:09 -0800)
committerScott <[email protected]>
Wed, 2 Feb 2022 19:09:18 +0000 (11:09 -0800)
conversion_utils.py
directory_filter.py
list_utils.py
lockfile.py
remote_worker.py
stopwatch.py
unittest_utils.py
unscrambler.py
waitable_presence.py

index 684edc0a9116aa827db2c7c83a52de10dbbba73b..8eaecd5bd7b8227b6cf8baeee4723e6bfa0450ed 100644 (file)
@@ -1,7 +1,6 @@
 #!/usr/bin/env python3
 
-from numbers import Number
-from typing import Callable
+from typing import Callable, SupportsFloat
 
 import constants
 
@@ -38,10 +37,10 @@ class Converter(object):
         self.from_canonical_f = from_canonical
         self.unit = unit
 
-    def to_canonical(self, n: Number) -> Number:
+    def to_canonical(self, n: SupportsFloat) -> SupportsFloat:
         return self.to_canonical_f(n)
 
-    def from_canonical(self, n: Number) -> Number:
+    def from_canonical(self, n: SupportsFloat) -> SupportsFloat:
         return self.from_canonical_f(n)
 
     def unit_suffix(self) -> str:
@@ -97,7 +96,7 @@ conversion_catalog = {
 }
 
 
-def convert(magnitude: Number, from_thing: str, to_thing: str) -> float:
+def convert(magnitude: SupportsFloat, from_thing: str, to_thing: str) -> float:
     src = conversion_catalog.get(from_thing, None)
     dst = conversion_catalog.get(to_thing, None)
     if src is None or dst is None:
@@ -107,7 +106,9 @@ def convert(magnitude: Number, from_thing: str, to_thing: str) -> float:
     return _convert(magnitude, src, dst)
 
 
-def _convert(magnitude: Number, from_unit: Converter, to_unit: Converter) -> float:
+def _convert(
+    magnitude: SupportsFloat, from_unit: Converter, to_unit: Converter
+) -> float:
     canonical = from_unit.to_canonical(magnitude)
     converted = to_unit.from_canonical(canonical)
     return float(converted)
index 508baf3bc888cf49e832ca72feddeda1797890a0..b076badf25dff7e34e2358b3e45d52c49692a424 100644 (file)
@@ -3,7 +3,7 @@
 import hashlib
 import logging
 import os
-from typing import Any, Optional
+from typing import Any, Dict, Optional, Set
 
 logger = logging.getLogger(__name__)
 
@@ -38,8 +38,8 @@ class DirectoryFileFilter(object):
         if not file_utils.does_directory_exist(directory):
             raise ValueError(directory)
         self.directory = directory
-        self.md5_by_filename = {}
-        self.mtime_by_filename = {}
+        self.md5_by_filename: Dict[str, str] = {}
+        self.mtime_by_filename: Dict[str, float] = {}
         self._update()
 
     def _update(self):
@@ -55,6 +55,7 @@ class DirectoryFileFilter(object):
         assert file_utils.does_file_exist(filename)
         if mtime is None:
             mtime = file_utils.get_file_raw_mtime(filename)
+        assert mtime
         if self.mtime_by_filename.get(filename, 0) != mtime:
             md5 = file_utils.get_file_md5(filename)
             logger.debug(f'Computed/stored {filename}\'s MD5 at ts={mtime} ({md5})')
@@ -102,7 +103,7 @@ class DirectoryAllFilesFilter(DirectoryFileFilter):
     """
 
     def __init__(self, directory: str):
-        self.all_md5s = set()
+        self.all_md5s: Set[str] = set()
         super().__init__(directory)
 
     def _update_file(self, filename: str, mtime: Optional[float] = None):
@@ -111,13 +112,15 @@ class DirectoryAllFilesFilter(DirectoryFileFilter):
         assert file_utils.does_file_exist(filename)
         if mtime is None:
             mtime = file_utils.get_file_raw_mtime(filename)
+        assert mtime
         if self.mtime_by_filename.get(filename, 0) != mtime:
             md5 = file_utils.get_file_md5(filename)
             self.mtime_by_filename[filename] = mtime
             self.md5_by_filename[filename] = md5
             self.all_md5s.add(md5)
 
-    def apply(self, item: Any) -> bool:
+    def apply(self, item: Any, ignored_filename: str = None) -> bool:
+        assert not ignored_filename
         self._update()
         mem_hash = hashlib.md5()
         mem_hash.update(item)
index 71630dc9a6fdb4d968f890b2ac4d3b8cb78360d2..d70159a1b2dadb61640eae20f029608cabd2f46e 100644 (file)
@@ -65,7 +65,7 @@ def remove_list_if_one_element(lst: List[Any]) -> Any:
         return lst
 
 
-def population_counts(lst: List[Any]) -> Counter:
+def population_counts(lst: Sequence[Any]) -> Counter:
     """
     Return a population count mapping for the list (i.e. the keys are
     list items and the values are the number of occurrances of that
index 4b6aadeffde8ec7bf255025873d720e2b3afda93..03fbb9ef73a1449ffd8779a85ffe4ba885d193f6 100644 (file)
@@ -34,7 +34,7 @@ class LockFileException(Exception):
 class LockFileContents:
     pid: int
     commandline: str
-    expiration_timestamp: float
+    expiration_timestamp: Optional[float]
 
 
 class LockFile(object):
@@ -181,7 +181,7 @@ class LockFile(object):
                     # Has the lock expiration expired?
                     if contents.expiration_timestamp is not None:
                         now = datetime.datetime.now().timestamp()
-                        if now > contents.expiration_datetime:
+                        if now > contents.expiration_timestamp:
                             msg = f'Lockfile {self.lockfile} expiration time has passed; force acquiring'
                             logger.warning(msg)
                             self.release()
index b58c6ba0a66f8d32b2b81af72a66d23493c9b2e5..12a5028c30e2bf95093542296bb1cf5a9866f879 100755 (executable)
@@ -10,6 +10,7 @@ import signal
 import threading
 import sys
 import time
+from typing import Optional
 
 import cloudpickle  # type: ignore
 import psutil  # type: ignore
@@ -17,6 +18,7 @@ import psutil  # type: ignore
 import argparse_utils
 import bootstrap
 import config
+from stopwatch import Timer
 from thread_utils import background_thread
 
 
@@ -76,11 +78,24 @@ def watch_for_cancel(terminate_event: threading.Event) -> None:
         time.sleep(1.0)
 
 
+def cleanup_and_exit(
+    thread: Optional[threading.Thread],
+    stop_thread: Optional[threading.Event],
+    exit_code: int,
+) -> None:
+    if stop_thread is not None:
+        stop_thread.set()
+        assert thread is not None
+        thread.join()
+    sys.exit(exit_code)
+
+
 @bootstrap.initialize
 def main() -> None:
     in_file = config.config['code_file']
     out_file = config.config['result_file']
 
+    thread = None
     stop_thread = None
     if config.config['watch_for_cancel']:
         (thread, stop_thread) = watch_for_cancel()
@@ -92,8 +107,7 @@ def main() -> None:
     except Exception as e:
         logger.exception(e)
         logger.critical(f'Problem reading {in_file}.  Aborting.')
-        stop_thread.set()
-        sys.exit(-1)
+        cleanup_and_exit(thread, stop_thread, 1)
 
     logger.debug(f'Deserializing {in_file}.')
     try:
@@ -101,14 +115,12 @@ def main() -> None:
     except Exception as e:
         logger.exception(e)
         logger.critical(f'Problem deserializing {in_file}.  Aborting.')
-        stop_thread.set()
-        sys.exit(-1)
+        cleanup_and_exit(thread, stop_thread, 2)
 
     logger.debug('Invoking user code...')
-    start = time.time()
-    ret = fun(*args, **kwargs)
-    end = time.time()
-    logger.debug(f'User code took {end - start:.1f}s')
+    with Timer() as t:
+        ret = fun(*args, **kwargs)
+    logger.debug(f'User code took {t():.1f}s')
 
     logger.debug('Serializing results')
     try:
@@ -116,8 +128,7 @@ def main() -> None:
     except Exception as e:
         logger.exception(e)
         logger.critical(f'Could not serialize result ({type(ret)}).  Aborting.')
-        stop_thread.set()
-        sys.exit(-1)
+        cleanup_and_exit(thread, stop_thread, 3)
 
     logger.debug(f'Writing {out_file}.')
     try:
@@ -126,12 +137,8 @@ def main() -> None:
     except Exception as e:
         logger.exception(e)
         logger.critical(f'Error writing {out_file}.  Aborting.')
-        stop_thread.set()
-        sys.exit(-1)
-
-    if stop_thread is not None:
-        stop_thread.set()
-        thread.join()
+        cleanup_and_exit(thread, stop_thread, 4)
+    cleanup_and_exit(thread, stop_thread, 0)
 
 
 if __name__ == '__main__':
index 516138c50fadb3fc71f29ffe8bf9e29169ddff3c..c6c154c124482c9bf9c9e20950a9cca0fb22cb1d 100644 (file)
@@ -10,7 +10,7 @@ class Timer(object):
 
     e.g.
 
-        with timer.Timer() as t:
+        with stopwatch.Timer() as t:
             do_the_thing()
 
         walltime = t()
index b9746a80307ad512cee993aca29449e365e998b6..81b339ae3485d05e6f76248aa8fbeb70f52ef2a3 100644 (file)
@@ -18,7 +18,7 @@ import random
 import statistics
 import time
 import tempfile
-from typing import Callable, Dict, List
+from typing import Any, Callable, Dict, List, Optional
 import unittest
 import warnings
 
@@ -83,7 +83,7 @@ class PerfRegressionDataPersister(ABC):
         pass
 
     @abstractmethod
-    def load_performance_data(self) -> Dict[str, List[float]]:
+    def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
         pass
 
     @abstractmethod
@@ -98,7 +98,7 @@ class PerfRegressionDataPersister(ABC):
 class FileBasedPerfRegressionDataPersister(PerfRegressionDataPersister):
     def __init__(self, filename: str):
         self.filename = filename
-        self.traces_to_delete = []
+        self.traces_to_delete: List[str] = []
 
     def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
         with open(self.filename, 'rb') as f:
@@ -128,7 +128,7 @@ class DatabasePerfRegressionDataPersister(PerfRegressionDataPersister):
                 f'SELECT * FROM runtimes_by_function WHERE function = "{method_id}";'
             )
         )
-        ret = {method_id: []}
+        ret: Dict[str, List[float]] = {method_id: []}
         for result in results.all():
             ret[method_id].append(result['runtime'])
         results.close()
@@ -283,14 +283,16 @@ class RecordStdout(object):
 
     def __init__(self) -> None:
         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
-        self.recorder = None
+        self.recorder: Optional[contextlib.redirect_stdout] = None
 
     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
         self.recorder = contextlib.redirect_stdout(self.destination)
+        assert self.recorder
         self.recorder.__enter__()
         return lambda: self.destination
 
-    def __exit__(self, *args) -> bool:
+    def __exit__(self, *args) -> Optional[bool]:
+        assert self.recorder
         self.recorder.__exit__(*args)
         self.destination.seek(0)
         return None
@@ -310,14 +312,16 @@ class RecordStderr(object):
 
     def __init__(self) -> None:
         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
-        self.recorder = None
+        self.recorder: Optional[contextlib.redirect_stdout[Any]] = None
 
     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
-        self.recorder = contextlib.redirect_stderr(self.destination)
+        self.recorder = contextlib.redirect_stderr(self.destination)  # type: ignore
+        assert self.recorder
         self.recorder.__enter__()
         return lambda: self.destination
 
-    def __exit__(self, *args) -> bool:
+    def __exit__(self, *args) -> Optional[bool]:
+        assert self.recorder
         self.recorder.__exit__(*args)
         self.destination.seek(0)
         return None
@@ -331,7 +335,7 @@ class RecordMultipleStreams(object):
     def __init__(self, *files) -> None:
         self.files = [*files]
         self.destination = tempfile.SpooledTemporaryFile(mode='r+')
-        self.saved_writes = []
+        self.saved_writes: List[Callable[..., Any]] = []
 
     def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
         for f in self.files:
@@ -339,10 +343,11 @@ class RecordMultipleStreams(object):
             f.write = self.destination.write
         return lambda: self.destination
 
-    def __exit__(self, *args) -> bool:
+    def __exit__(self, *args) -> Optional[bool]:
         for f in self.files:
             f.write = self.saved_writes.pop()
         self.destination.seek(0)
+        return None
 
 
 if __name__ == '__main__':
index 8fcd65d4e3ef732bc529cd15f3d3de1fa13dede2..d9e4253e4cd77165fae6c4962d5957d74b619d9c 100644 (file)
@@ -198,7 +198,7 @@ class Unscrambler(object):
         unless you want to populate the same exact files.
 
         """
-        words_by_sigs = {}
+        words_by_sigs: Dict[int, str] = {}
         seen = set()
         with open(dictfile, "r") as f:
             for word in f:
index d54511ff362bc45ceaf5c95e15174523f0327be9..77b6e817198f47e23f87b2af71f353c1eec0488a 100644 (file)
@@ -98,6 +98,8 @@ class WaitablePresenceDetectorWithMemory(state_tracker.WaitableAutomaticStateTra
         if self.someone_is_home is None:
             raise Exception("Too Soon!")
         if self.someone_is_home:
+            assert self.someone_home_since
             return (True, self.someone_home_since)
         else:
+            assert self.everyone_gone_since
             return (False, self.everyone_gone_since)