Improve documentation.
[pyutils.git] / src / pyutils / math_utils.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """Mathematical helpers."""
6
7 import collections
8 import functools
9 import math
10 from heapq import heappop, heappush
11 from typing import Dict, List, Optional, Tuple
12
13 from pyutils import dict_utils
14
15
16 class NumericPopulation(object):
17     """A numeric population with some statistics such as median, mean, pN,
18     stdev, etc...
19
20     >>> pop = NumericPopulation()
21     >>> pop.add_number(1)
22     >>> pop.add_number(10)
23     >>> pop.add_number(3)
24     >>> len(pop)
25     3
26     >>> pop.get_median()
27     3
28     >>> pop.add_number(7)
29     >>> pop.add_number(5)
30     >>> pop.get_median()
31     5
32     >>> pop.get_mean()
33     5.2
34     >>> round(pop.get_stdev(), 1)
35     1.4
36     >>> pop.get_percentile(20)
37     3
38     >>> pop.get_percentile(60)
39     7
40     """
41
42     def __init__(self):
43         self.lowers, self.highers = [], []
44         self.aggregate = 0.0
45         self.sorted_copy: Optional[List[float]] = None
46         self.maximum = None
47         self.minimum = None
48
49     def add_number(self, number: float):
50         """Adds a number to the population.  Runtime complexity of this
51         operation is :math:`O(2 log_2 n)`"""
52
53         if not self.highers or number > self.highers[0]:
54             heappush(self.highers, number)
55         else:
56             heappush(self.lowers, -number)  # for lowers we need a max heap
57         self.aggregate += number
58         self._rebalance()
59         if not self.maximum or number > self.maximum:
60             self.maximum = number
61         if not self.minimum or number < self.minimum:
62             self.minimum = number
63
64     def __len__(self):
65         """Return the population size."""
66         n = 0
67         if self.highers:
68             n += len(self.highers)
69         if self.lowers:
70             n += len(self.lowers)
71         return n
72
73     def _rebalance(self):
74         if len(self.lowers) - len(self.highers) > 1:
75             heappush(self.highers, -heappop(self.lowers))
76         elif len(self.highers) - len(self.lowers) > 1:
77             heappush(self.lowers, -heappop(self.highers))
78
79     def get_median(self) -> float:
80         """Returns the approximate median (p50) so far in :math:`O(1)` time."""
81
82         if len(self.lowers) == len(self.highers):
83             return -self.lowers[0]
84         elif len(self.lowers) > len(self.highers):
85             return -self.lowers[0]
86         else:
87             return self.highers[0]
88
89     def get_mean(self) -> float:
90         """Returns the mean (arithmetic mean) so far in :math:`O(1)` time."""
91
92         count = len(self)
93         return self.aggregate / count
94
95     def get_mode(self) -> Tuple[float, int]:
96         """Returns the mode (most common member in the population)
97         in :math:`O(n)` time."""
98
99         count: Dict[float, int] = collections.defaultdict(int)
100         for n in self.lowers:
101             count[-n] += 1
102         for n in self.highers:
103             count[n] += 1
104         return dict_utils.item_with_max_value(count)
105
106     def get_stdev(self) -> float:
107         """Returns the stdev so far in :math:`O(n)` time."""
108
109         mean = self.get_mean()
110         variance = 0.0
111         for n in self.lowers:
112             n = -n
113             variance += (n - mean) ** 2
114         for n in self.highers:
115             variance += (n - mean) ** 2
116         count = len(self.lowers) + len(self.highers)
117         return math.sqrt(variance) / count
118
119     def _create_sorted_copy_if_needed(self, count: int):
120         if not self.sorted_copy or count != len(self.sorted_copy):
121             self.sorted_copy = []
122             for x in self.lowers:
123                 self.sorted_copy.append(-x)
124             for x in self.highers:
125                 self.sorted_copy.append(x)
126             self.sorted_copy = sorted(self.sorted_copy)
127
128     def get_percentile(self, n: float) -> float:
129         """Returns the number at approximately pn% (i.e. the nth percentile)
130         of the distribution in :math:`O(n log_2 n)` time.  Not thread-safe;
131         does caching across multiple calls without an invocation to
132         add_number for perf reasons.
133         """
134         if n == 50:
135             return self.get_median()
136         count = len(self)
137         self._create_sorted_copy_if_needed(count)
138         assert self.sorted_copy
139         index = round(count * (n / 100.0))
140         index = max(0, index)
141         index = min(count - 1, index)
142         return self.sorted_copy[index]
143
144
145 def gcd_floats(a: float, b: float) -> float:
146     """Returns the greatest common divisor of a and b."""
147     if a < b:
148         return gcd_floats(b, a)
149
150     # base case
151     if abs(b) < 0.001:
152         return a
153     return gcd_floats(b, a - math.floor(a / b) * b)
154
155
156 def gcd_float_sequence(lst: List[float]) -> float:
157     """Returns the greatest common divisor of a list of floats."""
158     if len(lst) <= 0:
159         raise ValueError("Need at least one number")
160     elif len(lst) == 1:
161         return lst[0]
162     assert len(lst) >= 2
163     gcd = gcd_floats(lst[0], lst[1])
164     for i in range(2, len(lst)):
165         gcd = gcd_floats(gcd, lst[i])
166     return gcd
167
168
169 def truncate_float(n: float, decimals: int = 2):
170     """Truncate a float to a particular number of decimals.
171
172     >>> truncate_float(3.1415927, 3)
173     3.141
174
175     """
176     assert 0 < decimals < 10
177     multiplier = 10**decimals
178     return int(n * multiplier) / multiplier
179
180
181 def percentage_to_multiplier(percent: float) -> float:
182     """Given a percentage (e.g. 155%), return a factor needed to scale a
183     number by that percentage.
184
185     >>> percentage_to_multiplier(155)
186     2.55
187     >>> percentage_to_multiplier(45)
188     1.45
189     >>> percentage_to_multiplier(-25)
190     0.75
191     """
192     multiplier = percent / 100
193     multiplier += 1.0
194     return multiplier
195
196
197 def multiplier_to_percent(multiplier: float) -> float:
198     """Convert a multiplicative factor into a percent change.
199
200     >>> multiplier_to_percent(0.75)
201     -25.0
202     >>> multiplier_to_percent(1.0)
203     0.0
204     >>> multiplier_to_percent(1.99)
205     99.0
206     """
207     percent = multiplier
208     if percent > 0.0:
209         percent -= 1.0
210     else:
211         percent = 1.0 - percent
212     percent *= 100.0
213     return percent
214
215
216 @functools.lru_cache(maxsize=1024, typed=True)
217 def is_prime(n: int) -> bool:
218     """
219     Returns True if n is prime and False otherwise.  Obviously(?) very slow for
220     very large input numbers.
221
222     >>> is_prime(13)
223     True
224     >>> is_prime(22)
225     False
226     >>> is_prime(51602981)
227     True
228     """
229     if not isinstance(n, int):
230         raise TypeError("argument passed to is_prime is not of 'int' type")
231
232     # Corner cases
233     if n <= 1:
234         return False
235     if n <= 3:
236         return True
237
238     # This is checked so that we can skip middle five numbers in below
239     # loop
240     if n % 2 == 0 or n % 3 == 0:
241         return False
242
243     i = 5
244     while i * i <= n:
245         if n % i == 0 or n % (i + 2) == 0:
246             return False
247         i = i + 6
248     return True
249
250
251 if __name__ == '__main__':
252     import doctest
253
254     doctest.testmod()