3 # © Copyright 2021-2023, Scott Gasch
5 """Helpers for unittests.
9 When you import this we automatically wrap the standard Python
10 `unittest.main` with a call to :meth:`pyutils.bootstrap.initialize`
11 so that we get logger config, commandline args, logging control,
12 etc... this works fine but may be unexpected behavior.
27 from abc import ABC, abstractmethod
28 from typing import Any, Callable, Dict, List, Literal, Optional
30 from pyutils import bootstrap, config, function_utils
32 logger = logging.getLogger(__name__)
33 cfg = config.add_commandline_args(
34 f"Logging ({__file__})", "Args related to function decorators"
37 "--unittests_ignore_perf",
40 help="Ignore unittest perf regression in @check_method_for_perf_regressions",
43 "--unittests_num_perf_samples",
46 help="The count of perf timing samples we need to see before blocking slow runs on perf grounds",
49 "--unittests_drop_perf_traces",
53 help="The identifier (i.e. file!test_fixture) for which we should drop all perf data",
56 "--unittests_persistance_strategy",
57 choices=["FILE", "DATABASE"],
59 help="Should we persist perf data in a file or db?",
62 "--unittests_perfdb_filename",
65 default=f'{os.environ["HOME"]}/.python_unittest_performance_db',
66 help="File in which to store perf data (iff --unittests_persistance_strategy is FILE)",
69 "--unittests_perfdb_spec",
72 default="mariadb+pymysql://python_unittest:<PASSWORD>@db.house:3306/python_unittest_performance",
73 help="Db connection spec for perf data (iff --unittest_persistance_strategy is DATABASE)",
76 unittest.main = bootstrap.initialize(unittest.main)
79 class PerfRegressionDataPersister(ABC):
80 """A base class that defines an interface for dealing with
81 persisting perf regression data.
85 def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
86 """Load the historical performance data for the supplied method.
89 method_id: the method for which we want historical perf data.
94 def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
95 """Save the historical performance data of the supplied method.
98 method_id: the method whose historical perf data we're saving.
99 data: the historical performance data being persisted.
104 def delete_performance_data(self, method_id: str):
105 """Delete the historical performance data of the supplied method.
108 method_id: the method whose data should be erased.
113 class FileBasedPerfRegressionDataPersister(PerfRegressionDataPersister):
114 """A perf regression data persister that uses files."""
116 def __init__(self, filename: str):
119 filename: the filename to save/load historical performance data
122 self.filename = filename
123 self.traces_to_delete: List[str] = []
125 def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
126 with open(self.filename, "rb") as f:
127 return pickle.load(f)
129 def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
130 for trace in self.traces_to_delete:
134 with open(self.filename, "wb") as f:
135 pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
137 def delete_performance_data(self, method_id: str):
138 self.traces_to_delete.append(method_id)
141 # class DatabasePerfRegressionDataPersister(PerfRegressionDataPersister):
142 # """A perf regression data persister that uses a database backend."""
144 # def __init__(self, dbspec: str):
146 # self.dbspec = dbspec
147 # self.engine = sa.create_engine(self.dbspec)
148 # self.conn = self.engine.connect()
150 # def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
151 # results = self.conn.execute(
152 # sa.text(f'SELECT * FROM runtimes_by_function WHERE function = "{method_id}";')
154 # ret: Dict[str, List[float]] = {method_id: []}
155 # for result in results.all():
156 # ret[method_id].append(result['runtime'])
160 # def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
161 # self.delete_performance_data(method_id)
162 # for (mid, perf_data) in data.items():
163 # sql = 'INSERT INTO runtimes_by_function (function, runtime) VALUES '
164 # for perf in perf_data:
165 # self.conn.execute(sql + f'("{mid}", {perf});')
167 # def delete_performance_data(self, method_id: str):
168 # sql = f'DELETE FROM runtimes_by_function WHERE function = "{method_id}"'
169 # self.conn.execute(sql)
172 def check_method_for_perf_regressions(func: Callable) -> Callable:
173 """This decorator is meant to be used on a method in a class that
174 subclasses :class:`unittest.TestCase`. When decorated, method
175 execution timing (i.e. performance) will be measured and compared
176 with a database of historical performance for the same method.
177 The wrapper will then fail the test with a perf-related message if
178 it has become too slow.
180 See also :meth:`check_all_methods_for_perf_regressions`.
184 class TestMyClass(unittest.TestCase):
186 @check_method_for_perf_regressions
187 def test_some_part_of_my_class(self):
192 @functools.wraps(func)
193 def wrapper_perf_monitor(*args, **kwargs):
194 if config.config["unittests_ignore_perf"]:
195 return func(*args, **kwargs)
197 if config.config["unittests_persistance_strategy"] == "FILE":
198 filename = config.config["unittests_perfdb_filename"]
199 helper = FileBasedPerfRegressionDataPersister(filename)
200 elif config.config["unittests_persistance_strategy"] == "DATABASE":
201 raise NotImplementedError(
202 "Persisting to a database is not implemented in this version"
205 raise Exception("Unknown/unexpected --unittests_persistance_strategy value")
207 func_id = function_utils.function_identifier(func)
208 func_name = func.__name__
209 logger.debug("Watching %s's performance...", func_name)
210 logger.debug('Canonical function identifier = "%s"', func_id)
213 perfdb = helper.load_performance_data(func_id)
215 msg = "Unable to load perfdb; skipping it..."
216 logger.exception(msg)
220 # cmdline arg to forget perf traces for function
221 drop_id = config.config["unittests_drop_perf_traces"]
222 if drop_id is not None:
223 helper.delete_performance_data(drop_id)
225 # Run the wrapped test paying attention to latency.
226 start_time = time.perf_counter()
227 value = func(*args, **kwargs)
228 end_time = time.perf_counter()
229 run_time = end_time - start_time
231 # See if it was unexpectedly slow.
232 hist = perfdb.get(func_id, [])
233 if len(hist) < config.config["unittests_num_perf_samples"]:
234 hist.append(run_time)
235 logger.debug("Still establishing a perf baseline for %s", func_name)
237 stdev = statistics.stdev(hist)
238 logger.debug("For %s, performance stdev=%.2f", func_name, stdev)
240 logger.debug("For %s, slowest perf on record is %.2fs", func_name, slowest)
241 limit = slowest + stdev * 4
242 logger.debug("For %s, max acceptable runtime is %.2fs", func_name, limit)
244 "For %s, actual observed runtime was %.2fs", func_name, run_time
247 msg = f"""{func_id} performance has regressed unacceptably.
248 {slowest:f}s is the slowest runtime on record in {len(hist)} perf samples.
249 It just ran in {run_time:f}s which is 4+ stdevs slower than the slowest.
250 Here is the current, full db perf timing distribution:
256 slf = args[0] # Peek at the wrapped function's self ref.
257 slf.fail(msg) # ...to fail the testcase.
259 hist.append(run_time)
261 # Don't spam the database with samples; just pick a random
262 # sample from what we have and store that back.
263 n = min(config.config["unittests_num_perf_samples"], len(hist))
264 hist = random.sample(hist, n)
266 perfdb[func_id] = hist
267 helper.save_performance_data(func_id, perfdb)
270 return wrapper_perf_monitor
273 def check_all_methods_for_perf_regressions(prefix: str = "test_"):
274 """This decorator is meant to apply to classes that subclass from
275 :class:`unittest.TestCase` and, when applied, has the affect of
276 decorating each method that matches the `prefix` given with the
277 :meth:`check_method_for_perf_regressions` wrapper (see above).
278 This wrapper causes us to measure perf and fail tests that regress
282 prefix: the prefix of method names to check for regressions
284 See also :meth:`check_method_for_perf_regressions` to check only
287 Example usage. By decorating the class, all methods with names
288 that begin with `test_` will be perf monitored::
290 import pyutils.unittest_utils as uu
292 @uu.check_all_methods_for_perf_regressions()
293 class TestMyClass(unittest.TestCase):
295 def test_some_part_of_my_class(self):
298 def test_som_other_part_of_my_class(self):
302 def decorate_the_testcase(cls):
303 if issubclass(cls, unittest.TestCase):
304 for name, m in inspect.getmembers(cls, inspect.isfunction):
305 if name.startswith(prefix):
306 setattr(cls, name, check_method_for_perf_regressions(m))
307 logger.debug("Wrapping %s:%s.", cls.__name__, name)
310 return decorate_the_testcase
313 class RecordStdout(contextlib.AbstractContextManager):
315 Records what is emitted to stdout into a buffer instead.
317 >>> with RecordStdout() as record:
318 ... print("This is a test!")
319 >>> print({record().readline()})
320 {'This is a test!\\n'}
324 def __init__(self) -> None:
326 self.destination = tempfile.SpooledTemporaryFile(mode="r+")
327 self.recorder: Optional[contextlib.redirect_stdout] = None
329 def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
330 self.recorder = contextlib.redirect_stdout(self.destination)
331 assert self.recorder is not None
332 self.recorder.__enter__()
333 return lambda: self.destination
335 def __exit__(self, *args) -> Literal[False]:
336 assert self.recorder is not None
337 self.recorder.__exit__(*args)
338 self.destination.seek(0)
342 class RecordStderr(contextlib.AbstractContextManager):
344 Record what is emitted to stderr.
347 >>> with RecordStderr() as record:
348 ... print("This is a test!", file=sys.stderr)
349 >>> print({record().readline()})
350 {'This is a test!\\n'}
354 def __init__(self) -> None:
356 self.destination = tempfile.SpooledTemporaryFile(mode="r+")
357 self.recorder: Optional[contextlib.redirect_stdout[Any]] = None
359 def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
360 self.recorder = contextlib.redirect_stderr(self.destination) # type: ignore
361 assert self.recorder is not None
362 self.recorder.__enter__()
363 return lambda: self.destination
365 def __exit__(self, *args) -> Literal[False]:
366 assert self.recorder is not None
367 self.recorder.__exit__(*args)
368 self.destination.seek(0)
372 class RecordMultipleStreams(contextlib.AbstractContextManager):
374 Record the output to more than one stream.
378 with RecordMultipleStreams(sys.stdout, sys.stderr) as r:
379 print("This is a test!", file=sys.stderr)
380 print("This is too", file=sys.stdout)
382 print(r().readlines())
387 def __init__(self, *files) -> None:
389 self.files = [*files]
390 self.destination = tempfile.SpooledTemporaryFile(mode="r+")
391 self.saved_writes: List[Callable[..., Any]] = []
393 def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
395 self.saved_writes.append(f.write)
396 f.write = self.destination.write
397 return lambda: self.destination
399 def __exit__(self, *args) -> Literal[False]:
401 f.write = self.saved_writes.pop()
402 self.destination.seek(0)
406 if __name__ == "__main__":