034f90c0ee3ab932ffbf92740017fed369884b55
[python_utils.git] / logging_utils.py
1 #!/usr/bin/env python3
2
3 """Utilities related to logging."""
4
5 import collections
6 import contextlib
7 import datetime
8 import enum
9 import io
10 import logging
11 from logging.handlers import RotatingFileHandler, SysLogHandler
12 import os
13 import pytz
14 import sys
15 from typing import Iterable, Optional
16
17 # This module is commonly used by others in here and should avoid
18 # taking any unnecessary dependencies back on them.
19 import argparse_utils
20 import config
21
22 cfg = config.add_commandline_args(
23     f'Logging ({__file__})',
24     'Args related to logging')
25 cfg.add_argument(
26     '--logging_config_file',
27     type=argparse_utils.valid_filename,
28     default=None,
29     metavar='FILENAME',
30     help='Config file containing the logging setup, see: https://docs.python.org/3/howto/logging.html#logging-advanced-tutorial',
31 )
32 cfg.add_argument(
33     '--logging_level',
34     type=str,
35     default='INFO',
36     choices=['NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
37     metavar='LEVEL',
38     help='The level below which to squelch log messages.',
39 )
40 cfg.add_argument(
41     '--logging_format',
42     type=str,
43     default='%(levelname).1s:%(asctime)s: %(message)s',
44     help='The format for lines logged via the logger module.'
45 )
46 cfg.add_argument(
47     '--logging_date_format',
48     type=str,
49     default='%Y/%m/%dT%H:%M:%S.%f%z',
50     metavar='DATEFMT',
51     help='The format of any dates in --logging_format.'
52 )
53 cfg.add_argument(
54     '--logging_console',
55     action=argparse_utils.ActionNoYes,
56     default=True,
57     help='Should we log to the console (stderr)',
58 )
59 cfg.add_argument(
60     '--logging_filename',
61     type=str,
62     default=None,
63     metavar='FILENAME',
64     help='The filename of the logfile to write.'
65 )
66 cfg.add_argument(
67     '--logging_filename_maxsize',
68     type=int,
69     default=(1024*1024),
70     metavar='#BYTES',
71     help='The maximum size (in bytes) to write to the logging_filename.'
72 )
73 cfg.add_argument(
74     '--logging_filename_count',
75     type=int,
76     default=2,
77     metavar='COUNT',
78     help='The number of logging_filename copies to keep before deleting.'
79 )
80 cfg.add_argument(
81     '--logging_syslog',
82     action=argparse_utils.ActionNoYes,
83     default=False,
84     help='Should we log to localhost\'s syslog.'
85 )
86 cfg.add_argument(
87     '--logging_debug_threads',
88     action=argparse_utils.ActionNoYes,
89     default=False,
90     help='Should we prepend pid/tid data to all log messages?'
91 )
92 cfg.add_argument(
93     '--logging_info_is_print',
94     action=argparse_utils.ActionNoYes,
95     default=False,
96     help='logging.info also prints to stdout.'
97 )
98 cfg.add_argument(
99     '--logging_max_n_times_per_message',
100     type=int,
101     default=0,
102     help='When set, ignore logged messages from the same site after N.'
103 )
104
105 # See also: OutputMultiplexer
106 cfg.add_argument(
107     '--logging_captures_prints',
108     action=argparse_utils.ActionNoYes,
109     default=False,
110     help='When calling print also log.info too'
111 )
112
113 built_in_print = print
114
115
116 class OnlyInfoFilter(logging.Filter):
117     """
118     A filter that only logs messages produced at the INFO logging level.
119     """
120     def filter(self, record):
121         return record.levelno == logging.INFO
122
123
124 class OnlyNTimesFilter(logging.Filter):
125     """
126     A filter that only logs messages from a given site with the same
127     message at the same logging level N times and ignores subsequent
128     attempts to log.
129
130     """
131     def __init__(self, maximum: int) -> None:
132         self.maximum = maximum
133         self.counters = collections.Counter()
134         super().__init__()
135
136     def filter(self, record: logging.LogRecord) -> bool:
137         source = f'{record.pathname}+{record.lineno}+{record.levelno}+{record.msg}'
138         count = self.counters[source]
139         self.counters[source] += 1
140         return count < self.maximum
141
142
143 class MillisecondAwareFormatter(logging.Formatter):
144     """
145     A formatter for adding milliseconds to log messages.
146
147     """
148     converter = datetime.datetime.fromtimestamp
149
150     def formatTime(self, record, datefmt=None):
151         ct = MillisecondAwareFormatter.converter(
152             record.created, pytz.timezone("US/Pacific")
153         )
154         if datefmt:
155             s = ct.strftime(datefmt)
156         else:
157             t = ct.strftime("%Y-%m-%d %H:%M:%S")
158             s = "%s,%03d" % (t, record.msecs)
159         return s
160
161
162 def initialize_logging(logger=None) -> logging.Logger:
163     assert config.has_been_parsed()
164     if logger is None:
165         logger = logging.getLogger()       # Root logger
166
167     if config.config['logging_config_file'] is not None:
168         logging.config.fileConfig('logging.conf')
169         return logger
170
171     handlers = []
172     numeric_level = getattr(
173         logging,
174         config.config['logging_level'].upper(),
175         None
176     )
177     if not isinstance(numeric_level, int):
178         raise ValueError('Invalid level: %s' % config.config['logging_level'])
179
180     fmt = config.config['logging_format']
181     if config.config['logging_debug_threads']:
182         fmt = f'%(process)d.%(thread)d|{fmt}'
183
184     if config.config['logging_syslog']:
185         if sys.platform not in ('win32', 'cygwin'):
186             handler = SysLogHandler()
187 #            for k, v in encoded_priorities.items():
188 #                handler.encodePriority(k, v)
189             handler.setFormatter(
190                 MillisecondAwareFormatter(
191                     fmt=fmt,
192                     datefmt=config.config['logging_date_format'],
193                 )
194             )
195             handler.setLevel(numeric_level)
196             handlers.append(handler)
197
198     if config.config['logging_filename']:
199         handler = RotatingFileHandler(
200             config.config['logging_filename'],
201             maxBytes = config.config['logging_filename_maxsize'],
202             backupCount = config.config['logging_filename_count'],
203         )
204         handler.setLevel(numeric_level)
205         handler.setFormatter(
206             MillisecondAwareFormatter(
207                 fmt=fmt,
208                 datefmt=config.config['logging_date_format'],
209             )
210         )
211         handlers.append(handler)
212
213     if config.config['logging_console']:
214         handler = logging.StreamHandler(sys.stderr)
215         handler.setLevel(numeric_level)
216         handler.setFormatter(
217             MillisecondAwareFormatter(
218                 fmt=fmt,
219                 datefmt=config.config['logging_date_format'],
220             )
221         )
222         handlers.append(handler)
223
224     if len(handlers) == 0:
225         handlers.append(logging.NullHandler())
226
227     for handler in handlers:
228         logger.addHandler(handler)
229
230     if config.config['logging_info_is_print']:
231         handler = logging.StreamHandler(sys.stdout)
232         handler.addFilter(OnlyInfoFilter())
233         logger.addHandler(handler)
234
235     maximum = config.config['logging_max_n_times_per_message']
236     if maximum > 0:
237         for handler in handlers:
238             handler.addFilter(OnlyNTimesFilter(maximum))
239
240     logger.setLevel(numeric_level)
241     logger.propagate = False
242
243     if config.config['logging_captures_prints']:
244         import builtins
245         global built_in_print
246
247         def print_and_also_log(*arg, **kwarg):
248             f = kwarg.get('file', None)
249             if f == sys.stderr:
250                 logger.warning(*arg)
251             else:
252                 logger.info(*arg)
253             built_in_print(*arg, **kwarg)
254         builtins.print = print_and_also_log
255
256     return logger
257
258
259 def get_logger(name: str = ""):
260     logger = logging.getLogger(name)
261     return initialize_logging(logger)
262
263
264 def tprint(*args, **kwargs) -> None:
265     if config.config['logging_debug_threads']:
266         from thread_utils import current_thread_id
267         print(f'{current_thread_id()}', end="")
268         print(*args, **kwargs)
269     else:
270         pass
271
272
273 def dprint(*args, **kwargs) -> None:
274     print(*args, file=sys.stderr, **kwargs)
275
276
277 class OutputMultiplexer(object):
278
279     class Destination(enum.IntEnum):
280         """Bits in the destination_bitv bitvector.  Used to indicate the
281         output destination."""
282         LOG_DEBUG = 0x01         # -\
283         LOG_INFO = 0x02          #  |
284         LOG_WARNING = 0x04       #   > Should provide logger to the c'tor.
285         LOG_ERROR = 0x08         #  |
286         LOG_CRITICAL = 0x10      # _/
287         FILENAMES = 0x20         # Must provide a filename to the c'tor.
288         FILEHANDLES = 0x40       # Must provide a handle to the c'tor.
289         HLOG = 0x80
290         ALL_LOG_DESTINATIONS = (
291             LOG_DEBUG | LOG_INFO | LOG_WARNING | LOG_ERROR | LOG_CRITICAL
292         )
293         ALL_OUTPUT_DESTINATIONS = 0x8F
294
295     def __init__(self,
296                  destination_bitv: int,
297                  *,
298                  logger=None,
299                  filenames: Optional[Iterable[str]] = None,
300                  handles: Optional[Iterable[io.TextIOWrapper]] = None):
301         if logger is None:
302             logger = logging.getLogger(None)
303         self.logger = logger
304
305         if filenames is not None:
306             self.f = [
307                 open(filename, 'wb', buffering=0) for filename in filenames
308             ]
309         else:
310             if destination_bitv & OutputMultiplexer.FILENAMES:
311                 raise ValueError(
312                     "Filenames argument is required if bitv & FILENAMES"
313                 )
314             self.f = None
315
316         if handles is not None:
317             self.h = [handle for handle in handles]
318         else:
319             if destination_bitv & OutputMultiplexer.Destination.FILEHANDLES:
320                 raise ValueError(
321                     "Handle argument is required if bitv & FILEHANDLES"
322                 )
323             self.h = None
324
325         self.set_destination_bitv(destination_bitv)
326
327     def get_destination_bitv(self):
328         return self.destination_bitv
329
330     def set_destination_bitv(self, destination_bitv: int):
331         if destination_bitv & self.Destination.FILENAMES and self.f is None:
332             raise ValueError(
333                 "Filename argument is required if bitv & FILENAMES"
334             )
335         if destination_bitv & self.Destination.FILEHANDLES and self.h is None:
336             raise ValueError(
337                     "Handle argument is required if bitv & FILEHANDLES"
338                 )
339         self.destination_bitv = destination_bitv
340
341     def print(self, *args, **kwargs):
342         from string_utils import sprintf, strip_escape_sequences
343         end = kwargs.pop("end", None)
344         if end is not None:
345             if not isinstance(end, str):
346                 raise TypeError("end must be None or a string")
347         sep = kwargs.pop("sep", None)
348         if sep is not None:
349             if not isinstance(sep, str):
350                 raise TypeError("sep must be None or a string")
351         if kwargs:
352             raise TypeError("invalid keyword arguments to print()")
353         buf = sprintf(*args, end="", sep=sep)
354         if sep is None:
355             sep = " "
356         if end is None:
357             end = "\n"
358         if end == '\n':
359             buf += '\n'
360         if (
361                 self.destination_bitv & self.Destination.FILENAMES and
362                 self.f is not None
363         ):
364             for _ in self.f:
365                 _.write(buf.encode('utf-8'))
366                 _.flush()
367
368         if (
369                 self.destination_bitv & self.Destination.FILEHANDLES and
370                 self.h is not None
371         ):
372             for _ in self.h:
373                 _.write(buf)
374                 _.flush()
375
376         buf = strip_escape_sequences(buf)
377         if self.logger is not None:
378             if self.destination_bitv & self.Destination.LOG_DEBUG:
379                 self.logger.debug(buf)
380             if self.destination_bitv & self.Destination.LOG_INFO:
381                 self.logger.info(buf)
382             if self.destination_bitv & self.Destination.LOG_WARNING:
383                 self.logger.warning(buf)
384             if self.destination_bitv & self.Destination.LOG_ERROR:
385                 self.logger.error(buf)
386             if self.destination_bitv & self.Destination.LOG_CRITICAL:
387                 self.logger.critical(buf)
388         if self.destination_bitv & self.Destination.HLOG:
389             hlog(buf)
390
391     def close(self):
392         if self.f is not None:
393             for _ in self.f:
394                 _.close()
395
396
397 class OutputMultiplexerContext(OutputMultiplexer, contextlib.ContextDecorator):
398     def __init__(self,
399                  destination_bitv: OutputMultiplexer.Destination,
400                  *,
401                  logger = None,
402                  filenames = None,
403                  handles = None):
404         super().__init__(
405             destination_bitv,
406             logger=logger,
407             filenames=filenames,
408             handles=handles)
409
410     def __enter__(self):
411         return self
412
413     def __exit__(self, etype, value, traceback) -> bool:
414         super().close()
415         if etype is not None:
416             return False
417         return True
418
419
420 def hlog(message: str) -> None:
421     message = message.replace("'", "'\"'\"'")
422     os.system(f"/usr/bin/logger -p local7.info -- '{message}'")