Make smart futures avoid polling.
[python_utils.git] / logging_utils.py
1 #!/usr/bin/env python3
2
3 """Utilities related to logging."""
4
5 import collections
6 import contextlib
7 import datetime
8 import enum
9 import io
10 import logging
11 from logging.handlers import RotatingFileHandler, SysLogHandler
12 import os
13 import pytz
14 import random
15 import sys
16 from typing import Callable, Iterable, Mapping, Optional
17
18 from overrides import overrides
19
20 # This module is commonly used by others in here and should avoid
21 # taking any unnecessary dependencies back on them.
22 import argparse_utils
23 import config
24
25 cfg = config.add_commandline_args(
26     f'Logging ({__file__})',
27     'Args related to logging')
28 cfg.add_argument(
29     '--logging_config_file',
30     type=argparse_utils.valid_filename,
31     default=None,
32     metavar='FILENAME',
33     help='Config file containing the logging setup, see: https://docs.python.org/3/howto/logging.html#logging-advanced-tutorial',
34 )
35 cfg.add_argument(
36     '--logging_level',
37     type=str,
38     default='INFO',
39     choices=['NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
40     metavar='LEVEL',
41     help='The global default level below which to squelch log messages; see also --lmodule',
42 )
43 cfg.add_argument(
44     '--logging_format',
45     type=str,
46     default='%(levelname).1s:%(asctime)s: %(message)s',
47     help='The format for lines logged via the logger module.'
48 )
49 cfg.add_argument(
50     '--logging_date_format',
51     type=str,
52     default='%Y/%m/%dT%H:%M:%S.%f%z',
53     metavar='DATEFMT',
54     help='The format of any dates in --logging_format.'
55 )
56 cfg.add_argument(
57     '--logging_console',
58     action=argparse_utils.ActionNoYes,
59     default=True,
60     help='Should we log to the console (stderr)',
61 )
62 cfg.add_argument(
63     '--logging_filename',
64     type=str,
65     default=None,
66     metavar='FILENAME',
67     help='The filename of the logfile to write.'
68 )
69 cfg.add_argument(
70     '--logging_filename_maxsize',
71     type=int,
72     default=(1024*1024),
73     metavar='#BYTES',
74     help='The maximum size (in bytes) to write to the logging_filename.'
75 )
76 cfg.add_argument(
77     '--logging_filename_count',
78     type=int,
79     default=7,
80     metavar='COUNT',
81     help='The number of logging_filename copies to keep before deleting.'
82 )
83 cfg.add_argument(
84     '--logging_syslog',
85     action=argparse_utils.ActionNoYes,
86     default=False,
87     help='Should we log to localhost\'s syslog.'
88 )
89 cfg.add_argument(
90     '--logging_debug_threads',
91     action=argparse_utils.ActionNoYes,
92     default=False,
93     help='Should we prepend pid/tid data to all log messages?'
94 )
95 cfg.add_argument(
96     '--logging_debug_modules',
97     action=argparse_utils.ActionNoYes,
98     default=False,
99     help='Should we prepend module/function data to all log messages?'
100 )
101 cfg.add_argument(
102     '--logging_info_is_print',
103     action=argparse_utils.ActionNoYes,
104     default=False,
105     help='logging.info also prints to stdout.'
106 )
107 cfg.add_argument(
108     '--logging_squelch_repeats_enabled',
109     action=argparse_utils.ActionNoYes,
110     default=True,
111     help='Do we allow code to indicate that it wants to squelch repeated logging messages or should we always log?'
112 )
113 cfg.add_argument(
114     '--logging_probabilistically_enabled',
115     action=argparse_utils.ActionNoYes,
116     default=True,
117     help='Do we allow probabilistic logging (for code that wants it) or should we always log?'
118 )
119 # See also: OutputMultiplexer
120 cfg.add_argument(
121     '--logging_captures_prints',
122     action=argparse_utils.ActionNoYes,
123     default=False,
124     help='When calling print, also log.info automatically.'
125 )
126 cfg.add_argument(
127     '--lmodule',
128     type=str,
129     metavar='<SCOPE>=<LEVEL>[,<SCOPE>=<LEVEL>...]',
130     help=(
131         'Allows per-scope logging levels which override the global level set with --logging-level.' +
132         'Pass a space separated list of <scope>=<level> where <scope> is one of: module, ' +
133         'module:function, or :function and <level> is a logging level (e.g. INFO, DEBUG...)'
134     )
135 )
136
137
138 built_in_print = print
139
140
141 def function_identifier(f: Callable) -> str:
142     """
143     Given a callable function, return a string that identifies it.
144     Usually that string is just __module__:__name__ but there's a
145     corner case: when __module__ is __main__ (i.e. the callable is
146     defined in the same module as __main__).  In this case,
147     f.__module__ returns "__main__" instead of the file that it is
148     defined in.  Work around this using pathlib.Path (see below).
149
150     >>> function_identifier(function_identifier)
151     'logging_utils:function_identifier'
152
153     """
154     if f.__module__ == '__main__':
155         from pathlib import Path
156         import __main__
157         module = __main__.__file__
158         module = Path(module).stem
159         return f'{module}:{f.__name__}'
160     else:
161         return f'{f.__module__}:{f.__name__}'
162
163
164 # A map from logging_callsite_id -> count of logged messages.
165 squelched_logging_counts: Mapping[str, int] = {}
166
167
168 def squelch_repeated_log_messages(squelch_after_n_repeats: int) -> Callable:
169     """
170     A decorator that marks a function as interested in having the logging
171     messages that it produces be squelched (ignored) after it logs the
172     same message more than N times.
173
174     Note: this decorator affects *ALL* logging messages produced
175     within the decorated function.  That said, messages must be
176     identical in order to be squelched.  For example, if the same line
177     of code produces different messages (because of, e.g., a format
178     string), the messages are considered to be different.
179
180     """
181     def squelch_logging_wrapper(f: Callable):
182         identifier = function_identifier(f)
183         squelched_logging_counts[identifier] = squelch_after_n_repeats
184         return f
185     return squelch_logging_wrapper
186
187
188 class SquelchRepeatedMessagesFilter(logging.Filter):
189     """
190     A filter that only logs messages from a given site with the same
191     (exact) message at the same logging level N times and ignores
192     subsequent attempts to log.
193
194     This filter only affects logging messages that repeat more than
195     a threshold number of times from functions that are tagged with
196     the @logging_utils.squelched_logging_ok decorator.
197
198     """
199     def __init__(self) -> None:
200         self.counters = collections.Counter()
201         super().__init__()
202
203     @overrides
204     def filter(self, record: logging.LogRecord) -> bool:
205         id1 = f'{record.module}:{record.funcName}'
206         if id1 not in squelched_logging_counts:
207             return True
208         threshold = squelched_logging_counts[id1]
209         logsite = f'{record.pathname}+{record.lineno}+{record.levelno}+{record.msg}'
210         count = self.counters[logsite]
211         self.counters[logsite] += 1
212         return count < threshold
213
214
215 class DynamicPerScopeLoggingLevelFilter(logging.Filter):
216     """Only interested in seeing logging messages from an allow list of
217     module names or module:function names.  Block others.
218
219     """
220     @staticmethod
221     def level_name_to_level(name: str) -> int:
222         numeric_level = getattr(
223             logging,
224             name,
225             None
226         )
227         if not isinstance(numeric_level, int):
228             raise ValueError('Invalid level: {name}')
229         return numeric_level
230
231     def __init__(
232             self,
233             default_logging_level: int,
234             per_scope_logging_levels: str,
235     ) -> None:
236         super().__init__()
237         self.valid_levels = set(['NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'])
238         self.default_logging_level = default_logging_level
239         self.level_by_scope = {}
240         if per_scope_logging_levels is not None:
241             for chunk in per_scope_logging_levels.split(','):
242                 if '=' not in chunk:
243                     print(
244                         f'Malformed lmodule directive: "{chunk}", missing "=".  Ignored.',
245                         file=sys.stderr
246                     )
247                     continue
248                 try:
249                     (scope, level) = chunk.split('=')
250                 except ValueError:
251                     print(
252                         f'Malformed lmodule directive: "{chunk}".  Ignored.',
253                         file=sys.stderr
254                     )
255                     continue
256                 scope = scope.strip()
257                 level = level.strip().upper()
258                 if level not in self.valid_levels:
259                     print(
260                         f'Malformed lmodule directive: "{chunk}", bad level.  Ignored.',
261                         file=sys.stderr
262                     )
263                     continue
264                 self.level_by_scope[scope] = (
265                     DynamicPerScopeLoggingLevelFilter.level_name_to_level(
266                         level
267                     )
268                 )
269
270     @overrides
271     def filter(self, record: logging.LogRecord) -> bool:
272         # First try to find a logging level by scope (--lmodule)
273         if len(self.level_by_scope) > 0:
274             min_level = None
275             for scope in (
276                     record.module,
277                     f'{record.module}:{record.funcName}',
278                     f':{record.funcName}'
279             ):
280                 level = self.level_by_scope.get(scope, None)
281                 if level is not None:
282                     if min_level is None or level < min_level:
283                         min_level = level
284
285             # If we found one, use it instead of the global default level.
286             if min_level is not None:
287                 return record.levelno >= min_level
288
289         # Otherwise, use the global logging level (--logging_level)
290         return record.levelno >= self.default_logging_level
291
292
293 # A map from function_identifier -> probability of logging (0.0%..100.0%)
294 probabilistic_logging_levels: Mapping[str, float] = {}
295
296
297 def logging_is_probabilistic(probability_of_logging: float) -> Callable:
298     """
299     A decorator that indicates that all logging statements within the
300     scope of a particular (marked) function are not deterministic
301     (i.e. they do not always unconditionally log) but rather are
302     probabilistic (i.e. they log N% of the time randomly).
303
304     This affects *ALL* logging statements within the marked function.
305
306     """
307     def probabilistic_logging_wrapper(f: Callable):
308         identifier = function_identifier(f)
309         probabilistic_logging_levels[identifier] = probability_of_logging
310         return f
311     return probabilistic_logging_wrapper
312
313
314 class ProbabilisticFilter(logging.Filter):
315     """
316     A filter that logs messages probabilistically (i.e. randomly at some
317     percent chance).
318
319     This filter only affects logging messages from functions that have
320     been tagged with the @logging_utils.probabilistic_logging decorator.
321
322     """
323     @overrides
324     def filter(self, record: logging.LogRecord) -> bool:
325         id1 = f'{record.module}:{record.funcName}'
326         if id1 not in probabilistic_logging_levels:
327             return True
328         threshold = probabilistic_logging_levels[id1]
329         return (random.random() * 100.0) <= threshold
330
331
332 class OnlyInfoFilter(logging.Filter):
333     """
334     A filter that only logs messages produced at the INFO logging
335     level.  This is used by the logging_info_is_print commandline
336     option to select a subset of the logging stream to send to a
337     stdout handler.
338
339     """
340     @overrides
341     def filter(self, record: logging.LogRecord):
342         return record.levelno == logging.INFO
343
344
345 class MillisecondAwareFormatter(logging.Formatter):
346     """
347     A formatter for adding milliseconds to log messages.
348
349     """
350     converter = datetime.datetime.fromtimestamp
351
352     @overrides
353     def formatTime(self, record, datefmt=None):
354         ct = MillisecondAwareFormatter.converter(
355             record.created, pytz.timezone("US/Pacific")
356         )
357         if datefmt:
358             s = ct.strftime(datefmt)
359         else:
360             t = ct.strftime("%Y-%m-%d %H:%M:%S")
361             s = "%s,%03d" % (t, record.msecs)
362         return s
363
364
365 def initialize_logging(logger=None) -> logging.Logger:
366     assert config.has_been_parsed()
367     if logger is None:
368         logger = logging.getLogger()       # Root logger
369
370     if config.config['logging_config_file'] is not None:
371         logging.config.fileConfig('logging.conf')
372         return logger
373
374     handlers = []
375
376     # Global default logging level (--logging_level)
377     default_logging_level = getattr(
378         logging,
379         config.config['logging_level'].upper(),
380         None
381     )
382     if not isinstance(default_logging_level, int):
383         raise ValueError('Invalid level: %s' % config.config['logging_level'])
384
385     fmt = config.config['logging_format']
386     if config.config['logging_debug_threads']:
387         fmt = f'%(process)d.%(thread)d|{fmt}'
388     if config.config['logging_debug_modules']:
389         fmt = f'%(filename)s:%(funcName)s:%(lineno)s|{fmt}'
390
391     if config.config['logging_syslog']:
392         if sys.platform not in ('win32', 'cygwin'):
393             handler = SysLogHandler()
394             handler.setFormatter(
395                 MillisecondAwareFormatter(
396                     fmt=fmt,
397                     datefmt=config.config['logging_date_format'],
398                 )
399             )
400             handlers.append(handler)
401
402     if config.config['logging_filename']:
403         handler = RotatingFileHandler(
404             config.config['logging_filename'],
405             maxBytes = config.config['logging_filename_maxsize'],
406             backupCount = config.config['logging_filename_count'],
407         )
408         handler.setFormatter(
409             MillisecondAwareFormatter(
410                 fmt=fmt,
411                 datefmt=config.config['logging_date_format'],
412             )
413         )
414         handlers.append(handler)
415
416     if config.config['logging_console']:
417         handler = logging.StreamHandler(sys.stderr)
418         handler.setFormatter(
419             MillisecondAwareFormatter(
420                 fmt=fmt,
421                 datefmt=config.config['logging_date_format'],
422             )
423         )
424         handlers.append(handler)
425
426     if len(handlers) == 0:
427         handlers.append(logging.NullHandler())
428
429     for handler in handlers:
430         logger.addHandler(handler)
431
432     if config.config['logging_info_is_print']:
433         handler = logging.StreamHandler(sys.stdout)
434         handler.addFilter(OnlyInfoFilter())
435         logger.addHandler(handler)
436
437     if config.config['logging_squelch_repeats_enabled']:
438         for handler in handlers:
439             handler.addFilter(SquelchRepeatedMessagesFilter())
440
441     if config.config['logging_probabilistically_enabled']:
442         for handler in handlers:
443             handler.addFilter(ProbabilisticFilter())
444
445     for handler in handlers:
446         handler.addFilter(
447             DynamicPerScopeLoggingLevelFilter(
448                 default_logging_level,
449                 config.config['lmodule'],
450             )
451         )
452     logger.setLevel(0)
453     logger.propagate = False
454
455     if config.config['logging_captures_prints']:
456         import builtins
457         global built_in_print
458
459         def print_and_also_log(*arg, **kwarg):
460             f = kwarg.get('file', None)
461             if f == sys.stderr:
462                 logger.warning(*arg)
463             else:
464                 logger.info(*arg)
465             built_in_print(*arg, **kwarg)
466         builtins.print = print_and_also_log
467
468     return logger
469
470
471 def get_logger(name: str = ""):
472     logger = logging.getLogger(name)
473     return initialize_logging(logger)
474
475
476 def tprint(*args, **kwargs) -> None:
477     """Legacy function for printing a message augmented with thread id."""
478
479     if config.config['logging_debug_threads']:
480         from thread_utils import current_thread_id
481         print(f'{current_thread_id()}', end="")
482         print(*args, **kwargs)
483     else:
484         pass
485
486
487 def dprint(*args, **kwargs) -> None:
488     """Legacy function used to print to stderr."""
489
490     print(*args, file=sys.stderr, **kwargs)
491
492
493 class OutputMultiplexer(object):
494     """
495     A class that broadcasts printed messages to several sinks (including
496     various logging levels, different files, different file handles,
497     the house log, etc...)
498
499     """
500     class Destination(enum.IntEnum):
501         """Bits in the destination_bitv bitvector.  Used to indicate the
502         output destination."""
503         LOG_DEBUG = 0x01         #  ⎫
504         LOG_INFO = 0x02          #  ⎪
505         LOG_WARNING = 0x04       #  ⎬ Must provide logger to the c'tor.
506         LOG_ERROR = 0x08         #  ⎪
507         LOG_CRITICAL = 0x10      #  ⎭
508         FILENAMES = 0x20         # Must provide a filename to the c'tor.
509         FILEHANDLES = 0x40       # Must provide a handle to the c'tor.
510         HLOG = 0x80
511         ALL_LOG_DESTINATIONS = (
512             LOG_DEBUG | LOG_INFO | LOG_WARNING | LOG_ERROR | LOG_CRITICAL
513         )
514         ALL_OUTPUT_DESTINATIONS = 0x8F
515
516     def __init__(self,
517                  destination_bitv: int,
518                  *,
519                  logger=None,
520                  filenames: Optional[Iterable[str]] = None,
521                  handles: Optional[Iterable[io.TextIOWrapper]] = None):
522         if logger is None:
523             logger = logging.getLogger(None)
524         self.logger = logger
525
526         if filenames is not None:
527             self.f = [
528                 open(filename, 'wb', buffering=0) for filename in filenames
529             ]
530         else:
531             if destination_bitv & OutputMultiplexer.FILENAMES:
532                 raise ValueError(
533                     "Filenames argument is required if bitv & FILENAMES"
534                 )
535             self.f = None
536
537         if handles is not None:
538             self.h = [handle for handle in handles]
539         else:
540             if destination_bitv & OutputMultiplexer.Destination.FILEHANDLES:
541                 raise ValueError(
542                     "Handle argument is required if bitv & FILEHANDLES"
543                 )
544             self.h = None
545
546         self.set_destination_bitv(destination_bitv)
547
548     def get_destination_bitv(self):
549         return self.destination_bitv
550
551     def set_destination_bitv(self, destination_bitv: int):
552         if destination_bitv & self.Destination.FILENAMES and self.f is None:
553             raise ValueError(
554                 "Filename argument is required if bitv & FILENAMES"
555             )
556         if destination_bitv & self.Destination.FILEHANDLES and self.h is None:
557             raise ValueError(
558                     "Handle argument is required if bitv & FILEHANDLES"
559                 )
560         self.destination_bitv = destination_bitv
561
562     def print(self, *args, **kwargs):
563         from string_utils import sprintf, strip_escape_sequences
564         end = kwargs.pop("end", None)
565         if end is not None:
566             if not isinstance(end, str):
567                 raise TypeError("end must be None or a string")
568         sep = kwargs.pop("sep", None)
569         if sep is not None:
570             if not isinstance(sep, str):
571                 raise TypeError("sep must be None or a string")
572         if kwargs:
573             raise TypeError("invalid keyword arguments to print()")
574         buf = sprintf(*args, end="", sep=sep)
575         if sep is None:
576             sep = " "
577         if end is None:
578             end = "\n"
579         if end == '\n':
580             buf += '\n'
581         if (
582                 self.destination_bitv & self.Destination.FILENAMES and
583                 self.f is not None
584         ):
585             for _ in self.f:
586                 _.write(buf.encode('utf-8'))
587                 _.flush()
588
589         if (
590                 self.destination_bitv & self.Destination.FILEHANDLES and
591                 self.h is not None
592         ):
593             for _ in self.h:
594                 _.write(buf)
595                 _.flush()
596
597         buf = strip_escape_sequences(buf)
598         if self.logger is not None:
599             if self.destination_bitv & self.Destination.LOG_DEBUG:
600                 self.logger.debug(buf)
601             if self.destination_bitv & self.Destination.LOG_INFO:
602                 self.logger.info(buf)
603             if self.destination_bitv & self.Destination.LOG_WARNING:
604                 self.logger.warning(buf)
605             if self.destination_bitv & self.Destination.LOG_ERROR:
606                 self.logger.error(buf)
607             if self.destination_bitv & self.Destination.LOG_CRITICAL:
608                 self.logger.critical(buf)
609         if self.destination_bitv & self.Destination.HLOG:
610             hlog(buf)
611
612     def close(self):
613         if self.f is not None:
614             for _ in self.f:
615                 _.close()
616
617
618 class OutputMultiplexerContext(OutputMultiplexer, contextlib.ContextDecorator):
619     """
620     A context that uses an OutputMultiplexer.  e.g.
621
622         with OutputMultiplexerContext(
623                 OutputMultiplexer.LOG_INFO |
624                 OutputMultiplexer.LOG_DEBUG |
625                 OutputMultiplexer.FILENAMES |
626                 OutputMultiplexer.FILEHANDLES,
627                 filenames = [ '/tmp/foo.log', '/var/log/bar.log' ],
628                 handles = [ f, g ]
629             ) as mplex:
630                 mplex.print("This is a log message!")
631
632     """
633     def __init__(self,
634                  destination_bitv: OutputMultiplexer.Destination,
635                  *,
636                  logger = None,
637                  filenames = None,
638                  handles = None):
639         super().__init__(
640             destination_bitv,
641             logger=logger,
642             filenames=filenames,
643             handles=handles)
644
645     def __enter__(self):
646         return self
647
648     def __exit__(self, etype, value, traceback) -> bool:
649         super().close()
650         if etype is not None:
651             return False
652         return True
653
654
655 def hlog(message: str) -> None:
656     """Write a message to the house log."""
657
658     message = message.replace("'", "'\"'\"'")
659     os.system(f"/usr/bin/logger -p local7.info -- '{message}'")
660
661
662 if __name__ == '__main__':
663     import doctest
664     doctest.testmod()