Money, Rate, CentCount and a bunch of bugfixes.
authorScott Gasch <[email protected]>
Thu, 29 Jul 2021 05:13:43 +0000 (22:13 -0700)
committerScott Gasch <[email protected]>
Thu, 29 Jul 2021 05:13:43 +0000 (22:13 -0700)
18 files changed:
ansi.py
bootstrap.py
file_utils.py
google_assistant.py
histogram.py
list_utils.py
math_utils.py
stopwatch.py
string_utils.py
tests/ansi_test.py [new file with mode: 0755]
tests/centcount_test.py [new file with mode: 0755]
tests/money_test.py [new file with mode: 0755]
tests/rate_test.py [new file with mode: 0755]
tests/string_utils_test.py
type/centcount.py [new file with mode: 0644]
type/money.py [new file with mode: 0644]
type/rate.py [new file with mode: 0644]
unittest_utils.py

diff --git a/ansi.py b/ansi.py
index 4c580c0a0262ad133a71f6def6ca63a7d003dd8f..769b29c46b21c705f8c8ad09906f9dd7239cfaa2 100755 (executable)
--- a/ansi.py
+++ b/ansi.py
@@ -1,9 +1,12 @@
 #!/usr/bin/env python3
 
+from abc import abstractmethod
 import difflib
+import io
 import logging
+import re
 import sys
-from typing import Dict, Optional, Tuple
+from typing import Any, Callable, Dict, Iterable, Optional, Tuple
 
 logger = logging.getLogger(__name__)
 
