Oops, update type hints in geocode.
[python_utils.git] / state_tracker.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """Several helpers to keep track of internal state via periodic
6 polling.  StateTracker expects to be invoked periodically to maintain
7 state whereas the others automatically update themselves and,
8 optionally, expose an event for client code to wait on state changes.
9
10 """
11
12 import datetime
13 import logging
14 import threading
15 import time
16 from abc import ABC, abstractmethod
17 from typing import Dict, Optional
18
19 import pytz
20
21 from thread_utils import background_thread
22
23 logger = logging.getLogger(__name__)
24
25
26 class StateTracker(ABC):
27     """A base class that maintains and updates a global state via an
28     update routine.  Instances of this class should be periodically
29     invoked via the heartbeat() method.  This method, in turn, invokes
30     update() with update_ids according to a schedule / periodicity
31     provided to the c'tor.
32
33     """
34
35     def __init__(self, update_ids_to_update_secs: Dict[str, float]) -> None:
36         """The update_ids_to_update_secs dict parameter describes one or more
37         update types (unique update_ids) and the periodicity(ies), in
38         seconds, at which it/they should be invoked.
39
40         Note that, when more than one update is overdue, they will be
41         invoked in order by their update_ids so care in choosing these
42         identifiers may be in order.
43
44         """
45         self.update_ids_to_update_secs = update_ids_to_update_secs
46         self.last_reminder_ts: Dict[str, Optional[datetime.datetime]] = {}
47         self.now: Optional[datetime.datetime] = None
48         for x in update_ids_to_update_secs.keys():
49             self.last_reminder_ts[x] = None
50
51     @abstractmethod
52     def update(
53         self,
54         update_id: str,
55         now: datetime.datetime,
56         last_invocation: Optional[datetime.datetime],
57     ) -> None:
58         """Put whatever you want here.  The update_id will be the string
59         passed to the c'tor as a key in the Dict.  It will only be
60         tapped on the shoulder, at most, every update_secs seconds.
61         The now param is the approximate current timestamp and the
62         last_invocation param is the last time you were invoked (or
63         None on the first invocation)
64
65         """
66         pass
67
68     def heartbeat(self, *, force_all_updates_to_run: bool = False) -> None:
69         """Invoke this method to cause the StateTracker instance to identify
70         and invoke any overdue updates based on the schedule passed to
71         the c'tor.  In the base StateTracker class, this method must
72         be invoked manually with a thread from external code.
73
74         If more than one type of update (update_id) are overdue,
75         they will be invoked in order based on their update_ids.
76
77         Setting force_all_updates_to_run will invoke all updates
78         (ordered by update_id) immediately ignoring whether or not
79         they are due.
80
81         """
82         self.now = datetime.datetime.now(tz=pytz.timezone("US/Pacific"))
83         for update_id in sorted(self.last_reminder_ts.keys()):
84             if force_all_updates_to_run:
85                 logger.debug('Forcing all updates to run')
86                 self.update(update_id, self.now, self.last_reminder_ts[update_id])
87                 self.last_reminder_ts[update_id] = self.now
88                 return
89
90             refresh_secs = self.update_ids_to_update_secs[update_id]
91             last_run = self.last_reminder_ts[update_id]
92             if last_run is None:  # Never run before
93                 logger.debug('id %s has never been run; running it now', update_id)
94                 self.update(update_id, self.now, self.last_reminder_ts[update_id])
95                 self.last_reminder_ts[update_id] = self.now
96             else:
97                 delta = self.now - last_run
98                 if delta.total_seconds() >= refresh_secs:  # Is overdue?
99                     logger.debug('id %s is overdue; running it now', update_id)
100                     self.update(
101                         update_id,
102                         self.now,
103                         self.last_reminder_ts[update_id],
104                     )
105                     self.last_reminder_ts[update_id] = self.now
106
107
108 class AutomaticStateTracker(StateTracker):
109     """Just like HeartbeatCurrentState but you don't need to pump the
110     heartbeat; it runs on a background thread.  Call .shutdown() to
111     terminate the updates.
112
113     """
114
115     @background_thread
116     def pace_maker(self, should_terminate: threading.Event) -> None:
117         """Entry point for a background thread to own calling heartbeat()
118         at regular intervals so that the main thread doesn't need to do
119         so.
120
121         """
122         while True:
123             if should_terminate.is_set():
124                 logger.debug('pace_maker noticed event; shutting down')
125                 return
126             self.heartbeat()
127             logger.debug('pace_maker is sleeping for %.1fs', self.sleep_delay)
128             time.sleep(self.sleep_delay)
129
130     def __init__(
131         self,
132         update_ids_to_update_secs: Dict[str, float],
133         *,
134         override_sleep_delay: Optional[float] = None,
135     ) -> None:
136         import math_utils
137
138         super().__init__(update_ids_to_update_secs)
139         if override_sleep_delay is not None:
140             logger.debug('Overriding sleep delay to %.1f', override_sleep_delay)
141             self.sleep_delay = override_sleep_delay
142         else:
143             periods_list = list(update_ids_to_update_secs.values())
144             self.sleep_delay = math_utils.gcd_float_sequence(periods_list)
145             logger.info('Computed sleep_delay=%.1f', self.sleep_delay)
146         (thread, stop_event) = self.pace_maker()
147         self.should_terminate = stop_event
148         self.updater_thread = thread
149
150     def shutdown(self):
151         """Terminates the background thread and waits for it to tear down.
152         This may block for as long as self.sleep_delay.
153
154         """
155         logger.debug('Setting shutdown event and waiting for background thread.')
156         self.should_terminate.set()
157         self.updater_thread.join()
158         logger.debug('Background thread terminated.')
159
160
161 class WaitableAutomaticStateTracker(AutomaticStateTracker):
162     """This is an AutomaticStateTracker that exposes a wait method which
163     will block the calling thread until the state changes with an
164     optional timeout.  The caller should check the return value of
165     wait; it will be true if something changed and false if the wait
166     simply timed out.  If the return value is true, the instance
167     should be reset() before wait is called again.
168
169     Example usage:
170
171         detector = waitable_presence.WaitableAutomaticStateSubclass()
172         while True:
173             changed = detector.wait(timeout=60 * 5)
174             if changed:
175                 detector.reset()
176                 # Figure out what changed and react
177             else:
178                 # Just a timeout; no need to reset.  Maybe do something
179                 # else before looping up into wait again.
180
181     """
182
183     def __init__(
184         self,
185         update_ids_to_update_secs: Dict[str, float],
186         *,
187         override_sleep_delay: Optional[float] = None,
188     ) -> None:
189         self._something_changed = threading.Event()
190         super().__init__(update_ids_to_update_secs, override_sleep_delay=override_sleep_delay)
191
192     def something_changed(self):
193         self._something_changed.set()
194
195     def did_something_change(self) -> bool:
196         return self._something_changed.is_set()
197
198     def reset(self):
199         self._something_changed.clear()
200
201     def wait(self, *, timeout=None):
202         return self._something_changed.wait(timeout=timeout)