8ec47d2e224e98c6dd0749c36dce7f79553d35b4
[pyutils.git] / src / pyutils / math_utils.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2023, Scott Gasch
4
5 """Helper utilities with a mathematical / statictical focus."""
6
7 import collections
8 import functools
9 import math
10 from heapq import heappop, heappush
11 from typing import Dict, List, Optional, Tuple
12
13 from pyutils import dict_utils
14 from pyutils.typez.typing import Numeric
15
16
17 class NumericPopulation(object):
18     """This object *store* a numeric population in a way that enables relatively
19     fast addition of new numbers (:math:`O(2log_2 n)`) and instant access to the
20     median value in the population (:math:`O(1)`).  It also provides other population
21     summary statistics such as the :meth:`get_mode`, :meth:`get_percentile` and
22     :meth:`get_stdev`.
23
24     .. note::
25
26         Because this class stores a copy of all numbers added to it, it shouldn't
27         be used for very large populations.  Consider sampling.
28
29     >>> pop = NumericPopulation()
30     >>> pop.add_number(1)
31     >>> pop.add_number(10)
32     >>> pop.add_number(3)
33     >>> len(pop)
34     3
35     >>> pop.get_median()
36     3
37     >>> pop.add_number(7)
38     >>> pop.add_number(5)
39     >>> pop.get_median()
40     5
41     >>> pop.get_mean()
42     5.2
43     >>> round(pop.get_stdev(), 1)
44     3.1
45     >>> pop.get_percentile(20)
46     3
47     >>> pop.get_percentile(60)
48     7
49     """
50
51     def __init__(self):
52         self.lowers, self.highers = [], []
53         self.aggregate = 0.0
54         self.sorted_copy: Optional[List[Numeric]] = None
55         self.maximum = None
56         self.minimum = None
57
58     def add_number(self, number: Numeric):
59         """Adds a number to the population.  Runtime complexity of this
60         operation is :math:`O(2 log_2 n)`
61
62         Args:
63             number: the number to add_number to the population
64         """
65
66         if not self.highers or number > self.highers[0]:
67             heappush(self.highers, number)
68         else:
69             heappush(self.lowers, -number)  # for lowers we need a max heap
70         self.aggregate += number
71         self._rebalance()
72         if not self.maximum or number > self.maximum:
73             self.maximum = number
74         if not self.minimum or number < self.minimum:
75             self.minimum = number
76
77     def __len__(self):
78         """
79         Returns:
80             the population's current size.
81         """
82         n = 0
83         if self.highers:
84             n += len(self.highers)
85         if self.lowers:
86             n += len(self.lowers)
87         return n
88
89     def _rebalance(self):
90         """Internal helper for rebalancing the `lowers` and `highers` heaps"""
91         if len(self.lowers) - len(self.highers) > 1:
92             heappush(self.highers, -heappop(self.lowers))
93         elif len(self.highers) - len(self.lowers) > 1:
94             heappush(self.lowers, -heappop(self.highers))
95
96     def get_median(self) -> Numeric:
97         """
98         Returns:
99             The median (p50) of the current population in :math:`O(1)` time.
100         """
101         if len(self.lowers) == len(self.highers):
102             return -self.lowers[0]
103         elif len(self.lowers) > len(self.highers):
104             return -self.lowers[0]
105         else:
106             return self.highers[0]
107
108     def get_mean(self) -> float:
109         """
110         Returns:
111             The mean (arithmetic mean) so far in :math:`O(1)` time.
112         """
113         count = len(self)
114         return self.aggregate / count
115
116     def get_mode(self) -> Tuple[Numeric, int]:
117         """
118         Returns:
119             The population mode (most common member in the population)
120             in :math:`O(n)` time.
121         """
122         count: Dict[Numeric, int] = collections.defaultdict(int)
123         for n in self.lowers:
124             count[-n] += 1
125         for n in self.highers:
126             count[n] += 1
127         return dict_utils.item_with_max_value(count)  # type: ignore
128
129     def get_stdev(self) -> float:
130         """
131         Returns:
132             The stdev of the current population in :math:`O(n)` time.
133         """
134         mean = self.get_mean()
135         variance = 0.0
136         for n in self.lowers:
137             n = -n
138             variance += (n - mean) ** 2
139         for n in self.highers:
140             variance += (n - mean) ** 2
141         count = len(self.lowers) + len(self.highers)
142         return math.sqrt(variance / count)
143
144     def _create_sorted_copy_if_needed(self, count: int):
145         """Internal helper."""
146         if not self.sorted_copy or count != len(self.sorted_copy):
147             self.sorted_copy = []
148             for x in self.lowers:
149                 self.sorted_copy.append(-x)
150             for x in self.highers:
151                 self.sorted_copy.append(x)
152             self.sorted_copy = sorted(self.sorted_copy)
153
154     def get_percentile(self, n: float) -> Numeric:
155         """
156         Returns: the number at approximately pn% in the population
157         (i.e. the nth percentile) in :math:`O(n log_2 n)` time (it
158         performs a full sort).  This is not the most efficient
159         algorithm.
160
161         Not thread-safe; does caching across multiple calls without
162         an invocation to :meth:`add_number` for perf reasons.
163
164         Args:
165             n: the percentile to compute
166         """
167         if n == 50:
168             return self.get_median()
169         count = len(self)
170         self._create_sorted_copy_if_needed(count)
171         assert self.sorted_copy
172         index = round(count * (n / 100.0))
173         index = max(0, index)
174         index = min(count - 1, index)
175         return self.sorted_copy[index]
176
177
178 def gcd_floats(a: float, b: float) -> float:
179     """
180     Returns:
181         The greatest common divisor of a and b.
182
183     Args:
184         a: first operand
185         b: second operatnd
186     """
187     if a < b:
188         return gcd_floats(b, a)
189
190     # base case
191     if abs(b) < 0.001:
192         return a
193     return gcd_floats(b, a - math.floor(a / b) * b)
194
195
196 def gcd_float_sequence(lst: List[float]) -> float:
197     """
198     Returns:
199         The greatest common divisor of a list of floats.
200
201     Args:
202         lst: a list of operands
203     """
204     if len(lst) <= 0:
205         raise ValueError("Need at least one number")
206     if len(lst) == 1:
207         return lst[0]
208     assert len(lst) >= 2
209     gcd = gcd_floats(lst[0], lst[1])
210     for i in range(2, len(lst)):
211         gcd = gcd_floats(gcd, lst[i])
212     return gcd
213
214
215 def truncate_float(n: float, decimals: int = 2):
216     """
217     Returns:
218         A truncated float to a particular number of decimals.
219
220     Args:
221         n: the float to truncate
222         decimals: how many decimal places are desired?
223
224     >>> truncate_float(3.1415927, 3)
225     3.141
226     """
227     assert 0 < decimals < 10
228     multiplier = 10**decimals
229     return int(n * multiplier) / multiplier
230
231
232 def percentage_to_multiplier(percent: float) -> float:
233     """Given a percentage that represents a return or percent change
234     (e.g. 155%), determine the factor (i.e.  multiplier) needed to
235     scale a number by that percentage (e.g. 2.55x)
236
237     Args:
238         percent: the return percent to scale by
239
240     >>> percentage_to_multiplier(155)
241     2.55
242     >>> percentage_to_multiplier(45)
243     1.45
244     >>> percentage_to_multiplier(-25)
245     0.75
246
247     """
248     multiplier = percent / 100
249     multiplier += 1.0
250     return multiplier
251
252
253 def multiplier_to_percent(multiplier: float) -> float:
254     """Convert a multiplicative factor into a percent change or return
255     percentage.
256
257     Args:
258         multiplier: the multiplier for which to compute the percent change
259
260     >>> multiplier_to_percent(0.75)
261     -25.0
262     >>> multiplier_to_percent(1.0)
263     0.0
264     >>> multiplier_to_percent(1.99)
265     99.0
266     """
267     percent = multiplier
268     if percent > 0.0:
269         percent -= 1.0
270     else:
271         percent = 1.0 - percent
272     percent *= 100.0
273     return percent
274
275
276 @functools.lru_cache(maxsize=1024, typed=True)
277 def is_prime(n: int) -> bool:
278     """
279     Args:
280         n: the number for which primeness is to be determined.
281
282     Returns:
283         True if n is prime and False otherwise.
284
285     .. note::
286
287          Obviously(?) very slow for very large input numbers until
288          we get quantum computers.
289
290     >>> is_prime(13)
291     True
292     >>> is_prime(22)
293     False
294     >>> is_prime(51602981)
295     True
296     """
297     if not isinstance(n, int):
298         raise TypeError("argument passed to is_prime is not of 'int' type")
299
300     # Corner cases
301     if n <= 1:
302         return False
303     if n <= 3:
304         return True
305
306     # This is checked so that we can skip middle five numbers in below
307     # loop
308     if n % 2 == 0 or n % 3 == 0:
309         return False
310
311     i = 5
312     while i * i <= n:
313         if n % i == 0 or n % (i + 2) == 0:
314             return False
315         i = i + 6
316     return True
317
318
319 if __name__ == "__main__":
320     import doctest
321
322     doctest.testmod()