Logging + documentation.
[python_utils.git] / logging_utils.py
1 #!/usr/bin/env python3
2
3 """Utilities related to logging."""
4
5 import collections
6 import contextlib
7 import datetime
8 import enum
9 import io
10 import logging
11 from logging.handlers import RotatingFileHandler, SysLogHandler
12 import os
13 import pytz
14 import random
15 import sys
16 from typing import Callable, Iterable, Mapping, Optional
17
18 # This module is commonly used by others in here and should avoid
19 # taking any unnecessary dependencies back on them.
20 import argparse_utils
21 import config
22
23 cfg = config.add_commandline_args(
24     f'Logging ({__file__})',
25     'Args related to logging')
26 cfg.add_argument(
27     '--logging_config_file',
28     type=argparse_utils.valid_filename,
29     default=None,
30     metavar='FILENAME',
31     help='Config file containing the logging setup, see: https://docs.python.org/3/howto/logging.html#logging-advanced-tutorial',
32 )
33 cfg.add_argument(
34     '--logging_level',
35     type=str,
36     default='INFO',
37     choices=['NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
38     metavar='LEVEL',
39     help='The level below which to squelch log messages.',
40 )
41 cfg.add_argument(
42     '--logging_format',
43     type=str,
44     default='%(levelname).1s:%(asctime)s: %(message)s',
45     help='The format for lines logged via the logger module.'
46 )
47 cfg.add_argument(
48     '--logging_date_format',
49     type=str,
50     default='%Y/%m/%dT%H:%M:%S.%f%z',
51     metavar='DATEFMT',
52     help='The format of any dates in --logging_format.'
53 )
54 cfg.add_argument(
55     '--logging_console',
56     action=argparse_utils.ActionNoYes,
57     default=True,
58     help='Should we log to the console (stderr)',
59 )
60 cfg.add_argument(
61     '--logging_filename',
62     type=str,
63     default=None,
64     metavar='FILENAME',
65     help='The filename of the logfile to write.'
66 )
67 cfg.add_argument(
68     '--logging_filename_maxsize',
69     type=int,
70     default=(1024*1024),
71     metavar='#BYTES',
72     help='The maximum size (in bytes) to write to the logging_filename.'
73 )
74 cfg.add_argument(
75     '--logging_filename_count',
76     type=int,
77     default=2,
78     metavar='COUNT',
79     help='The number of logging_filename copies to keep before deleting.'
80 )
81 cfg.add_argument(
82     '--logging_syslog',
83     action=argparse_utils.ActionNoYes,
84     default=False,
85     help='Should we log to localhost\'s syslog.'
86 )
87 cfg.add_argument(
88     '--logging_debug_threads',
89     action=argparse_utils.ActionNoYes,
90     default=False,
91     help='Should we prepend pid/tid data to all log messages?'
92 )
93 cfg.add_argument(
94     '--logging_info_is_print',
95     action=argparse_utils.ActionNoYes,
96     default=False,
97     help='logging.info also prints to stdout.'
98 )
99 cfg.add_argument(
100     '--logging_squelch_repeats_enabled',
101     action=argparse_utils.ActionNoYes,
102     default=True,
103     help='Do we allow code to indicate that it wants to squelch repeated logging messages or should we always log?'
104 )
105 cfg.add_argument(
106     '--logging_probabilistically_enabled',
107     action=argparse_utils.ActionNoYes,
108     default=True,
109     help='Do we allow probabilistic logging (for code that wants it) or should we always log?'
110 )
111 # See also: OutputMultiplexer
112 cfg.add_argument(
113     '--logging_captures_prints',
114     action=argparse_utils.ActionNoYes,
115     default=False,
116     help='When calling print, also log.info automatically.'
117 )
118
119 built_in_print = print
120
121
122 def function_identifier(f: Callable) -> str:
123     """
124     Given a callable function, return a string that identifies it.
125     Usually that string is just __module__:__name__ but there's a
126     corner case: when __module__ is __main__ (i.e. the callable is
127     defined in the same module as __main__).  In this case,
128     f.__module__ returns "__main__" instead of the file that it is
129     defined in.  Work around this using pathlib.Path (see below).
130
131     >>> function_identifier(function_identifier)
132     'logging_utils:function_identifier'
133
134     """
135     if f.__module__ == '__main__':
136         from pathlib import Path
137         import __main__
138         module = __main__.__file__
139         module = Path(module).stem
140         return f'{module}:{f.__name__}'
141     else:
142         return f'{f.__module__}:{f.__name__}'
143
144
145 # A map from logging_callsite_id -> count of logged messages.
146 squelched_logging_counts: Mapping[str, int] = {}
147
148
149 def squelch_repeated_log_messages(squelch_after_n_repeats: int) -> Callable:
150     """
151     A decorator that marks a function as interested in having the logging
152     messages that it produces be squelched (ignored) after it logs the
153     same message more than N times.
154
155     Note: this decorator affects *ALL* logging messages produced
156     within the decorated function.  That said, messages must be
157     identical in order to be squelched.  For example, if the same line
158     of code produces different messages (because of, e.g., a format
159     string), the messages are considered to be different.
160
161     """
162     def squelch_logging_wrapper(f: Callable):
163         identifier = function_identifier(f)
164         squelched_logging_counts[identifier] = squelch_after_n_repeats
165         return f
166     return squelch_logging_wrapper
167
168
169 class SquelchRepeatedMessagesFilter(logging.Filter):
170     """
171     A filter that only logs messages from a given site with the same
172     (exact) message at the same logging level N times and ignores
173     subsequent attempts to log.
174
175     This filter only affects logging messages that repeat more than
176     a threshold number of times from functions that are tagged with
177     the @logging_utils.squelched_logging_ok decorator.
178
179     """
180     def __init__(self) -> None:
181         self.counters = collections.Counter()
182         super().__init__()
183
184     def filter(self, record: logging.LogRecord) -> bool:
185         id1 = f'{record.module}:{record.funcName}'
186         if id1 not in squelched_logging_counts:
187             return True
188         threshold = squelched_logging_counts[id1]
189         logsite = f'{record.pathname}+{record.lineno}+{record.levelno}+{record.msg}'
190         count = self.counters[logsite]
191         self.counters[logsite] += 1
192         return count < threshold
193
194
195 # A map from function_identifier -> probability of logging (0.0%..100.0%)
196 probabilistic_logging_levels: Mapping[str, float] = {}
197
198
199 def logging_is_probabilistic(probability_of_logging: float) -> Callable:
200     """
201     A decorator that indicates that all logging statements within the
202     scope of a particular (marked) function are not deterministic
203     (i.e. they do not always unconditionally log) but rather are
204     probabilistic (i.e. they log N% of the time randomly).
205
206     This affects *ALL* logging statements within the marked function.
207
208     """
209     def probabilistic_logging_wrapper(f: Callable):
210         identifier = function_identifier(f)
211         probabilistic_logging_levels[identifier] = probability_of_logging
212         return f
213     return probabilistic_logging_wrapper
214
215
216 class ProbabilisticFilter(logging.Filter):
217     """
218     A filter that logs messages probabilistically (i.e. randomly at some
219     percent chance).
220
221     This filter only affects logging messages from functions that have
222     been tagged with the @logging_utils.probabilistic_logging decorator.
223
224     """
225     def filter(self, record: logging.LogRecord) -> bool:
226         id1 = f'{record.module}:{record.funcName}'
227         if id1 not in probabilistic_logging_levels:
228             return True
229         threshold = probabilistic_logging_levels[id1]
230         return (random.random() * 100.0) <= threshold
231
232
233 class OnlyInfoFilter(logging.Filter):
234     """
235     A filter that only logs messages produced at the INFO logging
236     level.  This is used by the logging_info_is_print commandline
237     option to select a subset of the logging stream to send to a
238     stdout handler.
239
240     """
241     def filter(self, record):
242         return record.levelno == logging.INFO
243
244
245 class MillisecondAwareFormatter(logging.Formatter):
246     """
247     A formatter for adding milliseconds to log messages.
248
249     """
250     converter = datetime.datetime.fromtimestamp
251
252     def formatTime(self, record, datefmt=None):
253         ct = MillisecondAwareFormatter.converter(
254             record.created, pytz.timezone("US/Pacific")
255         )
256         if datefmt:
257             s = ct.strftime(datefmt)
258         else:
259             t = ct.strftime("%Y-%m-%d %H:%M:%S")
260             s = "%s,%03d" % (t, record.msecs)
261         return s
262
263
264 def initialize_logging(logger=None) -> logging.Logger:
265     assert config.has_been_parsed()
266     if logger is None:
267         logger = logging.getLogger()       # Root logger
268
269     if config.config['logging_config_file'] is not None:
270         logging.config.fileConfig('logging.conf')
271         return logger
272
273     handlers = []
274     numeric_level = getattr(
275         logging,
276         config.config['logging_level'].upper(),
277         None
278     )
279     if not isinstance(numeric_level, int):
280         raise ValueError('Invalid level: %s' % config.config['logging_level'])
281
282     fmt = config.config['logging_format']
283     if config.config['logging_debug_threads']:
284         fmt = f'%(process)d.%(thread)d|{fmt}'
285
286     if config.config['logging_syslog']:
287         if sys.platform not in ('win32', 'cygwin'):
288             handler = SysLogHandler()
289 #            for k, v in encoded_priorities.items():
290 #                handler.encodePriority(k, v)
291             handler.setFormatter(
292                 MillisecondAwareFormatter(
293                     fmt=fmt,
294                     datefmt=config.config['logging_date_format'],
295                 )
296             )
297             handler.setLevel(numeric_level)
298             handlers.append(handler)
299
300     if config.config['logging_filename']:
301         handler = RotatingFileHandler(
302             config.config['logging_filename'],
303             maxBytes = config.config['logging_filename_maxsize'],
304             backupCount = config.config['logging_filename_count'],
305         )
306         handler.setLevel(numeric_level)
307         handler.setFormatter(
308             MillisecondAwareFormatter(
309                 fmt=fmt,
310                 datefmt=config.config['logging_date_format'],
311             )
312         )
313         handlers.append(handler)
314
315     if config.config['logging_console']:
316         handler = logging.StreamHandler(sys.stderr)
317         handler.setLevel(numeric_level)
318         handler.setFormatter(
319             MillisecondAwareFormatter(
320                 fmt=fmt,
321                 datefmt=config.config['logging_date_format'],
322             )
323         )
324         handlers.append(handler)
325
326     if len(handlers) == 0:
327         handlers.append(logging.NullHandler())
328
329     for handler in handlers:
330         logger.addHandler(handler)
331
332     if config.config['logging_info_is_print']:
333         handler = logging.StreamHandler(sys.stdout)
334         handler.addFilter(OnlyInfoFilter())
335         logger.addHandler(handler)
336
337     if config.config['logging_squelch_repeats_enabled']:
338         for handler in handlers:
339             handler.addFilter(SquelchRepeatedMessagesFilter())
340
341     if config.config['logging_probabilistically_enabled']:
342         for handler in handlers:
343             handler.addFilter(ProbabilisticFilter())
344
345     logger.setLevel(numeric_level)
346     logger.propagate = False
347
348     if config.config['logging_captures_prints']:
349         import builtins
350         global built_in_print
351
352         def print_and_also_log(*arg, **kwarg):
353             f = kwarg.get('file', None)
354             if f == sys.stderr:
355                 logger.warning(*arg)
356             else:
357                 logger.info(*arg)
358             built_in_print(*arg, **kwarg)
359         builtins.print = print_and_also_log
360
361     return logger
362
363
364 def get_logger(name: str = ""):
365     logger = logging.getLogger(name)
366     return initialize_logging(logger)
367
368
369 def tprint(*args, **kwargs) -> None:
370     """Legacy function for printing a message augmented with thread id."""
371
372     if config.config['logging_debug_threads']:
373         from thread_utils import current_thread_id
374         print(f'{current_thread_id()}', end="")
375         print(*args, **kwargs)
376     else:
377         pass
378
379
380 def dprint(*args, **kwargs) -> None:
381     """Legacy function used to print to stderr."""
382
383     print(*args, file=sys.stderr, **kwargs)
384
385
386 class OutputMultiplexer(object):
387     """
388     A class that broadcasts printed messages to several sinks (including
389     various logging levels, different files, different file handles,
390     the house log, etc...)
391
392     """
393     class Destination(enum.IntEnum):
394         """Bits in the destination_bitv bitvector.  Used to indicate the
395         output destination."""
396         LOG_DEBUG = 0x01         # -\
397         LOG_INFO = 0x02          #  |
398         LOG_WARNING = 0x04       #   > Should provide logger to the c'tor.
399         LOG_ERROR = 0x08         #  |
400         LOG_CRITICAL = 0x10      # _/
401         FILENAMES = 0x20         # Must provide a filename to the c'tor.
402         FILEHANDLES = 0x40       # Must provide a handle to the c'tor.
403         HLOG = 0x80
404         ALL_LOG_DESTINATIONS = (
405             LOG_DEBUG | LOG_INFO | LOG_WARNING | LOG_ERROR | LOG_CRITICAL
406         )
407         ALL_OUTPUT_DESTINATIONS = 0x8F
408
409     def __init__(self,
410                  destination_bitv: int,
411                  *,
412                  logger=None,
413                  filenames: Optional[Iterable[str]] = None,
414                  handles: Optional[Iterable[io.TextIOWrapper]] = None):
415         if logger is None:
416             logger = logging.getLogger(None)
417         self.logger = logger
418
419         if filenames is not None:
420             self.f = [
421                 open(filename, 'wb', buffering=0) for filename in filenames
422             ]
423         else:
424             if destination_bitv & OutputMultiplexer.FILENAMES:
425                 raise ValueError(
426                     "Filenames argument is required if bitv & FILENAMES"
427                 )
428             self.f = None
429
430         if handles is not None:
431             self.h = [handle for handle in handles]
432         else:
433             if destination_bitv & OutputMultiplexer.Destination.FILEHANDLES:
434                 raise ValueError(
435                     "Handle argument is required if bitv & FILEHANDLES"
436                 )
437             self.h = None
438
439         self.set_destination_bitv(destination_bitv)
440
441     def get_destination_bitv(self):
442         return self.destination_bitv
443
444     def set_destination_bitv(self, destination_bitv: int):
445         if destination_bitv & self.Destination.FILENAMES and self.f is None:
446             raise ValueError(
447                 "Filename argument is required if bitv & FILENAMES"
448             )
449         if destination_bitv & self.Destination.FILEHANDLES and self.h is None:
450             raise ValueError(
451                     "Handle argument is required if bitv & FILEHANDLES"
452                 )
453         self.destination_bitv = destination_bitv
454
455     def print(self, *args, **kwargs):
456         from string_utils import sprintf, strip_escape_sequences
457         end = kwargs.pop("end", None)
458         if end is not None:
459             if not isinstance(end, str):
460                 raise TypeError("end must be None or a string")
461         sep = kwargs.pop("sep", None)
462         if sep is not None:
463             if not isinstance(sep, str):
464                 raise TypeError("sep must be None or a string")
465         if kwargs:
466             raise TypeError("invalid keyword arguments to print()")
467         buf = sprintf(*args, end="", sep=sep)
468         if sep is None:
469             sep = " "
470         if end is None:
471             end = "\n"
472         if end == '\n':
473             buf += '\n'
474         if (
475                 self.destination_bitv & self.Destination.FILENAMES and
476                 self.f is not None
477         ):
478             for _ in self.f:
479                 _.write(buf.encode('utf-8'))
480                 _.flush()
481
482         if (
483                 self.destination_bitv & self.Destination.FILEHANDLES and
484                 self.h is not None
485         ):
486             for _ in self.h:
487                 _.write(buf)
488                 _.flush()
489
490         buf = strip_escape_sequences(buf)
491         if self.logger is not None:
492             if self.destination_bitv & self.Destination.LOG_DEBUG:
493                 self.logger.debug(buf)
494             if self.destination_bitv & self.Destination.LOG_INFO:
495                 self.logger.info(buf)
496             if self.destination_bitv & self.Destination.LOG_WARNING:
497                 self.logger.warning(buf)
498             if self.destination_bitv & self.Destination.LOG_ERROR:
499                 self.logger.error(buf)
500             if self.destination_bitv & self.Destination.LOG_CRITICAL:
501                 self.logger.critical(buf)
502         if self.destination_bitv & self.Destination.HLOG:
503             hlog(buf)
504
505     def close(self):
506         if self.f is not None:
507             for _ in self.f:
508                 _.close()
509
510
511 class OutputMultiplexerContext(OutputMultiplexer, contextlib.ContextDecorator):
512     """
513     A context that uses an OutputMultiplexer.  e.g.
514
515         with OutputMultiplexerContext(
516                 OutputMultiplexer.LOG_INFO |
517                 OutputMultiplexer.LOG_DEBUG |
518                 OutputMultiplexer.FILENAMES |
519                 OutputMultiplexer.FILEHANDLES,
520                 filenames = [ '/tmp/foo.log', '/var/log/bar.log' ],
521                 handles = [ f, g ]
522             ) as mplex:
523                 mplex.print("This is a log message!")
524
525     """
526     def __init__(self,
527                  destination_bitv: OutputMultiplexer.Destination,
528                  *,
529                  logger = None,
530                  filenames = None,
531                  handles = None):
532         super().__init__(
533             destination_bitv,
534             logger=logger,
535             filenames=filenames,
536             handles=handles)
537
538     def __enter__(self):
539         return self
540
541     def __exit__(self, etype, value, traceback) -> bool:
542         super().close()
543         if etype is not None:
544             return False
545         return True
546
547
548 def hlog(message: str) -> None:
549     """Write a message to the house log."""
550
551     message = message.replace("'", "'\"'\"'")
552     os.system(f"/usr/bin/logger -p local7.info -- '{message}'")
553
554
555 if __name__ == '__main__':
556     import doctest
557     doctest.testmod()