Used isort to sort imports. Also added to the git pre-commit hook.
[python_utils.git] / state_tracker.py
1 #!/usr/bin/env python3
2
3 import datetime
4 import logging
5 import threading
6 import time
7 from abc import ABC, abstractmethod
8 from typing import Dict, Optional
9
10 import pytz
11
12 from thread_utils import background_thread
13
14 logger = logging.getLogger(__name__)
15
16
17 class StateTracker(ABC):
18     """A base class that maintains and updates a global state via an
19     update routine.  Instances of this class should be periodically
20     invoked via the heartbeat() method.  This method, in turn, invokes
21     update() with update_ids according to a schedule / periodicity
22     provided to the c'tor.
23
24     """
25
26     def __init__(self, update_ids_to_update_secs: Dict[str, float]) -> None:
27         """The update_ids_to_update_secs dict parameter describes one or more
28         update types (unique update_ids) and the periodicity(ies), in
29         seconds, at which it/they should be invoked.
30
31         Note that, when more than one update is overdue, they will be
32         invoked in order by their update_ids so care in choosing these
33         identifiers may be in order.
34
35         """
36         self.update_ids_to_update_secs = update_ids_to_update_secs
37         self.last_reminder_ts: Dict[str, Optional[datetime.datetime]] = {}
38         for x in update_ids_to_update_secs.keys():
39             self.last_reminder_ts[x] = None
40
41     @abstractmethod
42     def update(
43         self,
44         update_id: str,
45         now: datetime.datetime,
46         last_invocation: Optional[datetime.datetime],
47     ) -> None:
48         """Put whatever you want here.  The update_id will be the string
49         passed to the c'tor as a key in the Dict.  It will only be
50         tapped on the shoulder, at most, every update_secs seconds.
51         The now param is the approximate current timestamp and the
52         last_invocation param is the last time you were invoked (or
53         None on the first invocation)
54
55         """
56         pass
57
58     def heartbeat(self, *, force_all_updates_to_run: bool = False) -> None:
59         """Invoke this method to cause the StateTracker instance to identify
60         and invoke any overdue updates based on the schedule passed to
61         the c'tor.  In the base StateTracker class, this method must
62         be invoked manually with a thread from external code.
63
64         If more than one type of update (update_id) are overdue,
65         they will be invoked in order based on their update_ids.
66
67         Setting force_all_updates_to_run will invoke all updates
68         (ordered by update_id) immediately ignoring whether or not
69         they are due.
70
71         """
72         self.now = datetime.datetime.now(tz=pytz.timezone("US/Pacific"))
73         for update_id in sorted(self.last_reminder_ts.keys()):
74             if force_all_updates_to_run:
75                 logger.debug('Forcing all updates to run')
76                 self.update(update_id, self.now, self.last_reminder_ts[update_id])
77                 self.last_reminder_ts[update_id] = self.now
78                 return
79
80             refresh_secs = self.update_ids_to_update_secs[update_id]
81             last_run = self.last_reminder_ts[update_id]
82             if last_run is None:  # Never run before
83                 logger.debug(f'id {update_id} has never been run; running it now')
84                 self.update(update_id, self.now, self.last_reminder_ts[update_id])
85                 self.last_reminder_ts[update_id] = self.now
86             else:
87                 delta = self.now - last_run
88                 if delta.total_seconds() >= refresh_secs:  # Is overdue?
89                     logger.debug(f'id {update_id} is overdue; running it now')
90                     self.update(
91                         update_id,
92                         self.now,
93                         self.last_reminder_ts[update_id],
94                     )
95                     self.last_reminder_ts[update_id] = self.now
96
97
98 class AutomaticStateTracker(StateTracker):
99     """Just like HeartbeatCurrentState but you don't need to pump the
100     heartbeat; it runs on a background thread.  Call .shutdown() to
101     terminate the updates.
102
103     """
104
105     @background_thread
106     def pace_maker(self, should_terminate) -> None:
107         """Entry point for a background thread to own calling heartbeat()
108         at regular intervals so that the main thread doesn't need to do
109         so.
110
111         """
112         while True:
113             if should_terminate.is_set():
114                 logger.debug('pace_maker noticed event; shutting down')
115                 return
116             self.heartbeat()
117             logger.debug(f'pace_maker is sleeping for {self.sleep_delay}s')
118             time.sleep(self.sleep_delay)
119
120     def __init__(
121         self,
122         update_ids_to_update_secs: Dict[str, float],
123         *,
124         override_sleep_delay: Optional[float] = None,
125     ) -> None:
126         import math_utils
127
128         super().__init__(update_ids_to_update_secs)
129         if override_sleep_delay is not None:
130             logger.debug(f'Overriding sleep delay to {override_sleep_delay}')
131             self.sleep_delay = override_sleep_delay
132         else:
133             periods_list = list(update_ids_to_update_secs.values())
134             self.sleep_delay = math_utils.gcd_float_sequence(periods_list)
135             logger.info(f'Computed sleep_delay={self.sleep_delay}')
136         (thread, stop_event) = self.pace_maker()
137         self.should_terminate = stop_event
138         self.updater_thread = thread
139
140     def shutdown(self):
141         """Terminates the background thread and waits for it to tear down.
142         This may block for as long as self.sleep_delay.
143
144         """
145         logger.debug('Setting shutdown event and waiting for background thread.')
146         self.should_terminate.set()
147         self.updater_thread.join()
148         logger.debug('Background thread terminated.')
149
150
151 class WaitableAutomaticStateTracker(AutomaticStateTracker):
152     """This is an AutomaticStateTracker that exposes a wait method which
153     will block the calling thread until the state changes with an
154     optional timeout.  The caller should check the return value of
155     wait; it will be true if something changed and false if the wait
156     simply timed out.  If the return value is true, the instance
157     should be reset() before wait is called again.
158
159     Example usage:
160
161         detector = waitable_presence.WaitableAutomaticStateSubclass()
162         while True:
163             changed = detector.wait(timeout=60 * 5)
164             if changed:
165                 detector.reset()
166                 # Figure out what changed and react
167             else:
168                 # Just a timeout; no need to reset.  Maybe do something
169                 # else before looping up into wait again.
170
171     """
172
173     def __init__(
174         self,
175         update_ids_to_update_secs: Dict[str, float],
176         *,
177         override_sleep_delay: Optional[float] = None,
178     ) -> None:
179         self._something_changed = threading.Event()
180         super().__init__(
181             update_ids_to_update_secs, override_sleep_delay=override_sleep_delay
182         )
183
184     def something_changed(self):
185         self._something_changed.set()
186
187     def did_something_change(self) -> bool:
188         return self._something_changed.is_set()
189
190     def reset(self):
191         self._something_changed.clear()
192
193     def wait(self, *, timeout=None):
194         return self._something_changed.wait(timeout=timeout)