More type annotations.
authorScott <[email protected]>
Wed, 2 Feb 2022 17:14:12 +0000 (09:14 -0800)
committerScott <[email protected]>
Wed, 2 Feb 2022 17:14:12 +0000 (09:14 -0800)
arper.py
base_presence.py
bootstrap.py
decorator_utils.py
dict_utils.py
executors.py
file_utils.py
histogram.py
persistent.py
smart_future.py
thread_utils.py

index 7700d5a994286f263dd14c35cf44a6e8a3d6c56b..ca5d1d5fa59c6ad75bb811f0664c90e798a4745d 100644 (file)
--- a/arper.py
+++ b/arper.py
@@ -47,7 +47,7 @@ cfg.add_argument(
 )
 
 
[email protected]_autoloaded_singleton()
[email protected]_autoloaded_singleton()  # type: ignore
 class Arper(persistent.Persistent):
     def __init__(self, cached_state: Optional[BiDict] = None) -> None:
         self.state = BiDict()
index 612193e1df3b84f0ee8fc52b86ab5c94ce7ee38d..3ceddb32dff9ab1bf2a17d2b3aa5547125ded2d9 100755 (executable)
@@ -4,7 +4,7 @@ import datetime
 from collections import defaultdict
 import logging
 import re
-from typing import Dict, List, Set
+from typing import Dict, List, Optional, Set
 import warnings
 
 # Note: this module is fairly early loaded.  Be aware of dependencies.
@@ -75,7 +75,7 @@ class PresenceDetection(object):
         ] = defaultdict(dict)
         self.names_by_mac: Dict[str, str] = {}
         self.dark_locations: Set[Location] = set()
-        self.last_update = None
+        self.last_update: Optional[datetime.datetime] = None
 
     def maybe_update(self) -> None:
         if self.last_update is None:
index e50cb386543aaef66fdd36cb9c149fbf009ef5ae..98da78cf6c1755c5fc3f5e42379e8d900abb10b7 100644 (file)
@@ -3,7 +3,9 @@
 import functools
 import logging
 import os
+import importlib
 from inspect import stack
+from typing import List
 import sys
 
 # This module is commonly used by others in here and should avoid
@@ -99,7 +101,7 @@ def handle_uncaught_exception(exc_type, exc_value, exc_tb):
                 original_hook(exc_type, exc_value, exc_tb)
 
 
-class ImportInterceptor(object):
+class ImportInterceptor(importlib.abc.MetaPathFinder):
     def __init__(self):
         import collect.trie
 
@@ -120,6 +122,11 @@ class ImportInterceptor(object):
     def should_ignore_filename(self, filename: str) -> bool:
         return 'importlib' in filename or 'six.py' in filename
 
+    def find_module(self, fullname, path):
+        raise Exception(
+            "This method has been deprecated since Python 3.4, please upgrade."
+        )
+
     def find_spec(self, loaded_module, path=None, target=None):
         s = stack()
         for x in range(3, len(s)):
@@ -147,6 +154,9 @@ class ImportInterceptor(object):
         logger.debug(msg)
         print(msg)
 
+    def invalidate_caches(self):
+        pass
+
     def find_importer(self, module: str):
         if module in self.tree_node_by_module:
             node = self.tree_node_by_module[module]
@@ -166,7 +176,7 @@ import_interceptor = None
 for arg in sys.argv:
     if arg == '--audit_import_events':
         import_interceptor = ImportInterceptor()
-        sys.meta_path = [import_interceptor] + sys.meta_path
+        sys.meta_path.insert(0, import_interceptor)
 
 
 def dump_all_objects() -> None:
index a956a214e2b9d5c2dad7bfb60ee95d528a49ccac..cd69639448425ce3a47073c7e423ea98d6704b2b 100644 (file)
@@ -223,7 +223,7 @@ def debug_count_calls(func: Callable) -> Callable:
         logger.info(msg)
         return func(*args, **kwargs)
 
-    wrapper_debug_count_calls.num_calls = 0
+    wrapper_debug_count_calls.num_calls = 0  # type: ignore
     return wrapper_debug_count_calls
 
 
@@ -366,7 +366,7 @@ def memoized(func: Callable) -> Callable:
             logger.debug(f"Returning memoized value for {func.__name__}")
         return wrapper_memoized.cache[cache_key]
 
-    wrapper_memoized.cache = dict()
+    wrapper_memoized.cache = dict()  # type: ignore
     return wrapper_memoized
 
 
