More cleanup, yey!
[python_utils.git] / state_tracker.py
1 #!/usr/bin/env python3
2
3 """Several helpers to keep track of internal state via periodic
4 polling.  StateTracker expects to be invoked periodically to maintain
5 state whereas the others automatically update themselves and,
6 optionally, expose an event for client code to wait on state changes."""
7
8 import datetime
9 import logging
10 import threading
11 import time
12 from abc import ABC, abstractmethod
13 from typing import Dict, Optional
14
15 import pytz
16
17 from thread_utils import background_thread
18
19 logger = logging.getLogger(__name__)
20
21
22 class StateTracker(ABC):
23     """A base class that maintains and updates a global state via an
24     update routine.  Instances of this class should be periodically
25     invoked via the heartbeat() method.  This method, in turn, invokes
26     update() with update_ids according to a schedule / periodicity
27     provided to the c'tor.
28
29     """
30
31     def __init__(self, update_ids_to_update_secs: Dict[str, float]) -> None:
32         """The update_ids_to_update_secs dict parameter describes one or more
33         update types (unique update_ids) and the periodicity(ies), in
34         seconds, at which it/they should be invoked.
35
36         Note that, when more than one update is overdue, they will be
37         invoked in order by their update_ids so care in choosing these
38         identifiers may be in order.
39
40         """
41         self.update_ids_to_update_secs = update_ids_to_update_secs
42         self.last_reminder_ts: Dict[str, Optional[datetime.datetime]] = {}
43         self.now: Optional[datetime.datetime] = None
44         for x in update_ids_to_update_secs.keys():
45             self.last_reminder_ts[x] = None
46
47     @abstractmethod
48     def update(
49         self,
50         update_id: str,
51         now: datetime.datetime,
52         last_invocation: Optional[datetime.datetime],
53     ) -> None:
54         """Put whatever you want here.  The update_id will be the string
55         passed to the c'tor as a key in the Dict.  It will only be
56         tapped on the shoulder, at most, every update_secs seconds.
57         The now param is the approximate current timestamp and the
58         last_invocation param is the last time you were invoked (or
59         None on the first invocation)
60
61         """
62         pass
63
64     def heartbeat(self, *, force_all_updates_to_run: bool = False) -> None:
65         """Invoke this method to cause the StateTracker instance to identify
66         and invoke any overdue updates based on the schedule passed to
67         the c'tor.  In the base StateTracker class, this method must
68         be invoked manually with a thread from external code.
69
70         If more than one type of update (update_id) are overdue,
71         they will be invoked in order based on their update_ids.
72
73         Setting force_all_updates_to_run will invoke all updates
74         (ordered by update_id) immediately ignoring whether or not
75         they are due.
76
77         """
78         self.now = datetime.datetime.now(tz=pytz.timezone("US/Pacific"))
79         for update_id in sorted(self.last_reminder_ts.keys()):
80             if force_all_updates_to_run:
81                 logger.debug('Forcing all updates to run')
82                 self.update(update_id, self.now, self.last_reminder_ts[update_id])
83                 self.last_reminder_ts[update_id] = self.now
84                 return
85
86             refresh_secs = self.update_ids_to_update_secs[update_id]
87             last_run = self.last_reminder_ts[update_id]
88             if last_run is None:  # Never run before
89                 logger.debug('id %s has never been run; running it now', update_id)
90                 self.update(update_id, self.now, self.last_reminder_ts[update_id])
91                 self.last_reminder_ts[update_id] = self.now
92             else:
93                 delta = self.now - last_run
94                 if delta.total_seconds() >= refresh_secs:  # Is overdue?
95                     logger.debug('id %s is overdue; running it now', update_id)
96                     self.update(
97                         update_id,
98                         self.now,
99                         self.last_reminder_ts[update_id],
100                     )
101                     self.last_reminder_ts[update_id] = self.now
102
103
104 class AutomaticStateTracker(StateTracker):
105     """Just like HeartbeatCurrentState but you don't need to pump the
106     heartbeat; it runs on a background thread.  Call .shutdown() to
107     terminate the updates.
108
109     """
110
111     @background_thread
112     def pace_maker(self, should_terminate) -> None:
113         """Entry point for a background thread to own calling heartbeat()
114         at regular intervals so that the main thread doesn't need to do
115         so.
116
117         """
118         while True:
119             if should_terminate.is_set():
120                 logger.debug('pace_maker noticed event; shutting down')
121                 return
122             self.heartbeat()
123             logger.debug('pace_maker is sleeping for %.1fs', self.sleep_delay)
124             time.sleep(self.sleep_delay)
125
126     def __init__(
127         self,
128         update_ids_to_update_secs: Dict[str, float],
129         *,
130         override_sleep_delay: Optional[float] = None,
131     ) -> None:
132         import math_utils
133
134         super().__init__(update_ids_to_update_secs)
135         if override_sleep_delay is not None:
136             logger.debug('Overriding sleep delay to %.1f', override_sleep_delay)
137             self.sleep_delay = override_sleep_delay
138         else:
139             periods_list = list(update_ids_to_update_secs.values())
140             self.sleep_delay = math_utils.gcd_float_sequence(periods_list)
141             logger.info('Computed sleep_delay=%.1f', self.sleep_delay)
142         (thread, stop_event) = self.pace_maker()
143         self.should_terminate = stop_event
144         self.updater_thread = thread
145
146     def shutdown(self):
147         """Terminates the background thread and waits for it to tear down.
148         This may block for as long as self.sleep_delay.
149
150         """
151         logger.debug('Setting shutdown event and waiting for background thread.')
152         self.should_terminate.set()
153         self.updater_thread.join()
154         logger.debug('Background thread terminated.')
155
156
157 class WaitableAutomaticStateTracker(AutomaticStateTracker):
158     """This is an AutomaticStateTracker that exposes a wait method which
159     will block the calling thread until the state changes with an
160     optional timeout.  The caller should check the return value of
161     wait; it will be true if something changed and false if the wait
162     simply timed out.  If the return value is true, the instance
163     should be reset() before wait is called again.
164
165     Example usage:
166
167         detector = waitable_presence.WaitableAutomaticStateSubclass()
168         while True:
169             changed = detector.wait(timeout=60 * 5)
170             if changed:
171                 detector.reset()
172                 # Figure out what changed and react
173             else:
174                 # Just a timeout; no need to reset.  Maybe do something
175                 # else before looping up into wait again.
176
177     """
178
179     def __init__(
180         self,
181         update_ids_to_update_secs: Dict[str, float],
182         *,
183         override_sleep_delay: Optional[float] = None,
184     ) -> None:
185         self._something_changed = threading.Event()
186         super().__init__(update_ids_to_update_secs, override_sleep_delay=override_sleep_delay)
187
188     def something_changed(self):
189         self._something_changed.set()
190
191     def did_something_change(self) -> bool:
192         return self._something_changed.is_set()
193
194     def reset(self):
195         self._something_changed.clear()
196
197     def wait(self, *, timeout=None):
198         return self._something_changed.wait(timeout=timeout)