Update docs.
[python_utils.git] / math_utils.py
index 188d3234986c5f6dd5b0474512402c9adc98cbf1..ed9c2f450f0fb0a9b121069a19d4a0b1fe7acfba 100644 (file)
@@ -1,15 +1,21 @@
 #!/usr/bin/env python3
 
+# © Copyright 2021-2022, Scott Gasch
+
 """Mathematical helpers."""
 
+import collections
 import functools
 import math
 from heapq import heappop, heappush
-from typing import List, Optional
+from typing import Dict, List, Optional, Tuple
+
+import dict_utils
 
 
 class NumericPopulation(object):
-    """A running median computer.
+    """A numeric population with some statistics such as median, mean, pN,
+    stdev, etc...
 
     >>> pop = NumericPopulation()
     >>> pop.add_number(1)
@@ -24,7 +30,7 @@ class NumericPopulation(object):
     >>> pop.get_mean()
     5.2
     >>> round(pop.get_stdev(), 2)
-    6.99
+    1.75
     >>> pop.get_percentile(20)
     3
     >>> pop.get_percentile(60)
@@ -35,9 +41,12 @@ class NumericPopulation(object):
         self.lowers, self.highers = [], []
         self.aggregate = 0.0
         self.sorted_copy: Optional[List[float]] = None
+        self.maximum = None
+        self.minimum = None
 
     def add_number(self, number: float):
-        """O(2 log2 n)"""
+        """Adds a number to the population.  Runtime complexity of this
+        operation is :math:`O(2 log_2 n)`"""
 
         if not self.highers or number > self.highers[0]:
             heappush(self.highers, number)
@@ -45,6 +54,10 @@ class NumericPopulation(object):
             heappush(self.lowers, -number)  # for lowers we need a max heap
         self.aggregate += number
         self._rebalance()
+        if not self.maximum or number > self.maximum:
+            self.maximum = number
+        if not self.minimum or number < self.minimum:
+            self.minimum = number
 
     def _rebalance(self):
         if len(self.lowers) - len(self.highers) > 1:
@@ -68,6 +81,17 @@ class NumericPopulation(object):
         count = len(self.lowers) + len(self.highers)
         return self.aggregate / count
 
+    def get_mode(self) -> Tuple[float, int]:
+        """Returns the mode (most common member in the population)
+        in O(n) time."""
+
+        count: Dict[float, int] = collections.defaultdict(int)
+        for n in self.lowers:
+            count[-n] += 1
+        for n in self.highers:
+            count[n] += 1
+        return dict_utils.item_with_max_value(count)
+
     def get_stdev(self) -> float:
         """Returns the stdev so far in O(n) time."""
 
@@ -78,33 +102,37 @@ class NumericPopulation(object):
             variance += (n - mean) ** 2
         for n in self.highers:
             variance += (n - mean) ** 2
-        return math.sqrt(variance)
+        count = len(self.lowers) + len(self.highers)
+        return math.sqrt(variance) / count
+
+    def _create_sorted_copy_if_needed(self, count: int):
+        if not self.sorted_copy or count != len(self.sorted_copy):
+            self.sorted_copy = []
+            for x in self.lowers:
+                self.sorted_copy.append(-x)
+            for x in self.highers:
+                self.sorted_copy.append(x)
+            self.sorted_copy = sorted(self.sorted_copy)
 
     def get_percentile(self, n: float) -> float:
         """Returns the number at approximately pn% (i.e. the nth percentile)
-        of the distribution in O(n log n) time (expensive, requires a
-        complete sort).  Not thread safe.  Caching does across
-        multiple calls without an invocation to add_number.
-
+        of the distribution in O(n log n) time.  Not thread-safe;
+        does caching across multiple calls without an invocation to
+        add_number for perf reasons.
         """
         if n == 50:
             return self.get_median()
         count = len(self.lowers) + len(self.highers)
-        if self.sorted_copy is not None:
-            if count == len(self.sorted_copy):
-                index = round(count * (n / 100.0))
-                assert 0 <= index < count
-                return self.sorted_copy[index]
-        self.sorted_copy = [-x for x in self.lowers]
-        for x in self.highers:
-            self.sorted_copy.append(x)
-        self.sorted_copy = sorted(self.sorted_copy)
+        self._create_sorted_copy_if_needed(count)
+        assert self.sorted_copy
         index = round(count * (n / 100.0))
-        assert 0 <= index < count
+        index = max(0, index)
+        index = min(count - 1, index)
         return self.sorted_copy[index]
 
 
 def gcd_floats(a: float, b: float) -> float:
+    """Returns the greatest common divisor of a and b."""
     if a < b:
         return gcd_floats(b, a)
 
@@ -115,6 +143,7 @@ def gcd_floats(a: float, b: float) -> float:
 
 
 def gcd_float_sequence(lst: List[float]) -> float:
+    """Returns the greatest common divisor of a list of floats."""
     if len(lst) <= 0:
         raise ValueError("Need at least one number")
     elif len(lst) == 1:
@@ -127,8 +156,7 @@ def gcd_float_sequence(lst: List[float]) -> float:
 
 
 def truncate_float(n: float, decimals: int = 2):
-    """
-    Truncate a float to a particular number of decimals.
+    """Truncate a float to a particular number of decimals.
 
     >>> truncate_float(3.1415927, 3)
     3.141
@@ -149,7 +177,6 @@ def percentage_to_multiplier(percent: float) -> float:
     1.45
     >>> percentage_to_multiplier(-25)
     0.75
-
     """
     multiplier = percent / 100
     multiplier += 1.0
@@ -165,7 +192,6 @@ def multiplier_to_percent(multiplier: float) -> float:
     0.0
     >>> multiplier_to_percent(1.99)
     99.0
-
     """
     percent = multiplier
     if percent > 0.0:
@@ -188,7 +214,6 @@ def is_prime(n: int) -> bool:
     False
     >>> is_prime(51602981)
     True
-
     """
     if not isinstance(n, int):
         raise TypeError("argument passed to is_prime is not of 'int' type")