index b1464c6bb9967ce4efa48318a288babd0ff322e9..451a87dadf08d8632ac6f593dfb592116a05779b 100644 (file)
@@ -70,7 +70,7 @@ def raise_on_duplicated_keys(key, new_value, old_value):
 def coalesce(
     inputs: Iterator[Dict[Any, Any]],
     *,
-    aggregation_function: Callable[[Any, Any], Any] = coalesce_by_creating_list,
+    aggregation_function: Callable[[Any, Any, Any], Any] = coalesce_by_creating_list,
 ) -> Dict[Any, Any]:
     """Merge N dicts into one dict containing the union of all keys /
     values in the input dicts.  When keys collide, apply the
@@ -223,7 +223,7 @@ def dict_to_key_value_lists(d: Dict[Any, Any]) -> Tuple[List[Any], List[Any]]:
     ['scott', '555-1212', '123 main st.', '12345']
 
     """
-    r = ([], [])
+    r: Tuple[List[Any], List[Any]] = ([], [])
     for (k, v) in d.items():
         r[0].append(k)
         r[1].append(v)
index e95ed716043b4962cd939b6d25885fd87826466a..990df03f19af253773c72e23eac201a9163e2931 100644 (file)
@@ -248,7 +248,7 @@ class BundleDetails:
     end_ts: float
     slower_than_local_p95: bool
     slower_than_global_p95: bool
-    src_bundle: BundleDetails
+    src_bundle: Optional[BundleDetails]
     is_cancelled: threading.Event
     was_cancelled: bool
     backup_bundles: Optional[List[BundleDetails]]
@@ -288,7 +288,7 @@ class RemoteExecutorStatus:
         self.worker_count: int = total_worker_count
         self.known_workers: Set[RemoteWorkerRecord] = set()
         self.start_time: float = time.time()
