Update docs.
[python_utils.git] / math_utils.py
index 3216d4a9222f3e9760d2f5276b40503e08cfee8f..ed9c2f450f0fb0a9b121069a19d4a0b1fe7acfba 100644 (file)
 #!/usr/bin/env python3
 
+# © Copyright 2021-2022, Scott Gasch
+
+"""Mathematical helpers."""
+
+import collections
 import functools
 import math
-from typing import List
-from heapq import heappush, heappop
+from heapq import heappop, heappush
+from typing import Dict, List, Optional, Tuple
 
+import dict_utils
 
-class RunningMedian(object):
-    """A running median computer.
 
-    >>> median = RunningMedian()
-    >>> median.add_number(1)
-    >>> median.add_number(10)
-    >>> median.add_number(3)
-    >>> median.get_median()
+class NumericPopulation(object):
+    """A numeric population with some statistics such as median, mean, pN,
+    stdev, etc...
+
+    >>> pop = NumericPopulation()
+    >>> pop.add_number(1)
+    >>> pop.add_number(10)
+    >>> pop.add_number(3)
+    >>> pop.get_median()
     3
-    >>> median.add_number(7)
-    >>> median.add_number(5)
-    >>> median.get_median()
+    >>> pop.add_number(7)
+    >>> pop.add_number(5)
+    >>> pop.get_median()
     5
+    >>> pop.get_mean()
+    5.2
+    >>> round(pop.get_stdev(), 2)
+    1.75
+    >>> pop.get_percentile(20)
+    3
+    >>> pop.get_percentile(60)
+    7
     """
 
     def __init__(self):
         self.lowers, self.highers = [], []
+        self.aggregate = 0.0
+        self.sorted_copy: Optional[List[float]] = None
+        self.maximum = None
+        self.minimum = None
+
+    def add_number(self, number: float):
+        """Adds a number to the population.  Runtime complexity of this
+        operation is :math:`O(2 log_2 n)`"""
 
-    def add_number(self, number):
         if not self.highers or number > self.highers[0]:
             heappush(self.highers, number)
         else:
             heappush(self.lowers, -number)  # for lowers we need a max heap
-        self.rebalance()
-
-    def rebalance(self):
+        self.aggregate += number
+        self._rebalance()
+        if not self.maximum or number > self.maximum:
+            self.maximum = number
+        if not self.minimum or number < self.minimum:
+            self.minimum = number
+
+    def _rebalance(self):
         if len(self.lowers) - len(self.highers) > 1:
             heappush(self.highers, -heappop(self.lowers))
         elif len(self.highers) - len(self.lowers) > 1:
             heappush(self.lowers, -heappop(self.highers))
 
-    def get_median(self):
+    def get_median(self) -> float:
+        """Returns the approximate median (p50) so far in O(1) time."""
+
         if len(self.lowers) == len(self.highers):
-            return (-self.lowers[0] + self.highers[0]) / 2
+            return -self.lowers[0]
         elif len(self.lowers) > len(self.highers):
             return -self.lowers[0]
         else:
             return self.highers[0]
 
+    def get_mean(self) -> float:
+        """Returns the mean (arithmetic mean) so far in O(1) time."""
+
+        count = len(self.lowers) + len(self.highers)
+        return self.aggregate / count
+
+    def get_mode(self) -> Tuple[float, int]:
+        """Returns the mode (most common member in the population)
+        in O(n) time."""
+
+        count: Dict[float, int] = collections.defaultdict(int)
+        for n in self.lowers:
+            count[-n] += 1
+        for n in self.highers:
+            count[n] += 1
+        return dict_utils.item_with_max_value(count)
+
+    def get_stdev(self) -> float:
+        """Returns the stdev so far in O(n) time."""
+
+        mean = self.get_mean()
+        variance = 0.0
+        for n in self.lowers:
+            n = -n
+            variance += (n - mean) ** 2
+        for n in self.highers:
+            variance += (n - mean) ** 2
+        count = len(self.lowers) + len(self.highers)
+        return math.sqrt(variance) / count
+
+    def _create_sorted_copy_if_needed(self, count: int):
+        if not self.sorted_copy or count != len(self.sorted_copy):
+            self.sorted_copy = []
+            for x in self.lowers:
+                self.sorted_copy.append(-x)
+            for x in self.highers:
+                self.sorted_copy.append(x)
+            self.sorted_copy = sorted(self.sorted_copy)
+
+    def get_percentile(self, n: float) -> float:
+        """Returns the number at approximately pn% (i.e. the nth percentile)
+        of the distribution in O(n log n) time.  Not thread-safe;
+        does caching across multiple calls without an invocation to
+        add_number for perf reasons.
+        """
+        if n == 50:
+            return self.get_median()
+        count = len(self.lowers) + len(self.highers)
+        self._create_sorted_copy_if_needed(count)
+        assert self.sorted_copy
+        index = round(count * (n / 100.0))
+        index = max(0, index)
+        index = min(count - 1, index)
+        return self.sorted_copy[index]
+
 
 def gcd_floats(a: float, b: float) -> float:
+    """Returns the greatest common divisor of a and b."""
     if a < b:
         return gcd_floats(b, a)
 
@@ -57,6 +143,7 @@ def gcd_floats(a: float, b: float) -> float:
 
 
 def gcd_float_sequence(lst: List[float]) -> float:
+    """Returns the greatest common divisor of a list of floats."""
     if len(lst) <= 0:
         raise ValueError("Need at least one number")
     elif len(lst) == 1:
@@ -69,15 +156,14 @@ def gcd_float_sequence(lst: List[float]) -> float:
 
 
 def truncate_float(n: float, decimals: int = 2):
-    """
-    Truncate a float to a particular number of decimals.
+    """Truncate a float to a particular number of decimals.
 
     >>> truncate_float(3.1415927, 3)
     3.141
 
     """
-    assert decimals > 0 and decimals < 10
-    multiplier = 10 ** decimals
+    assert 0 < decimals < 10
+    multiplier = 10**decimals
     return int(n * multiplier) / multiplier
 
 
@@ -91,7 +177,6 @@ def percentage_to_multiplier(percent: float) -> float:
     1.45
     >>> percentage_to_multiplier(-25)
     0.75
-
     """
     multiplier = percent / 100
     multiplier += 1.0
@@ -107,7 +192,6 @@ def multiplier_to_percent(multiplier: float) -> float:
     0.0
     >>> multiplier_to_percent(1.99)
     99.0
-
     """
     percent = multiplier
     if percent > 0.0:
@@ -130,7 +214,6 @@ def is_prime(n: int) -> bool:
     False
     >>> is_prime(51602981)
     True
-
     """
     if not isinstance(n, int):
         raise TypeError("argument passed to is_prime is not of 'int' type")