@@ -1846,6 +1849,37 @@ def bg(name: Optional[str] = "",
     return bg_24bit(red, green, blue)
 
 
+class StdoutInterceptor(io.TextIOBase):
+    def __init__(self):
+        self.saved_stdout: Optional[io.TextIOBase] = None
+        self.buf = ''
+
+    @abstractmethod
+    def write(self, s):
+        pass
+
+    def __enter__(self) -> None:
+        self.saved_stdout = sys.stdout
+        sys.stdout = self
+        return None
+
+    def __exit__(self, *args) -> bool:
+        sys.stdout = self.saved_stdout
+        print(self.buf)
+        return None
+
+
+class ProgrammableColorizer(StdoutInterceptor):
+    def __init__(self, patterns: Iterable[Tuple[re.Pattern, Callable[[Any, re.Pattern], str]]]):
+        super().__init__()
+        self.patterns = [_ for _ in patterns]
+
+    def write(self, s: str):
+        for pattern in self.patterns:
+            s = pattern[0].sub(pattern[1], s)
+        self.buf += s
+
+
 if __name__ == '__main__':
     def main() -> None:
         name = " ".join(sys.argv[1:])
index 3b03b3a691b5bb823a3f0548a616358eff93ee96..3489b8a7583ea8551fb1d78d1aa270250a1d7b32 100644 (file)
@@ -3,9 +3,7 @@
 import functools
 import logging
 import os
-import pdb
 import sys
-import traceback
 
 # This module is commonly used by others in here and should avoid
 # taking any unnecessary dependencies back on them.
@@ -25,21 +23,49 @@ args.add_argument(
     default=False,
     help='Break into pdb on top level unhandled exceptions.'
 )
+args.add_argument(
+    '--show_random_seed',
+    action=ActionNoYes,
+    default=False,
+    help='Should we display (and log.debug) the global random seed?'
+)
+args.add_argument(
+    '--set_random_seed',
+    type=int,
+    nargs=1,
+    default=None,
+    metavar='SEED_INT',
+    help='Override the global random seed with a particular number.'
+)
+
+original_hook = sys.excepthook
 
 
-def handle_uncaught_exception(
-        exc_type,
-        exc_value,
-        exc_traceback):
+def handle_uncaught_exception(exc_type, exc_value, exc_tb):
+    global original_hook
+    msg = f'Unhandled top level exception {exc_type}'
+    logger.exception(msg)
+    print(msg, file=sys.stderr)
     if issubclass(exc_type, KeyboardInterrupt):
-        sys.__excepthook__(exc_type, exc_value, exc_traceback)
+        sys.__excepthook__(exc_type, exc_value, exc_tb)
         return
-    logger.exception(f'Unhandled top level {exc_type}',
-                     exc_info=(exc_type, exc_value, exc_traceback))
-    traceback.print_exception(exc_type, exc_value, exc_traceback)
-    if config.config['debug_unhandled_exceptions']:
-        logger.info("Invoking the debugger...")
-        pdb.pm()
+    else:
+        if (
+                not sys.stderr.isatty() or
+                not sys.stdin.isatty()
+        ):
+            # stdin or stderr is redirected, just do the normal thing
+            original_hook(exc_type, exc_value, exc_tb)
+        else:
+            # a terminal is attached and stderr is not redirected, debug.
+            if config.config['debug_unhandled_exceptions']:
+                import traceback
+                import pdb
+                traceback.print_exception(exc_type, exc_value, exc_tb)
+                logger.info("Invoking the debugger...")
+                pdb.pm()
+            else:
+                original_hook(exc_type, exc_value, exc_tb)
 
 
 def initialize(entry_point):
@@ -47,7 +73,8 @@ def initialize(entry_point):
     """Remember to initialize config and logging before running main."""
     @functools.wraps(entry_point)
     def initialize_wrapper(*args, **kwargs):
-        sys.excepthook = handle_uncaught_exception
+        if sys.excepthook == sys.__excepthook__:
+            sys.excepthook = handle_uncaught_exception
         if (
                 '__globals__' in entry_point.__dict__ and
                 '__file__' in entry_point.__globals__
@@ -60,6 +87,21 @@ def initialize(entry_point):
 
         config.late_logging()
 
+        # Allow programs that don't bother to override the random seed
+        # to be replayed via the commandline.
+        import random
+        random_seed = config.config['set_random_seed']
+        if random_seed is not None:
+            random_seed = random_seed[0]
+        else:
+            random_seed = int.from_bytes(os.urandom(4), 'little')
+
+        if config.config['show_random_seed']:
+            msg = f'Global random seed is: {random_seed}'
+            print(msg)
+            logger.debug(msg)
+        random.seed(random_seed)
+
         logger.debug(f'Starting {entry_point.__name__} (program entry point)')
 
         ret = None
index 464b0e76cfba0ef4e80ba5343c24bf433584b9b5..525a1afb0e262e93082f91dc8860a932575ed27a 100644 (file)
@@ -7,11 +7,14 @@ import errno
 import hashlib
 import logging
 import os
+import io
 import pathlib
 import time
 from typing import Optional
 import glob
 from os.path import isfile, join, exists
+from uuid import uuid4
+
 
 logger = logging.getLogger(__name__)
 
@@ -249,3 +252,25 @@ def get_files_recursive(directory: str):
     for subdir in get_directories(directory):
         for file_or_directory in get_files_recursive(subdir):
             yield file_or_directory
+
+
+class FileWriter(object):
+    def __init__(self, filename: str) -> None:
+        self.filename = filename
+        uuid = uuid4()
+        self.tempfile = f'{filename}-{uuid}.tmp'
+        self.handle = None
+
+    def __enter__(self) -> io.TextIOWrapper:
+        assert not does_path_exist(self.tempfile)
+        self.handle = open(self.tempfile, mode="w")
+        return self.handle
+
+    def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
+        if self.handle is not None:
+            self.handle.close()
+            cmd = f'/bin/mv -f {self.tempfile} {self.filename}'
+            ret = os.system(cmd)
+            if (ret >> 8) != 0:
+                raise Exception(f'{cmd} failed, exit value {ret>>8}')
+        return None
index 71301e4779c2736a4a908df84b7ec7e67ba02b3f..a50003c7eb2a41e8326714ad24e4eccd2ec6cc34 100644 (file)
@@ -79,10 +79,15 @@ def ask_google(cmd: str, *, recognize_speech=True) -> GoogleResponse:
                     sample_rate=24000,
                     sample_width=2,
                 )
-                audio_transcription = recognizer.recognize_google(
-                    speech,
-                )
-                logger.debug(f"Transcription: '{audio_transcription}'")
+                try:
+                    audio_transcription = recognizer.recognize_google(
+                        speech,
+                    )
+                    logger.debug(f"Transcription: '{audio_transcription}'")
+                except sr.UnknownValueError as e:
+                    logger.exception(e)
+                    logger.warning('Unable to parse Google assistant\'s response.')
+                    audio_transcription = None
     else:
         logger.error(
             f'HTTP request to {url} with {payload} failed; code {r.status_code}'
index b98e8489c030de0b816c6815766a69c677599028..0368376434c3579014cf97974e3e6e381f494872 100644 (file)
@@ -66,7 +66,8 @@ class SimpleHistogram(Generic[T]):
             all_true = all_true and self.add_item(item)
         return all_true
 
-    def __repr__(self) -> str:
+    def __repr__(self,
+                 label_formatter='%10s') -> str:
         from text_utils import bar_graph
         max_population: Optional[int] = None
         for bucket in self.buckets:
@@ -86,17 +87,11 @@ class SimpleHistogram(Generic[T]):
             bar = bar_graph(
                 (pop / max_population),
                 include_text = False,
-                width = 70,
+                width = 58,
                 left_end = "",
                 right_end = "")
-            label = f'{start}..{end}'
-            txt += f'{label:12}: ' + bar + f"({pop}) ({len(bar)})\n"
+            label = f'{label_formatter}..{label_formatter}' % (start, end)
+            txt += f'{label:20}: ' + bar + f"({pop/self.count*100.0:5.2f}% n={pop})\n"
             if start == last_bucket_start:
                 break
-
-        txt = txt + f'''{self.count} item(s)
-{self.maximum} max
-{self.minimum} min
-{self.sigma/self.count:.3f} mean
-{self.median.get_median()} median'''
         return txt
index 74f1cf3078457d371194deb33ddf5ad6410ed599..7d3355cc85a72a047aacaa0c3f06430a9e8e8dd7 100644 (file)
@@ -21,3 +21,9 @@ def flatten(lst: List[Any]) -> List[Any]:
     if isinstance(lst[0], list):
         return flatten(lst[0]) + flatten(lst[1:])
     return lst[:1] + flatten(lst[1:])
+
+
+def prepend(item: Any, lst: List[Any]) -> List[Any]:
+    """Prepend an item to a list."""
+    lst = list.insert(0, item)
+    return lst
index 56fb7072366ab97621e032e9aed11d13d7740b5e..62771231bb67925483bcbf714fe2a8373b591058 100644 (file)
@@ -61,6 +61,22 @@ def truncate_float(n: float, decimals: int = 2):
     return int(n * multiplier) / multiplier
 
 
+def percentage_to_multiplier(percent: float) -> float:
+    multiplier = percent / 100
+    multiplier += 1.0
+    return multiplier
+
+
+def multiplier_to_percent(multiplier: float) -> float:
+    percent = multiplier
+    if percent > 0.0:
+        percent -= 1.0
+    else:
+        percent = 1.0 - percent
+    percent *= 100.0
+    return percent
+
+
 @functools.lru_cache(maxsize=1024, typed=True)
 def is_prime(n: int) -> bool:
     """Returns True if n is prime and False otherwise"""
index d54af8792d28eef4025c150cb1920fbd2113efee..1326cb1fec8ffb41db461f99aac59372ded655f1 100644 (file)
@@ -24,4 +24,4 @@ class Timer(object):
 
     def __exit__(self, *args) -> bool:
         self.end = time.perf_counter()
-        return True
+        return None  # don't suppress exceptions
index 911008d4c93bc50d6d78bb7d09d9d4aaaffdbcd5..6fc257de52c48f34e207e79e8b2227e914ad2b8c 100644 (file)
@@ -1,13 +1,15 @@
 #!/usr/bin/env python3
 
+import contextlib
 import datetime
+import io
 from itertools import zip_longest
 import json
 import logging
 import random
 import re
 import string
-from typing import Any, List, Optional
+from typing import Any, Callable, List, Optional
 import unicodedata
 from uuid import uuid4
 
@@ -921,6 +923,22 @@ def sprintf(*args, **kwargs) -> str:
     return ret
 
 
+class SprintfStdout(object):
+    def __init__(self) -> None:
+        self.destination = io.StringIO()
+        self.recorder = None
+
+    def __enter__(self) -> Callable[[], str]:
+        self.recorder = contextlib.redirect_stdout(self.destination)
+        self.recorder.__enter__()
+        return lambda: self.destination.getvalue()
+
+    def __exit__(self, *args) -> None:
+        self.recorder.__exit__(*args)
+        self.destination.seek(0)
+        return None  # don't suppress exceptions
+
+
 def is_are(n: int) -> str:
     if n == 1:
         return "is"
diff --git a/tests/ansi_test.py b/tests/ansi_test.py
new file mode 100755 (executable)
index 0000000..4c1f449
--- /dev/null
@@ -0,0 +1,19 @@
+#!/usr/bin/env python3
+
+import unittest
+
+import ansi
+import unittest_utils as uu
+
+
+class TestAnsi(unittest.TestCase):
+
+    def test_colorizer(self):
+        with ansi.Colorizer() as c:
+            print("testing...")
+            print("Section:")
+            print("  This is some detail.")
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/tests/centcount_test.py b/tests/centcount_test.py
new file mode 100755 (executable)
index 0000000..3122b98
--- /dev/null
@@ -0,0 +1,102 @@
+#!/usr/bin/env python3
+
+import unittest
+
+from type.centcount import CentCount
+import unittest_utils as uu
+
+
+class TestCentCount(unittest.TestCase):
+
+    def test_basic_utility(self):
+        amount = CentCount(1.45)
+        another = CentCount.parse("USD 1.45")
+        self.assertEqual(amount, another)
+
+    def test_negation(self):
+        amount = CentCount(1.45)
+        amount = -amount
+        self.assertEqual(CentCount(-1.45), amount)
+
+    def test_addition_and_subtraction(self):
+        amount = CentCount(1.00)
+        another = CentCount(2.00)
+        total = amount + another
+        self.assertEqual(CentCount(3.00), total)
+        delta = another - amount
+        self.assertEqual(CentCount(1.00), delta)
+        neg = amount - another
+        self.assertEqual(CentCount(-1.00), neg)
+        neg += another
+        self.assertEqual(CentCount(1.00), neg)
+        neg += 1.00
+        self.assertEqual(CentCount(2.00), neg)
+        neg -= 1.00
+        self.assertEqual(CentCount(1.00), neg)
+        x = 1000 - amount
+        self.assertEqual(CentCount(9.0), x)
+
+    def test_multiplication(self):
+        amount = CentCount(3.00)
+        amount *= 3
+        self.assertEqual(CentCount(9.00), amount)
+        with self.assertRaises(TypeError):
+            another = CentCount(0.33)
+            amount *= another
+
+    def test_division(self):
+        amount = CentCount(10.00)
+        x = amount / 5.0
+        self.assertEqual(CentCount(2.00), x)
+        with self.assertRaises(TypeError):
+            another = CentCount(1.33)
+            amount /= another
+
+    def test_equality(self):
+        usa = CentCount(1.0, 'USD')
+        can = CentCount(1.0, 'CAD')
+        self.assertNotEqual(usa, can)
+        eh = CentCount(1.0, 'CAD')
+        self.assertEqual(can, eh)
+
+    def test_comparison(self):
+        one = CentCount(1.0)
+        two = CentCount(2.0)
+        three = CentCount(3.0)
+        neg_one = CentCount(-1)
+        self.assertLess(one, two)
+        self.assertLess(neg_one, one)
+        self.assertGreater(one, neg_one)
+        self.assertGreater(three, one)
+        looney = CentCount(1.0, 'CAD')
+        with self.assertRaises(TypeError):
+            print(looney < one)
+
+    def test_strict_mode(self):
+        one = CentCount(1.0, strict_mode=True)
+        two = CentCount(2.0, strict_mode=True)
+        with self.assertRaises(TypeError):
+            x = one + 2.4
+        self.assertEqual(CentCount(3.0), one + two)
+        with self.assertRaises(TypeError):
+            x = two - 1.9
+        self.assertEqual(CentCount(1.0), two - one)
+        with self.assertRaises(TypeError):
+            print(one == 1.0)
+        self.assertTrue(CentCount(1.0) == one)
+        with self.assertRaises(TypeError):
+            print(one < 2.0)
+        self.assertTrue(one < two)
+        with self.assertRaises(TypeError):
+            print(two > 1.0)
+        self.assertTrue(two > one)
+
+    def test_truncate_and_round(self):
+        ten = CentCount(10.0)
+        x = ten * 2 / 3
+        x.truncate_fractional_cents()
+        self.assertEqual(CentCount(6.66), x)
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/tests/money_test.py b/tests/money_test.py
new file mode 100755 (executable)
index 0000000..57f4637
--- /dev/null
@@ -0,0 +1,103 @@
+#!/usr/bin/env python3
+
+import unittest
+
+from type.money import Money
+import unittest_utils as uu
+
+
+class TestMoney(unittest.TestCase):
+
+    def test_basic_utility(self):
+        amount = Money(1.45)
+        another = Money.parse("USD 1.45")
+        self.assertAlmostEqual(amount.amount, another.amount)
+
+    def test_negation(self):
+        amount = Money(1.45)
+        amount = -amount
+        self.assertAlmostEqual(Money(-1.45).amount, amount.amount)
+
+    def test_addition_and_subtraction(self):
+        amount = Money(1.00)
+        another = Money(2.00)
+        total = amount + another
+        self.assertEqual(Money(3.00), total)
+        delta = another - amount
+        self.assertEqual(Money(1.00), delta)
+        neg = amount - another
+        self.assertEqual(Money(-1.00), neg)
+        neg += another
+        self.assertEqual(Money(1.00), neg)
+        neg += 1.00
+        self.assertEqual(Money(2.00), neg)
+        neg -= 1
+        self.assertEqual(Money(1.00), neg)
+        x = 10 - amount
+        self.assertEqual(Money(9.0), x)
+
+    def test_multiplication(self):
+        amount = Money(3.00)
+        amount *= 3
+        self.assertEqual(Money(9.00), amount)
+        with self.assertRaises(TypeError):
+            another = Money(0.33)
+            amount *= another
+
+    def test_division(self):
+        amount = Money(10.00)
+        x = amount / 5.0
+        self.assertEqual(Money(2.00), x)
+        with self.assertRaises(TypeError):
+            another = Money(1.33)
+            amount /= another
+
+    def test_equality(self):
+        usa = Money(1.0, 'USD')
+        can = Money(1.0, 'CAD')
+        self.assertNotEqual(usa, can)
+        eh = Money(1.0, 'CAD')
+        self.assertEqual(can, eh)
+
+    def test_comparison(self):
+        one = Money(1.0)
+        two = Money(2.0)
+        three = Money(3.0)
+        neg_one = Money(-1)
+        self.assertLess(one, two)
+        self.assertLess(neg_one, one)
+        self.assertGreater(one, neg_one)
+        self.assertGreater(three, one)
+        looney = Money(1.0, 'CAD')
+        with self.assertRaises(TypeError):
+            print(looney < one)
+
+    def test_strict_mode(self):
+        one = Money(1.0, strict_mode=True)
+        two = Money(2.0, strict_mode=True)
+        with self.assertRaises(TypeError):
+            x = one + 2.4
+        self.assertEqual(Money(3.0), one + two)
+        with self.assertRaises(TypeError):
+            x = two - 1.9
+        self.assertEqual(Money(1.0), two - one)
+        with self.assertRaises(TypeError):
+            print(one == 1.0)
+        self.assertTrue(Money(1.0) == one)
+        with self.assertRaises(TypeError):
+            print(one < 2.0)
+        self.assertTrue(one < two)
+        with self.assertRaises(TypeError):
+            print(two > 1.0)
+        self.assertTrue(two > one)
+
+    def test_truncate_and_round(self):
+        ten = Money(10.0)
+        x = ten * 2 / 3
+        self.assertEqual(6.66, x.truncate_fractional_cents())
+        x = ten * 2 / 3
+        self.assertEqual(6.67, x.round_fractional_cents())
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/tests/rate_test.py b/tests/rate_test.py
new file mode 100755 (executable)
index 0000000..621539b
--- /dev/null
@@ -0,0 +1,77 @@
+#!/usr/bin/env python3
+
+import unittest
+
+from type.rate import Rate
+from type.money import Money
+
+import unittest_utils as uu
+
+
+class TestRate(unittest.TestCase):
+    def test_basic_utility(self):
+        my_stock_returns = Rate(percent_change=-20.0)
+        my_portfolio = 1000.0
+        self.assertAlmostEqual(
+            800.0,
+            my_stock_returns.apply_to(my_portfolio)
+        )
+
+        my_bond_returns = Rate(percentage=104.5)
+        my_money = Money(500.0)
+        self.assertAlmostEqual(
+            Money(522.5),
+            my_bond_returns.apply_to(my_money)
+        )
+
+        my_multiplier = Rate(multiplier=1.72)
+        my_nose_length = 3.2
+        self.assertAlmostEqual(
+            5.504,
+            my_multiplier.apply_to(my_nose_length)
+        )
+
+    def test_conversions(self):
+        x = Rate(104.55)
+        s = x.__repr__()
+        y = Rate(s)
+        self.assertAlmostEqual(x, y)
+        f = float(x)
+        z = Rate(f)
+        self.assertAlmostEqual(x, z)
+
+    def test_divide(self):
+        x = Rate(20.0)
+        x /= 2
+        self.assertAlmostEqual(10.0, x)
+        x = Rate(-20.0)
+        x /= 2
+        self.assertAlmostEqual(-10.0, x)
+
+    def test_add(self):
+        x = Rate(5.0)
+        y = Rate(10.0)
+        z = x + y
+        self.assertAlmostEqual(15.0, z)
+        x = Rate(-5.0)
+        x += y
+        self.assertAlmostEqual(5.0, x)
+
+    def test_sub(self):
+        x = Rate(5.0)
+        y = Rate(10.0)
+        z = x - y
+        self.assertAlmostEqual(-5.0, z)
+        z = y - x
+        self.assertAlmostEqual(5.0, z)
+
+    def test_repr(self):
+        x = Rate(percent_change=-50.0)
+        s = x.__repr__(relative=True)
+        self.assertEqual("-50.000%", s)
+        s = x.__repr__()
+        self.assertEqual("+50.000%", s)
+
+
+if __name__ == '__main__':
+    unittest.main()
index 0472daaccaf9a525794df24e79c8ae5f923898f0..cc570364047382c3d0e2aee570674cc37e87c710 100755 (executable)
@@ -180,6 +180,12 @@ class TestStringUtils(unittest.TestCase):
         self.assertFalse(su.is_snake_case('thisIsATest'))
         self.assertTrue(su.is_snake_case('this_is_a_test'))
 
+    def test_sprintf_context(self):
+        with su.SprintfStdout() as buf:
+            print("This is a test.")
+            print("This is another one.")
+        self.assertEqual('This is a test.\nThis is another one.\n', buf())
+
 
 if __name__ == '__main__':
     bootstrap.initialize(unittest.main)()
diff --git a/type/centcount.py b/type/centcount.py
new file mode 100644 (file)
index 0000000..4181721
--- /dev/null
@@ -0,0 +1,226 @@
+#!/usr/bin/env python3
+
+import re
+from typing import Optional, TypeVar, Tuple
+
+import math_utils
+
+
+T = TypeVar('T', bound='CentCount')
+
+
+class CentCount(object):
+    """A class for representing monetary amounts potentially with
+    different currencies.
+    """
+
+    def __init__ (
+            self,
+            centcount,
+            currency: str = 'USD',
+            *,
+            strict_mode = False
+    ):
+        self.strict_mode = strict_mode
+        if isinstance(centcount, str):
+            ret = CentCount._parse(centcount)
+            if ret is None:
+                raise Exception(f'Unable to parse money string "{centcount}"')
+            centcount = ret[0]
+            currency = ret[1]
+        if isinstance(centcount, float):
+            centcount = int(centcount * 100.0)
+        if not isinstance(centcount, int):
+            centcount = int(centcount)
+        self.centcount = centcount
+        if not currency:
+            self.currency: Optional[str] = None
+        else:
+            self.currency: Optional[str] = currency
+
+    def __repr__(self):
+        a = float(self.centcount)
+        a /= 100
+        a = round(a, 2)
+        s = f'{a:,.2f}'
+        if self.currency is not None:
+            return '%s %s' % (s, self.currency)
+        else:
+            return '$%s' % s
+
+    def __pos__(self):
+        return CentCount(centcount=self.centcount, currency=self.currency)
+
+    def __neg__(self):
+        return CentCount(centcount=-self.centcount, currency=self.currency)
+
+    def __add__(self, other):
+        if isinstance(other, CentCount):
+            if self.currency == other.currency:
+                return CentCount(
+                    centcount = self.centcount + other.centcount,
+                    currency = self.currency
+                )
+            else:
+                raise TypeError('Incompatible currencies in add expression')
+        else:
+            if self.strict_mode:
+                raise TypeError('In strict_mode only two moneys can be added')
+            else:
+                return self.__add__(CentCount(other, self.currency))
+
+    def __sub__(self, other):
+        if isinstance(other, CentCount):
+            if self.currency == other.currency:
+                return CentCount(
+                    centcount = self.centcount - other.centcount,
+                    currency = self.currency
+                )
+            else:
+                raise TypeError('Incompatible currencies in add expression')
+        else:
+            if self.strict_mode:
+                raise TypeError('In strict_mode only two moneys can be added')
+            else:
+                return self.__sub__(CentCount(other, self.currency))
+
+    def __mul__(self, other):
+        if isinstance(other, CentCount):
+            raise TypeError('can not multiply monetary quantities')
+        else:
+            return CentCount(
+                centcount = int(self.centcount * float(other)),
+                currency = self.currency
+            )
+
+    def __truediv__(self, other):
+        if isinstance(other, CentCount):
+            raise TypeError('can not divide monetary quantities')
+        else:
+            return CentCount(
+                centcount = int(float(self.centcount) / float(other)),
+                currency = self.currency
+            )
+
+    def __int__(self):
+        return self.centcount.__int__()
+
+    def __float__(self):
+        return self.centcount.__float__() / 100.0
+
+    def truncate_fractional_cents(self):
+        x = int(self)
+        self.centcount = int(math_utils.truncate_float(x))
+        return self.centcount
+
+    def round_fractional_cents(self):
+        x = int(self)
+        self.centcount = int(round(x, 2))
+        return self.centcount
+
+    __radd__ = __add__
+
+    def __rsub__(self, other):
+        if isinstance(other, CentCount):
+            if self.currency == other.currency:
+                return CentCount(
+                    centcount = other.centcount - self.centcount,
+                    currency = self.currency
+                )
+            else:
+                raise TypeError('Incompatible currencies in sub expression')
+        else:
+            if self.strict_mode:
+                raise TypeError('In strict_mode only two moneys can be added')
+            else:
+                return CentCount(
+                    centcount = int(other) - self.centcount,
+                    currency = self.currency
+                )
+
+    __rmul__ = __mul__
+
+    #
+    # Override comparison operators to also compare currency.
+    #
+    def __eq__(self, other):
+        if other is None:
+            return False
+        if isinstance(other, CentCount):
+            return (
+                self.centcount == other.centcount and
+                self.currency == other.currency
+            )
+        if self.strict_mode:
+            raise TypeError("In strict mode only two CentCounts can be compared")
+        else:
+            return self.centcount == int(other)
+
+    def __ne__(self, other):
+        result = self.__eq__(other)
+        if result is NotImplemented:
+            return result
+        return not result
+
+    def __lt__(self, other):
+        if isinstance(other, CentCount):
+            if self.currency == other.currency:
+                return self.centcount < other.centcount
+            else:
+                raise TypeError('can not directly compare different currencies')
+        else:
+            if self.strict_mode:
+                raise TypeError('In strict mode, only two CentCounts can be compated')
+            else:
+                return self.centcount < int(other)
+
+    def __gt__(self, other):
+        if isinstance(other, CentCount):
+            if self.currency == other.currency:
+                return self.centcount > other.centcount
+            else:
+                raise TypeError('can not directly compare different currencies')
+        else:
+            if self.strict_mode:
+                raise TypeError('In strict mode, only two CentCounts can be compated')
+            else:
+                return self.centcount > int(other)
+
+    def __le__(self, other):
+        return self < other or self == other
+
+    def __ge__(self, other):
+        return self > other or self == other
+
+    def __hash__(self):
+        return self.__repr__
+
+    CENTCOUNT_RE = re.compile("^([+|-]?)(\d+)(\.\d+)$")
+    CURRENCY_RE = re.compile("^[A-Z][A-Z][A-Z]$")
+
+    @classmethod
+    def _parse(cls, s: str) -> Optional[Tuple[int, str]]:
+        centcount = None
+        currency = None
+        s = s.strip()
+        chunks = s.split(' ')
+        try:
+            for chunk in chunks:
+                if CentCount.CENTCOUNT_RE.match(chunk) is not None:
+                    centcount = int(float(chunk) * 100.0)
+                elif CentCount.CURRENCY_RE.match(chunk) is not None:
+                    currency = chunk
+        except:
+            pass
+        if centcount is not None and currency is not None:
+            return (centcount, currency)
+        elif centcount is not None:
+            return (centcount, 'USD')
+        return None
+
+    @classmethod
+    def parse(cls, s: str) -> T:
+        chunks = CentCount._parse(s)
+        if chunks is not None:
+            return CentCount(chunks[0], chunks[1])
+        raise Exception(f'Unable to parse money string "{s}"')
diff --git a/type/money.py b/type/money.py
new file mode 100644 (file)
index 0000000..c77a938
--- /dev/null
@@ -0,0 +1,227 @@
+#!/usr/bin/env python3
+
+from decimal import Decimal
+import re
+from typing import Optional, TypeVar, Tuple
+
+import math_utils
+
+
+T = TypeVar('T', bound='Money')
+
+
+class Money(object):
+    """A class for representing monetary amounts potentially with
+    different currencies.
+    """
+
+    def __init__ (
+            self,
+            amount: Decimal = Decimal("0.0"),
+            currency: str = 'USD',
+            *,
+            strict_mode = False
+    ):
+        self.strict_mode = strict_mode
+        if isinstance(amount, str):
+            ret = Money._parse(amount)
+            if ret is None:
+                raise Exception(f'Unable to parse money string "{amount}"')
+            amount = ret[0]
+            currency = ret[1]
+        if not isinstance(amount, Decimal):
+            amount = Decimal(float(amount))
+        self.amount = amount
+        if not currency:
+            self.currency: Optional[str] = None
+        else:
+            self.currency: Optional[str] = currency
+
+    def __repr__(self):
+        a = float(self.amount)
+        a = round(a, 2)
+        s = f'{a:,.2f}'
+        if self.currency is not None:
+            return '%s %s' % (s, self.currency)
+        else:
+            return '$%s' % s
+
+    def __pos__(self):
+        return Money(amount=self.amount, currency=self.currency)
+
+    def __neg__(self):
+        return Money(amount=-self.amount, currency=self.currency)
+
+    def __add__(self, other):
+        if isinstance(other, Money):
+            if self.currency == other.currency:
+                return Money(
+                    amount = self.amount + other.amount,
+                    currency = self.currency
+                )
+            else:
+                raise TypeError('Incompatible currencies in add expression')
+        else:
+            if self.strict_mode:
+                raise TypeError('In strict_mode only two moneys can be added')
+            else:
+                return Money(
+                    amount = self.amount + Decimal(float(other)),
+                    currency = self.currency
+                )
+
+    def __sub__(self, other):
+        if isinstance(other, Money):
+            if self.currency == other.currency:
+                return Money(
+                    amount = self.amount - other.amount,
+                    currency = self.currency
+                )
+            else:
+                raise TypeError('Incompatible currencies in add expression')
+        else:
+            if self.strict_mode:
+                raise TypeError('In strict_mode only two moneys can be added')
+            else:
+                return Money(
+                    amount = self.amount - Decimal(float(other)),
+                    currency = self.currency
+                )
+
+    def __mul__(self, other):
+        if isinstance(other, Money):
+            raise TypeError('can not multiply monetary quantities')
+        else:
+            return Money(
+                amount = self.amount * Decimal(float(other)),
+                currency = self.currency
+            )
+
+    def __truediv__(self, other):
+        if isinstance(other, Money):
+            raise TypeError('can not divide monetary quantities')
+        else:
+            return Money(
+                amount = self.amount / Decimal(float(other)),
+                currency = self.currency
+            )
+
+    def __float__(self):
+        return self.amount.__float__()
+
+    def truncate_fractional_cents(self):
+        x = float(self)
+        self.amount = Decimal(math_utils.truncate_float(x))
+        return self.amount
+
+    def round_fractional_cents(self):
+        x = float(self)
+        self.amount = Decimal(round(x, 2))
+        return self.amount
+
+    __radd__ = __add__
+
+    def __rsub__(self, other):
+        if isinstance(other, Money):
+            if self.currency == other.currency:
+                return Money(
+                    amount = other.amount - self.amount,
+                    currency = self.currency
+                )
+            else:
+                raise TypeError('Incompatible currencies in sub expression')
+        else:
+            if self.strict_mode:
+                raise TypeError('In strict_mode only two moneys can be added')
+            else:
+                return Money(
+                    amount = Decimal(float(other)) - self.amount,
+                    currency = self.currency
+                )
+
+    __rmul__ = __mul__
+
+    #
+    # Override comparison operators to also compare currency.
+    #
+    def __eq__(self, other):
+        if other is None:
+            return False
+        if isinstance(other, Money):
+            return (
+                self.amount == other.amount and
+                self.currency == other.currency
+            )
+        if self.strict_mode:
+            raise TypeError("In strict mode only two Moneys can be compared")
+        else:
+            return self.amount == Decimal(float(other))
+
+    def __ne__(self, other):
+        result = self.__eq__(other)
+        if result is NotImplemented:
+            return result
+        return not result
+
+    def __lt__(self, other):
+        if isinstance(other, Money):
+            if self.currency == other.currency:
+                return self.amount < other.amount
+            else:
+                raise TypeError('can not directly compare different currencies')
+        else:
+            if self.strict_mode:
+                raise TypeError('In strict mode, only two Moneys can be compated')
+            else:
+                return self.amount < Decimal(float(other))
+
+    def __gt__(self, other):
+        if isinstance(other, Money):
+            if self.currency == other.currency:
+                return self.amount > other.amount
+            else:
+                raise TypeError('can not directly compare different currencies')
+        else:
+            if self.strict_mode:
+                raise TypeError('In strict mode, only two Moneys can be compated')
+            else:
+                return self.amount > Decimal(float(other))
+
+    def __le__(self, other):
+        return self < other or self == other
+
+    def __ge__(self, other):
+        return self > other or self == other
+
+    def __hash__(self):
+        return self.__repr__
+
+    AMOUNT_RE = re.compile("^([+|-]?)(\d+)(\.\d+)$")
+    CURRENCY_RE = re.compile("^[A-Z][A-Z][A-Z]$")
+
+    @classmethod
+    def _parse(cls, s: str) -> Optional[Tuple[Decimal, str]]:
+        amount = None
+        currency = None
+        s = s.strip()
+        chunks = s.split(' ')
+        try:
+            for chunk in chunks:
+                if Money.AMOUNT_RE.match(chunk) is not None:
+                    amount = Decimal(chunk)
+                elif Money.CURRENCY_RE.match(chunk) is not None:
+                    currency = chunk
+        except:
+            pass
+        if amount is not None and currency is not None:
+            return (amount, currency)
+        elif amount is not None:
+            return (amount, 'USD')
+        return None
+
+    @classmethod
+    def parse(cls, s: str) -> T:
+        chunks = Money._parse(s)
+        if chunks is not None:
+            return Money(chunks[0], chunks[1])
+        raise Exception(f'Unable to parse money string "{s}"')
diff --git a/type/rate.py b/type/rate.py
new file mode 100644 (file)
index 0000000..3161131
--- /dev/null
@@ -0,0 +1,89 @@
+#!/usr/bin/env python3
+
+from typing import Optional
+
+
+class Rate(object):
+    def __init__(
+            self,
+            multiplier: Optional[float] = None,
+            *,
+            percentage: Optional[float] = None,
+            percent_change: Optional[float] = None,
+    ):
+        count = 0
+        if multiplier is not None:
+            if isinstance(multiplier, str):
+                multiplier = multiplier.replace('%', '')
+                m = float(multiplier)
+                m /= 100
+                self.multiplier = m
+            else:
+                self.multiplier = multiplier
+            count += 1
+        if percentage is not None:
+            self.multiplier = percentage / 100
+            count += 1
+        if percent_change is not None:
+            self.multiplier = 1.0 + percent_change / 100
+            count += 1
+        if count != 1:
+            raise Exception(
+                'Exactly one of percentage, percent_change or multiplier is required.'
+            )
+
+    def apply_to(self, other):
+        return self.__mul__(other)
+
+    def of(self, other):
+        return self.__mul__(other)
+
+    def __float__(self):
+        return self.multiplier
+
+    def __mul__(self, other):
+        return self.multiplier * float(other)
+
+    __rmul__ = __mul__
+
+    def __truediv__(self, other):
+        return self.multiplier / float(other)
+
+    def __add__(self, other):
+        return self.multiplier + float(other)
+
+    __radd__ = __add__
+
+    def __sub__(self, other):
+        return self.multiplier - float(other)
+
+    def __eq__(self, other):
+        return self.multiplier == float(other)
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
+    def __lt__(self, other):
+        return self.multiplier < float(other)
+
+    def __gt__(self, other):
+        return self.multiplier > float(other)
+
+    def __le__(self, other):
+        return self < other or self == other
+
+    def __ge__(self, other):
+        return self > other or self == other
+
+    def __hash__(self):
+        return self.multiplier
+
+    def __repr__(self,
+                 *,
+                 relative=False,
+                 places=3):
+        if relative:
+            percentage = (self.multiplier - 1.0) * 100.0
+        else:
+            percentage = self.multiplier * 100.0
+        return f'{percentage:+.{places}f}%'
index 5987da6ba38194bf639b7382347ed599ff5bc0e9..2dc8cfe231e5aad03d877fac8116472e0c4b6c3d 100644 (file)
@@ -1,9 +1,10 @@
 #!/usr/bin/env python3
 
 """Helpers for unittests.  Note that when you import this we
-automatically wrap unittest.main() with a call to bootstrap.initialize
-so that we getLogger config, commandline args, logging control,
-etc... this works fine but it's a little hacky so caveat emptor.
+   automatically wrap unittest.main() with a call to
+   bootstrap.initialize so that we getLogger config, commandline args,
+   logging control, etc... this works fine but it's a little hacky so
+   caveat emptor.
 """
 
 import contextlib
@@ -170,7 +171,7 @@ class RecordStdout(object):
     def __exit__(self, *args) -> bool:
         self.recorder.__exit__(*args)
         self.destination.seek(0)
-        return True
+        return None
 
 
 class RecordStderr(object):
@@ -192,7 +193,7 @@ class RecordStderr(object):
     def __exit__(self, *args) -> bool:
         self.recorder.__exit__(*args)
         self.destination.seek(0)
-        return True
+        return None
 
 
 class RecordMultipleStreams(object):