3 # © Copyright 2021-2022, Scott Gasch
5 """Helpers for unittests.
9 When you import this we automatically wrap the standard Python
10 `unittest.main` with a call to :meth:`pyutils.bootstrap.initialize`
11 so that we get logger config, commandline args, logging control,
12 etc... this works fine but may be unexpected behavior.
27 from abc import ABC, abstractmethod
28 from typing import Any, Callable, Dict, List, Literal, Optional
30 from pyutils import bootstrap, config, function_utils
32 logger = logging.getLogger(__name__)
33 cfg = config.add_commandline_args(
34 f'Logging ({__file__})', 'Args related to function decorators'
37 '--unittests_ignore_perf',
40 help='Ignore unittest perf regression in @check_method_for_perf_regressions',
43 '--unittests_num_perf_samples',
46 help='The count of perf timing samples we need to see before blocking slow runs on perf grounds',
49 '--unittests_drop_perf_traces',
53 help='The identifier (i.e. file!test_fixture) for which we should drop all perf data',
56 '--unittests_persistance_strategy',
57 choices=['FILE', 'DATABASE'],
59 help='Should we persist perf data in a file or db?',
62 '--unittests_perfdb_filename',
65 default=f'{os.environ["HOME"]}/.python_unittest_performance_db',
66 help='File in which to store perf data (iff --unittests_persistance_strategy is FILE)',
69 '--unittests_perfdb_spec',
72 default='mariadb+pymysql://python_unittest:<PASSWORD>@db.house:3306/python_unittest_performance',
73 help='Db connection spec for perf data (iff --unittest_persistance_strategy is DATABASE)',
76 unittest.main = bootstrap.initialize(unittest.main)
79 class PerfRegressionDataPersister(ABC):
80 """A base class that defines an interface for dealing with
81 persisting perf regression data.
85 def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
86 """Load the historical performance data for the supplied method.
89 method_id: the method for which we want historical perf data.
94 def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
95 """Save the historical performance data of the supplied method.
98 method_id: the method whose historical perf data we're saving.
99 data: the historical performance data being persisted.
104 def delete_performance_data(self, method_id: str):
105 """Delete the historical performance data of the supplied method.
108 method_id: the method whose data should be erased.
113 class FileBasedPerfRegressionDataPersister(PerfRegressionDataPersister):
114 """A perf regression data persister that uses files."""
116 def __init__(self, filename: str):
119 filename: the filename to save/load historical performance data
122 self.filename = filename
123 self.traces_to_delete: List[str] = []
125 def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
126 with open(self.filename, 'rb') as f:
127 return pickle.load(f)
129 def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
130 for trace in self.traces_to_delete:
134 with open(self.filename, 'wb') as f:
135 pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
137 def delete_performance_data(self, method_id: str):
138 self.traces_to_delete.append(method_id)
141 # class DatabasePerfRegressionDataPersister(PerfRegressionDataPersister):
142 # """A perf regression data persister that uses a database backend."""
144 # def __init__(self, dbspec: str):
146 # self.dbspec = dbspec
147 # self.engine = sa.create_engine(self.dbspec)
148 # self.conn = self.engine.connect()
150 # def load_performance_data(self, method_id: str) -> Dict[str, List[float]]:
151 # results = self.conn.execute(
152 # sa.text(f'SELECT * FROM runtimes_by_function WHERE function = "{method_id}";')
154 # ret: Dict[str, List[float]] = {method_id: []}
155 # for result in results.all():
156 # ret[method_id].append(result['runtime'])
160 # def save_performance_data(self, method_id: str, data: Dict[str, List[float]]):
161 # self.delete_performance_data(method_id)
162 # for (mid, perf_data) in data.items():
163 # sql = 'INSERT INTO runtimes_by_function (function, runtime) VALUES '
164 # for perf in perf_data:
165 # self.conn.execute(sql + f'("{mid}", {perf});')
167 # def delete_performance_data(self, method_id: str):
168 # sql = f'DELETE FROM runtimes_by_function WHERE function = "{method_id}"'
169 # self.conn.execute(sql)
172 def check_method_for_perf_regressions(func: Callable) -> Callable:
173 """This decorator is meant to be used on a method in a class that
174 subclasses :class:`unittest.TestCase`. When decorated, method
175 execution timing (i.e. performance) will be measured and compared
176 with a database of historical performance for the same method.
177 The wrapper will then fail the test with a perf-related message if
178 it has become too slow.
180 See also :meth:`check_all_methods_for_perf_regressions`.
184 class TestMyClass(unittest.TestCase):
186 @check_method_for_perf_regressions
187 def test_some_part_of_my_class(self):
192 @functools.wraps(func)
193 def wrapper_perf_monitor(*args, **kwargs):
194 if config.config['unittests_ignore_perf']:
195 return func(*args, **kwargs)
197 if config.config['unittests_persistance_strategy'] == 'FILE':
198 filename = config.config['unittests_perfdb_filename']
199 helper = FileBasedPerfRegressionDataPersister(filename)
200 elif config.config['unittests_persistance_strategy'] == 'DATABASE':
201 raise NotImplementedError(
202 'Persisting to a database is not implemented in this version'
205 raise Exception('Unknown/unexpected --unittests_persistance_strategy value')
207 func_id = function_utils.function_identifier(func)
208 func_name = func.__name__
209 logger.debug('Watching %s\'s performance...', func_name)
210 logger.debug('Canonical function identifier = "%s"', func_id)
213 perfdb = helper.load_performance_data(func_id)
214 except Exception as e:
216 msg = 'Unable to load perfdb; skipping it...'
221 # cmdline arg to forget perf traces for function
222 drop_id = config.config['unittests_drop_perf_traces']
223 if drop_id is not None:
224 helper.delete_performance_data(drop_id)
226 # Run the wrapped test paying attention to latency.
227 start_time = time.perf_counter()
228 value = func(*args, **kwargs)
229 end_time = time.perf_counter()
230 run_time = end_time - start_time
232 # See if it was unexpectedly slow.
233 hist = perfdb.get(func_id, [])
234 if len(hist) < config.config['unittests_num_perf_samples']:
235 hist.append(run_time)
236 logger.debug('Still establishing a perf baseline for %s', func_name)
238 stdev = statistics.stdev(hist)
239 logger.debug('For %s, performance stdev=%.2f', func_name, stdev)
241 logger.debug('For %s, slowest perf on record is %.2fs', func_name, slowest)
242 limit = slowest + stdev * 4
243 logger.debug('For %s, max acceptable runtime is %.2fs', func_name, limit)
245 'For %s, actual observed runtime was %.2fs', func_name, run_time
248 msg = f'''{func_id} performance has regressed unacceptably.
249 {slowest:f}s is the slowest runtime on record in {len(hist)} perf samples.
250 It just ran in {run_time:f}s which is 4+ stdevs slower than the slowest.
251 Here is the current, full db perf timing distribution:
257 slf = args[0] # Peek at the wrapped function's self ref.
258 slf.fail(msg) # ...to fail the testcase.
260 hist.append(run_time)
262 # Don't spam the database with samples; just pick a random
263 # sample from what we have and store that back.
264 n = min(config.config['unittests_num_perf_samples'], len(hist))
265 hist = random.sample(hist, n)
267 perfdb[func_id] = hist
268 helper.save_performance_data(func_id, perfdb)
271 return wrapper_perf_monitor
274 def check_all_methods_for_perf_regressions(prefix='test_'):
275 """This decorator is meant to apply to classes that subclass from
276 :class:`unittest.TestCase` and, when applied, has the affect of
277 decorating each method that matches the `prefix` given with the
278 :meth:`check_method_for_perf_regressions` wrapper (see above).
279 This wrapper causes us to measure perf and fail tests that regress
283 prefix: the prefix of method names to check for regressions
285 See also :meth:`check_method_for_perf_regressions` to check only
288 Example usage. By decorating the class, all methods with names
289 that begin with `test_` will be perf monitored::
291 import pyutils.unittest_utils as uu
293 @uu.check_all_methods_for_perf_regressions()
294 class TestMyClass(unittest.TestCase):
296 def test_some_part_of_my_class(self):
299 def test_som_other_part_of_my_class(self):
303 def decorate_the_testcase(cls):
304 if issubclass(cls, unittest.TestCase):
305 for name, m in inspect.getmembers(cls, inspect.isfunction):
306 if name.startswith(prefix):
307 setattr(cls, name, check_method_for_perf_regressions(m))
308 logger.debug('Wrapping %s:%s.', cls.__name__, name)
311 return decorate_the_testcase
314 class RecordStdout(contextlib.AbstractContextManager):
316 Records what is emitted to stdout into a buffer instead.
318 >>> with RecordStdout() as record:
319 ... print("This is a test!")
320 >>> print({record().readline()})
321 {'This is a test!\\n'}
325 def __init__(self) -> None:
327 self.destination = tempfile.SpooledTemporaryFile(mode='r+')
328 self.recorder: Optional[contextlib.redirect_stdout] = None
330 def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
331 self.recorder = contextlib.redirect_stdout(self.destination)
332 assert self.recorder is not None
333 self.recorder.__enter__()
334 return lambda: self.destination
336 def __exit__(self, *args) -> Literal[False]:
337 assert self.recorder is not None
338 self.recorder.__exit__(*args)
339 self.destination.seek(0)
343 class RecordStderr(contextlib.AbstractContextManager):
345 Record what is emitted to stderr.
348 >>> with RecordStderr() as record:
349 ... print("This is a test!", file=sys.stderr)
350 >>> print({record().readline()})
351 {'This is a test!\\n'}
355 def __init__(self) -> None:
357 self.destination = tempfile.SpooledTemporaryFile(mode='r+')
358 self.recorder: Optional[contextlib.redirect_stdout[Any]] = None
360 def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
361 self.recorder = contextlib.redirect_stderr(self.destination) # type: ignore
362 assert self.recorder is not None
363 self.recorder.__enter__()
364 return lambda: self.destination
366 def __exit__(self, *args) -> Literal[False]:
367 assert self.recorder is not None
368 self.recorder.__exit__(*args)
369 self.destination.seek(0)
373 class RecordMultipleStreams(contextlib.AbstractContextManager):
375 Record the output to more than one stream.
379 with RecordMultipleStreams(sys.stdout, sys.stderr) as r:
380 print("This is a test!", file=sys.stderr)
381 print("This is too", file=sys.stdout)
383 print(r().readlines())
388 def __init__(self, *files) -> None:
390 self.files = [*files]
391 self.destination = tempfile.SpooledTemporaryFile(mode='r+')
392 self.saved_writes: List[Callable[..., Any]] = []
394 def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]:
396 self.saved_writes.append(f.write)
397 f.write = self.destination.write
398 return lambda: self.destination
400 def __exit__(self, *args) -> Literal[False]:
402 f.write = self.saved_writes.pop()
403 self.destination.seek(0)
407 if __name__ == '__main__':