-        self.start_per_bundle: Dict[str, float] = defaultdict(float)
+        self.start_per_bundle: Dict[str, Optional[float]] = defaultdict(float)
         self.end_per_bundle: Dict[str, float] = defaultdict(float)
         self.finished_bundle_timings_per_worker: Dict[
             RemoteWorkerRecord, List[float]
@@ -345,7 +345,9 @@ class RemoteExecutorStatus:
         self.end_per_bundle[uuid] = ts
         self.in_flight_bundles_by_worker[worker].remove(uuid)
         if not was_cancelled:
-            bundle_latency = ts - self.start_per_bundle[uuid]
+            start = self.start_per_bundle[uuid]
+            assert start
+            bundle_latency = ts - start
             x = self.finished_bundle_timings_per_worker.get(worker, list())
             x.append(bundle_latency)
             self.finished_bundle_timings_per_worker[worker] = x
@@ -836,9 +838,10 @@ class RemoteExecutor(BaseExecutor):
         return self.wait_for_process(p, bundle, 0)
 
     def wait_for_process(
-        self, p: subprocess.Popen, bundle: BundleDetails, depth: int
+        self, p: Optional[subprocess.Popen], bundle: BundleDetails, depth: int
     ) -> Any:
         machine = bundle.machine
+        assert p
         pid = p.pid
         if depth > 3:
             logger.error(
@@ -981,10 +984,12 @@ class RemoteExecutor(BaseExecutor):
 
             # Tell the original to stop if we finished first.
             if not was_cancelled:
+                orig_bundle = bundle.src_bundle
+                assert orig_bundle
                 logger.debug(
-                    f'{bundle}: Notifying original {bundle.src_bundle.uuid} we beat them to it.'
+                    f'{bundle}: Notifying original {orig_bundle.uuid} we beat them to it.'
                 )
-                bundle.src_bundle.is_cancelled.set()
+                orig_bundle.is_cancelled.set()
         self.release_worker(bundle, was_cancelled=was_cancelled)
         return result
 
@@ -1068,7 +1073,9 @@ class RemoteExecutor(BaseExecutor):
         # they will move the result_file to this machine and let
         # the original pick them up and unpickle them.
 
-    def emergency_retry_nasty_bundle(self, bundle: BundleDetails) -> fut.Future:
+    def emergency_retry_nasty_bundle(
+        self, bundle: BundleDetails
+    ) -> Optional[fut.Future]:
         is_original = bundle.src_bundle is None
         bundle.worker = None
         avoid_last_machine = bundle.machine
index cd37f3069c70efd5c0f835e3362adbdf18d52e24..f273ea4f9a5fe08da8c1f15e3108c65b7c33d0ea 100644 (file)
@@ -14,7 +14,7 @@ import time
 from typing import Optional
 import glob
 from os.path import isfile, join, exists
-from typing import List
+from typing import List, TextIO
 from uuid import uuid4
 
 
@@ -332,11 +332,13 @@ def get_file_md5(filename: str) -> str:
 
 def set_file_raw_atime(filename: str, atime: float):
     mtime = get_file_raw_mtime(filename)
+    assert mtime
     os.utime(filename, (atime, mtime))
 
 
 def set_file_raw_mtime(filename: str, mtime: float):
     atime = get_file_raw_atime(filename)
+    assert atime
     os.utime(filename, (atime, mtime))
 
 
@@ -434,8 +436,8 @@ def describe_file_mtime(filename: str, *, brief=False) -> Optional[str]:
     return describe_file_timestamp(filename, lambda x: x.st_mtime, brief=brief)
 
 
-def touch_file(filename: str, *, mode: Optional[int] = 0o666) -> bool:
-    return pathlib.Path(filename, mode=mode).touch()
+def touch_file(filename: str, *, mode: Optional[int] = 0o666):
+    pathlib.Path(filename, mode=mode).touch()
 
 
 def expand_globs(in_filename: str):
@@ -470,14 +472,14 @@ class FileWriter(object):
         self.filename = filename
         uuid = uuid4()
         self.tempfile = f'{filename}-{uuid}.tmp'
-        self.handle = None
+        self.handle: Optional[TextIO] = None
 
-    def __enter__(self) -> io.TextIOWrapper:
+    def __enter__(self) -> TextIO:
         assert not does_path_exist(self.tempfile)
         self.handle = open(self.tempfile, mode="w")
         return self.handle
 
-    def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
+    def __exit__(self, exc_type, exc_val, exc_tb) -> Optional[bool]:
         if self.handle is not None:
             self.handle.close()
             cmd = f'/bin/mv -f {self.tempfile} {self.filename}'
index 993b5036d737cdcd705ef411e5cbcf70ba0da911..d45e93f328185f869c2b3af4cc832db280d14428 100644 (file)
@@ -85,8 +85,8 @@ class SimpleHistogram(Generic[T]):
             return txt
 
         max_label_width: Optional[int] = None
-        lowest_start: int = None
-        highest_end: int = None
+        lowest_start: Optional[int] = None
+        highest_end: Optional[int] = None
         for bucket in sorted(self.buckets, key=lambda x: x[0]):
             start = bucket[0]
             if lowest_start is None:
index d62dd6754eeffc78c1c09adea7c82e778f8450b3..7136559492ed0615366db854412c02c9024c7022 100644 (file)
@@ -64,6 +64,7 @@ def was_file_written_today(filename: str) -> bool:
         return False
 
     mtime = file_utils.get_file_mtime_as_datetime(filename)
+    assert mtime
     now = datetime.datetime.now()
     return mtime.month == now.month and mtime.day == now.day and mtime.year == now.year
 
@@ -80,6 +81,7 @@ def was_file_written_within_n_seconds(
         return False
 
     mtime = file_utils.get_file_mtime_as_datetime(filename)
+    assert mtime
     now = datetime.datetime.now()
     return (now - mtime).total_seconds() <= limit_seconds
 
@@ -126,7 +128,6 @@ class persistent_autoloaded_singleton(object):
         self.instance = None
 
     def __call__(self, cls: Persistent):
-        @functools.wraps(cls)
         def _load(*args, **kwargs):
 
             # If class has already been loaded, act like a singleton
index 604c149520464bcd9d8c5a55cf8905acd5ec34d4..1f6e6f0aedcf05966e536ec8f10f570c2175a3e4 100644 (file)
@@ -5,7 +5,7 @@ import concurrent
 import concurrent.futures as fut
 import logging
 import traceback
-from typing import Callable, List, TypeVar
+from typing import Callable, List, Set, TypeVar
 
 from overrides import overrides
 
@@ -27,11 +27,11 @@ def wait_any(
 ):
     real_futures = []
     smart_future_by_real_future = {}
-    completed_futures = set()
-    for f in futures:
-        assert type(f) == SmartFuture
-        real_futures.append(f.wrapped_future)
-        smart_future_by_real_future[f.wrapped_future] = f
+    completed_futures: Set[fut.Future] = set()
+    for x in futures:
+        assert type(x) == SmartFuture
+        real_futures.append(x.wrapped_future)
+        smart_future_by_real_future[x.wrapped_future] = x
 
     while len(completed_futures) != len(real_futures):
         newly_completed_futures = concurrent.futures.as_completed(real_futures)
@@ -59,9 +59,9 @@ def wait_all(
     log_exceptions: bool = True,
 ) -> None:
     real_futures = []
-    for f in futures:
-        assert type(f) == SmartFuture
-        real_futures.append(f.wrapped_future)
+    for x in futures:
+        assert type(x) == SmartFuture
+        real_futures.append(x.wrapped_future)
 
     (done, not_done) = concurrent.futures.wait(
         real_futures, timeout=None, return_when=concurrent.futures.ALL_COMPLETED
index 6035b09ef2df448a7dba5ad46c90af2e594c5dbf..22161275605d76a1199df8f18d536fd04e2fe17b 100644 (file)
@@ -61,7 +61,7 @@ def is_current_thread_main_thread() -> bool:
 
 def background_thread(
     _funct: Optional[Callable],
-) -> Tuple[threading.Thread, threading.Event]:
+) -> Callable[..., Tuple[threading.Thread, threading.Event]]:
     """A function decorator to create a background thread.
 
     *** Please note: the decorated function must take an shutdown ***