Add some useful stats to histogram.
[python_utils.git] / histogram.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3
4 """A text-based simple histogram helper class."""
5
6 import math
7 from dataclasses import dataclass
8 from typing import Dict, Generic, Iterable, List, Optional, Tuple, TypeVar
9
10 T = TypeVar("T", int, float)
11 Bound = int
12 Count = int
13
14
15 @dataclass
16 class BucketDetails:
17     """A collection of details about the internal histogram buckets."""
18
19     num_populated_buckets: int = 0
20     max_population: Optional[int] = None
21     last_bucket_start: Optional[int] = None
22     lowest_start: Optional[int] = None
23     highest_end: Optional[int] = None
24     max_label_width: Optional[int] = None
25
26
27 class SimpleHistogram(Generic[T]):
28     """A simple histogram."""
29
30     # Useful in defining wide open bottom/top bucket bounds:
31     POSITIVE_INFINITY = math.inf
32     NEGATIVE_INFINITY = -math.inf
33
34     def __init__(self, buckets: List[Tuple[Bound, Bound]]):
35         from math_utils import RunningMedian
36
37         self.buckets: Dict[Tuple[Bound, Bound], Count] = {}
38         for start_end in buckets:
39             if self._get_bucket(start_end[0]) is not None:
40                 raise Exception("Buckets overlap?!")
41             self.buckets[start_end] = 0
42         self.sigma: float = 0.0
43         self.stats: RunningMedian = RunningMedian()
44         self.maximum: Optional[T] = None
45         self.minimum: Optional[T] = None
46         self.count: Count = 0
47
48     @staticmethod
49     def n_evenly_spaced_buckets(
50         min_bound: T,
51         max_bound: T,
52         n: int,
53     ) -> List[Tuple[int, int]]:
54         ret: List[Tuple[int, int]] = []
55         stride = int((max_bound - min_bound) / n)
56         if stride <= 0:
57             raise Exception("Min must be < Max")
58         imax = math.ceil(max_bound)
59         imin = math.floor(min_bound)
60         for bucket_start in range(imin, imax, stride):
61             ret.append((bucket_start, bucket_start + stride))
62         return ret
63
64     def _get_bucket(self, item: T) -> Optional[Tuple[int, int]]:
65         for start_end in self.buckets:
66             if start_end[0] <= item < start_end[1]:
67                 return start_end
68         return None
69
70     def add_item(self, item: T) -> bool:
71         bucket = self._get_bucket(item)
72         if bucket is None:
73             return False
74         self.count += 1
75         self.buckets[bucket] += 1
76         self.sigma += item
77         self.stats.add_number(item)
78         if self.maximum is None or item > self.maximum:
79             self.maximum = item
80         if self.minimum is None or item < self.minimum:
81             self.minimum = item
82         return True
83
84     def add_items(self, lst: Iterable[T]) -> bool:
85         all_true = True
86         for item in lst:
87             all_true = all_true and self.add_item(item)
88         return all_true
89
90     def get_bucket_details(self, label_formatter: str) -> BucketDetails:
91         details = BucketDetails()
92         for (start, end), pop in sorted(self.buckets.items(), key=lambda x: x[0]):
93             if pop > 0:
94                 details.num_populated_buckets += 1
95                 details.last_bucket_start = start
96                 if details.max_population is None or pop > details.max_population:
97                     details.max_population = pop
98                 if details.lowest_start is None or start < details.lowest_start:
99                     details.lowest_start = start
100                 if details.highest_end is None or end > details.highest_end:
101                     details.highest_end = end
102                 label = f'[{label_formatter}..{label_formatter}): ' % (start, end)
103                 label_width = len(label)
104                 if details.max_label_width is None or label_width > details.max_label_width:
105                     details.max_label_width = label_width
106         return details
107
108     def __repr__(self, *, width: int = 80, label_formatter: str = '%d') -> str:
109         from text_utils import bar_graph
110
111         details = self.get_bucket_details(label_formatter)
112         txt = ""
113         if details.num_populated_buckets == 0:
114             return txt
115         assert details.max_label_width is not None
116         assert details.lowest_start is not None
117         assert details.highest_end is not None
118         assert details.max_population is not None
119         sigma_label = f'[{label_formatter}..{label_formatter}): ' % (
120             details.lowest_start,
121             details.highest_end,
122         )
123         if len(sigma_label) > details.max_label_width:
124             details.max_label_width = len(sigma_label)
125         bar_width = width - (details.max_label_width + 16)
126
127         for (start, end), pop in sorted(self.buckets.items(), key=lambda x: x[0]):
128             label = f'[{label_formatter}..{label_formatter}): ' % (start, end)
129             bar = bar_graph(
130                 (pop / details.max_population),
131                 include_text=False,
132                 width=bar_width,
133                 left_end="",
134                 right_end="",
135             )
136             txt += label.rjust(details.max_label_width)
137             txt += bar
138             txt += f"({pop/self.count*100.0:5.2f}% n={pop})\n"
139             if start == details.last_bucket_start:
140                 break
141         txt += '-' * width + '\n'
142         txt += sigma_label.rjust(details.max_label_width)
143         txt += ' ' * (bar_width - 2)
144         txt += f'Σ=(100.00% n={self.count})\n'
145         txt += ' ' * (bar_width + details.max_label_width - 2)
146         txt += f'mean(μ)={self.stats.get_mean():.3f}\n'
147         txt += ' ' * (bar_width + details.max_label_width - 2)
148         txt += f'p50(η)={self.stats.get_median():.3f}\n'
149         txt += ' ' * (bar_width + details.max_label_width - 2)
150         txt += f'stdev(σ)={self.stats.get_stdev():.3f}\n'
151         txt += '\n'
152         return txt