From 4faa994d32223c8d560d9dad0ca90a3f7eb10d6a Mon Sep 17 00:00:00 2001 From: Scott Gasch Date: Wed, 28 Jul 2021 22:13:43 -0700 Subject: [PATCH] Money, Rate, CentCount and a bunch of bugfixes. --- ansi.py | 36 +++++- bootstrap.py | 70 +++++++++--- file_utils.py | 25 ++++ google_assistant.py | 13 ++- histogram.py | 15 +-- list_utils.py | 6 + math_utils.py | 16 +++ stopwatch.py | 2 +- string_utils.py | 20 +++- tests/ansi_test.py | 19 ++++ tests/centcount_test.py | 102 +++++++++++++++++ tests/money_test.py | 103 +++++++++++++++++ tests/rate_test.py | 77 +++++++++++++ tests/string_utils_test.py | 6 + type/centcount.py | 226 ++++++++++++++++++++++++++++++++++++ type/money.py | 227 +++++++++++++++++++++++++++++++++++++ type/rate.py | 89 +++++++++++++++ unittest_utils.py | 11 +- 18 files changed, 1027 insertions(+), 36 deletions(-) create mode 100755 tests/ansi_test.py create mode 100755 tests/centcount_test.py create mode 100755 tests/money_test.py create mode 100755 tests/rate_test.py create mode 100644 type/centcount.py create mode 100644 type/money.py create mode 100644 type/rate.py diff --git a/ansi.py b/ansi.py index 4c580c0..769b29c 100755 --- a/ansi.py +++ b/ansi.py @@ -1,9 +1,12 @@ #!/usr/bin/env python3 +from abc import abstractmethod import difflib +import io import logging +import re import sys -from typing import Dict, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, Optional, Tuple logger = logging.getLogger(__name__) @@ -1846,6 +1849,37 @@ def bg(name: Optional[str] = "", return bg_24bit(red, green, blue) +class StdoutInterceptor(io.TextIOBase): + def __init__(self): + self.saved_stdout: Optional[io.TextIOBase] = None + self.buf = '' + + @abstractmethod + def write(self, s): + pass + + def __enter__(self) -> None: + self.saved_stdout = sys.stdout + sys.stdout = self + return None + + def __exit__(self, *args) -> bool: + sys.stdout = self.saved_stdout + print(self.buf) + return None + + +class ProgrammableColorizer(StdoutInterceptor): + def __init__(self, patterns: Iterable[Tuple[re.Pattern, Callable[[Any, re.Pattern], str]]]): + super().__init__() + self.patterns = [_ for _ in patterns] + + def write(self, s: str): + for pattern in self.patterns: + s = pattern[0].sub(pattern[1], s) + self.buf += s + + if __name__ == '__main__': def main() -> None: name = " ".join(sys.argv[1:]) diff --git a/bootstrap.py b/bootstrap.py index 3b03b3a..3489b8a 100644 --- a/bootstrap.py +++ b/bootstrap.py @@ -3,9 +3,7 @@ import functools import logging import os -import pdb import sys -import traceback # This module is commonly used by others in here and should avoid # taking any unnecessary dependencies back on them. @@ -25,21 +23,49 @@ args.add_argument( default=False, help='Break into pdb on top level unhandled exceptions.' ) +args.add_argument( + '--show_random_seed', + action=ActionNoYes, + default=False, + help='Should we display (and log.debug) the global random seed?' +) +args.add_argument( + '--set_random_seed', + type=int, + nargs=1, + default=None, + metavar='SEED_INT', + help='Override the global random seed with a particular number.' +) + +original_hook = sys.excepthook -def handle_uncaught_exception( - exc_type, - exc_value, - exc_traceback): +def handle_uncaught_exception(exc_type, exc_value, exc_tb): + global original_hook + msg = f'Unhandled top level exception {exc_type}' + logger.exception(msg) + print(msg, file=sys.stderr) if issubclass(exc_type, KeyboardInterrupt): - sys.__excepthook__(exc_type, exc_value, exc_traceback) + sys.__excepthook__(exc_type, exc_value, exc_tb) return - logger.exception(f'Unhandled top level {exc_type}', - exc_info=(exc_type, exc_value, exc_traceback)) - traceback.print_exception(exc_type, exc_value, exc_traceback) - if config.config['debug_unhandled_exceptions']: - logger.info("Invoking the debugger...") - pdb.pm() + else: + if ( + not sys.stderr.isatty() or + not sys.stdin.isatty() + ): + # stdin or stderr is redirected, just do the normal thing + original_hook(exc_type, exc_value, exc_tb) + else: + # a terminal is attached and stderr is not redirected, debug. + if config.config['debug_unhandled_exceptions']: + import traceback + import pdb + traceback.print_exception(exc_type, exc_value, exc_tb) + logger.info("Invoking the debugger...") + pdb.pm() + else: + original_hook(exc_type, exc_value, exc_tb) def initialize(entry_point): @@ -47,7 +73,8 @@ def initialize(entry_point): """Remember to initialize config and logging before running main.""" @functools.wraps(entry_point) def initialize_wrapper(*args, **kwargs): - sys.excepthook = handle_uncaught_exception + if sys.excepthook == sys.__excepthook__: + sys.excepthook = handle_uncaught_exception if ( '__globals__' in entry_point.__dict__ and '__file__' in entry_point.__globals__ @@ -60,6 +87,21 @@ def initialize(entry_point): config.late_logging() + # Allow programs that don't bother to override the random seed + # to be replayed via the commandline. + import random + random_seed = config.config['set_random_seed'] + if random_seed is not None: + random_seed = random_seed[0] + else: + random_seed = int.from_bytes(os.urandom(4), 'little') + + if config.config['show_random_seed']: + msg = f'Global random seed is: {random_seed}' + print(msg) + logger.debug(msg) + random.seed(random_seed) + logger.debug(f'Starting {entry_point.__name__} (program entry point)') ret = None diff --git a/file_utils.py b/file_utils.py index 464b0e7..525a1af 100644 --- a/file_utils.py +++ b/file_utils.py @@ -7,11 +7,14 @@ import errno import hashlib import logging import os +import io import pathlib import time from typing import Optional import glob from os.path import isfile, join, exists +from uuid import uuid4 + logger = logging.getLogger(__name__) @@ -249,3 +252,25 @@ def get_files_recursive(directory: str): for subdir in get_directories(directory): for file_or_directory in get_files_recursive(subdir): yield file_or_directory + + +class FileWriter(object): + def __init__(self, filename: str) -> None: + self.filename = filename + uuid = uuid4() + self.tempfile = f'{filename}-{uuid}.tmp' + self.handle = None + + def __enter__(self) -> io.TextIOWrapper: + assert not does_path_exist(self.tempfile) + self.handle = open(self.tempfile, mode="w") + return self.handle + + def __exit__(self, exc_type, exc_val, exc_tb) -> bool: + if self.handle is not None: + self.handle.close() + cmd = f'/bin/mv -f {self.tempfile} {self.filename}' + ret = os.system(cmd) + if (ret >> 8) != 0: + raise Exception(f'{cmd} failed, exit value {ret>>8}') + return None diff --git a/google_assistant.py b/google_assistant.py index 71301e4..a50003c 100644 --- a/google_assistant.py +++ b/google_assistant.py @@ -79,10 +79,15 @@ def ask_google(cmd: str, *, recognize_speech=True) -> GoogleResponse: sample_rate=24000, sample_width=2, ) - audio_transcription = recognizer.recognize_google( - speech, - ) - logger.debug(f"Transcription: '{audio_transcription}'") + try: + audio_transcription = recognizer.recognize_google( + speech, + ) + logger.debug(f"Transcription: '{audio_transcription}'") + except sr.UnknownValueError as e: + logger.exception(e) + logger.warning('Unable to parse Google assistant\'s response.') + audio_transcription = None else: logger.error( f'HTTP request to {url} with {payload} failed; code {r.status_code}' diff --git a/histogram.py b/histogram.py index b98e848..0368376 100644 --- a/histogram.py +++ b/histogram.py @@ -66,7 +66,8 @@ class SimpleHistogram(Generic[T]): all_true = all_true and self.add_item(item) return all_true - def __repr__(self) -> str: + def __repr__(self, + label_formatter='%10s') -> str: from text_utils import bar_graph max_population: Optional[int] = None for bucket in self.buckets: @@ -86,17 +87,11 @@ class SimpleHistogram(Generic[T]): bar = bar_graph( (pop / max_population), include_text = False, - width = 70, + width = 58, left_end = "", right_end = "") - label = f'{start}..{end}' - txt += f'{label:12}: ' + bar + f"({pop}) ({len(bar)})\n" + label = f'{label_formatter}..{label_formatter}' % (start, end) + txt += f'{label:20}: ' + bar + f"({pop/self.count*100.0:5.2f}% n={pop})\n" if start == last_bucket_start: break - - txt = txt + f'''{self.count} item(s) -{self.maximum} max -{self.minimum} min -{self.sigma/self.count:.3f} mean -{self.median.get_median()} median''' return txt diff --git a/list_utils.py b/list_utils.py index 74f1cf3..7d3355c 100644 --- a/list_utils.py +++ b/list_utils.py @@ -21,3 +21,9 @@ def flatten(lst: List[Any]) -> List[Any]: if isinstance(lst[0], list): return flatten(lst[0]) + flatten(lst[1:]) return lst[:1] + flatten(lst[1:]) + + +def prepend(item: Any, lst: List[Any]) -> List[Any]: + """Prepend an item to a list.""" + lst = list.insert(0, item) + return lst diff --git a/math_utils.py b/math_utils.py index 56fb707..6277123 100644 --- a/math_utils.py +++ b/math_utils.py @@ -61,6 +61,22 @@ def truncate_float(n: float, decimals: int = 2): return int(n * multiplier) / multiplier +def percentage_to_multiplier(percent: float) -> float: + multiplier = percent / 100 + multiplier += 1.0 + return multiplier + + +def multiplier_to_percent(multiplier: float) -> float: + percent = multiplier + if percent > 0.0: + percent -= 1.0 + else: + percent = 1.0 - percent + percent *= 100.0 + return percent + + @functools.lru_cache(maxsize=1024, typed=True) def is_prime(n: int) -> bool: """Returns True if n is prime and False otherwise""" diff --git a/stopwatch.py b/stopwatch.py index d54af87..1326cb1 100644 --- a/stopwatch.py +++ b/stopwatch.py @@ -24,4 +24,4 @@ class Timer(object): def __exit__(self, *args) -> bool: self.end = time.perf_counter() - return True + return None # don't suppress exceptions diff --git a/string_utils.py b/string_utils.py index 911008d..6fc257d 100644 --- a/string_utils.py +++ b/string_utils.py @@ -1,13 +1,15 @@ #!/usr/bin/env python3 +import contextlib import datetime +import io from itertools import zip_longest import json import logging import random import re import string -from typing import Any, List, Optional +from typing import Any, Callable, List, Optional import unicodedata from uuid import uuid4 @@ -921,6 +923,22 @@ def sprintf(*args, **kwargs) -> str: return ret +class SprintfStdout(object): + def __init__(self) -> None: + self.destination = io.StringIO() + self.recorder = None + + def __enter__(self) -> Callable[[], str]: + self.recorder = contextlib.redirect_stdout(self.destination) + self.recorder.__enter__() + return lambda: self.destination.getvalue() + + def __exit__(self, *args) -> None: + self.recorder.__exit__(*args) + self.destination.seek(0) + return None # don't suppress exceptions + + def is_are(n: int) -> str: if n == 1: return "is" diff --git a/tests/ansi_test.py b/tests/ansi_test.py new file mode 100755 index 0000000..4c1f449 --- /dev/null +++ b/tests/ansi_test.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 + +import unittest + +import ansi +import unittest_utils as uu + + +class TestAnsi(unittest.TestCase): + + def test_colorizer(self): + with ansi.Colorizer() as c: + print("testing...") + print("Section:") + print(" This is some detail.") + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/centcount_test.py b/tests/centcount_test.py new file mode 100755 index 0000000..3122b98 --- /dev/null +++ b/tests/centcount_test.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 + +import unittest + +from type.centcount import CentCount +import unittest_utils as uu + + +class TestCentCount(unittest.TestCase): + + def test_basic_utility(self): + amount = CentCount(1.45) + another = CentCount.parse("USD 1.45") + self.assertEqual(amount, another) + + def test_negation(self): + amount = CentCount(1.45) + amount = -amount + self.assertEqual(CentCount(-1.45), amount) + + def test_addition_and_subtraction(self): + amount = CentCount(1.00) + another = CentCount(2.00) + total = amount + another + self.assertEqual(CentCount(3.00), total) + delta = another - amount + self.assertEqual(CentCount(1.00), delta) + neg = amount - another + self.assertEqual(CentCount(-1.00), neg) + neg += another + self.assertEqual(CentCount(1.00), neg) + neg += 1.00 + self.assertEqual(CentCount(2.00), neg) + neg -= 1.00 + self.assertEqual(CentCount(1.00), neg) + x = 1000 - amount + self.assertEqual(CentCount(9.0), x) + + def test_multiplication(self): + amount = CentCount(3.00) + amount *= 3 + self.assertEqual(CentCount(9.00), amount) + with self.assertRaises(TypeError): + another = CentCount(0.33) + amount *= another + + def test_division(self): + amount = CentCount(10.00) + x = amount / 5.0 + self.assertEqual(CentCount(2.00), x) + with self.assertRaises(TypeError): + another = CentCount(1.33) + amount /= another + + def test_equality(self): + usa = CentCount(1.0, 'USD') + can = CentCount(1.0, 'CAD') + self.assertNotEqual(usa, can) + eh = CentCount(1.0, 'CAD') + self.assertEqual(can, eh) + + def test_comparison(self): + one = CentCount(1.0) + two = CentCount(2.0) + three = CentCount(3.0) + neg_one = CentCount(-1) + self.assertLess(one, two) + self.assertLess(neg_one, one) + self.assertGreater(one, neg_one) + self.assertGreater(three, one) + looney = CentCount(1.0, 'CAD') + with self.assertRaises(TypeError): + print(looney < one) + + def test_strict_mode(self): + one = CentCount(1.0, strict_mode=True) + two = CentCount(2.0, strict_mode=True) + with self.assertRaises(TypeError): + x = one + 2.4 + self.assertEqual(CentCount(3.0), one + two) + with self.assertRaises(TypeError): + x = two - 1.9 + self.assertEqual(CentCount(1.0), two - one) + with self.assertRaises(TypeError): + print(one == 1.0) + self.assertTrue(CentCount(1.0) == one) + with self.assertRaises(TypeError): + print(one < 2.0) + self.assertTrue(one < two) + with self.assertRaises(TypeError): + print(two > 1.0) + self.assertTrue(two > one) + + def test_truncate_and_round(self): + ten = CentCount(10.0) + x = ten * 2 / 3 + x.truncate_fractional_cents() + self.assertEqual(CentCount(6.66), x) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/money_test.py b/tests/money_test.py new file mode 100755 index 0000000..57f4637 --- /dev/null +++ b/tests/money_test.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 + +import unittest + +from type.money import Money +import unittest_utils as uu + + +class TestMoney(unittest.TestCase): + + def test_basic_utility(self): + amount = Money(1.45) + another = Money.parse("USD 1.45") + self.assertAlmostEqual(amount.amount, another.amount) + + def test_negation(self): + amount = Money(1.45) + amount = -amount + self.assertAlmostEqual(Money(-1.45).amount, amount.amount) + + def test_addition_and_subtraction(self): + amount = Money(1.00) + another = Money(2.00) + total = amount + another + self.assertEqual(Money(3.00), total) + delta = another - amount + self.assertEqual(Money(1.00), delta) + neg = amount - another + self.assertEqual(Money(-1.00), neg) + neg += another + self.assertEqual(Money(1.00), neg) + neg += 1.00 + self.assertEqual(Money(2.00), neg) + neg -= 1 + self.assertEqual(Money(1.00), neg) + x = 10 - amount + self.assertEqual(Money(9.0), x) + + def test_multiplication(self): + amount = Money(3.00) + amount *= 3 + self.assertEqual(Money(9.00), amount) + with self.assertRaises(TypeError): + another = Money(0.33) + amount *= another + + def test_division(self): + amount = Money(10.00) + x = amount / 5.0 + self.assertEqual(Money(2.00), x) + with self.assertRaises(TypeError): + another = Money(1.33) + amount /= another + + def test_equality(self): + usa = Money(1.0, 'USD') + can = Money(1.0, 'CAD') + self.assertNotEqual(usa, can) + eh = Money(1.0, 'CAD') + self.assertEqual(can, eh) + + def test_comparison(self): + one = Money(1.0) + two = Money(2.0) + three = Money(3.0) + neg_one = Money(-1) + self.assertLess(one, two) + self.assertLess(neg_one, one) + self.assertGreater(one, neg_one) + self.assertGreater(three, one) + looney = Money(1.0, 'CAD') + with self.assertRaises(TypeError): + print(looney < one) + + def test_strict_mode(self): + one = Money(1.0, strict_mode=True) + two = Money(2.0, strict_mode=True) + with self.assertRaises(TypeError): + x = one + 2.4 + self.assertEqual(Money(3.0), one + two) + with self.assertRaises(TypeError): + x = two - 1.9 + self.assertEqual(Money(1.0), two - one) + with self.assertRaises(TypeError): + print(one == 1.0) + self.assertTrue(Money(1.0) == one) + with self.assertRaises(TypeError): + print(one < 2.0) + self.assertTrue(one < two) + with self.assertRaises(TypeError): + print(two > 1.0) + self.assertTrue(two > one) + + def test_truncate_and_round(self): + ten = Money(10.0) + x = ten * 2 / 3 + self.assertEqual(6.66, x.truncate_fractional_cents()) + x = ten * 2 / 3 + self.assertEqual(6.67, x.round_fractional_cents()) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/rate_test.py b/tests/rate_test.py new file mode 100755 index 0000000..621539b --- /dev/null +++ b/tests/rate_test.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 + +import unittest + +from type.rate import Rate +from type.money import Money + +import unittest_utils as uu + + +class TestRate(unittest.TestCase): + def test_basic_utility(self): + my_stock_returns = Rate(percent_change=-20.0) + my_portfolio = 1000.0 + self.assertAlmostEqual( + 800.0, + my_stock_returns.apply_to(my_portfolio) + ) + + my_bond_returns = Rate(percentage=104.5) + my_money = Money(500.0) + self.assertAlmostEqual( + Money(522.5), + my_bond_returns.apply_to(my_money) + ) + + my_multiplier = Rate(multiplier=1.72) + my_nose_length = 3.2 + self.assertAlmostEqual( + 5.504, + my_multiplier.apply_to(my_nose_length) + ) + + def test_conversions(self): + x = Rate(104.55) + s = x.__repr__() + y = Rate(s) + self.assertAlmostEqual(x, y) + f = float(x) + z = Rate(f) + self.assertAlmostEqual(x, z) + + def test_divide(self): + x = Rate(20.0) + x /= 2 + self.assertAlmostEqual(10.0, x) + x = Rate(-20.0) + x /= 2 + self.assertAlmostEqual(-10.0, x) + + def test_add(self): + x = Rate(5.0) + y = Rate(10.0) + z = x + y + self.assertAlmostEqual(15.0, z) + x = Rate(-5.0) + x += y + self.assertAlmostEqual(5.0, x) + + def test_sub(self): + x = Rate(5.0) + y = Rate(10.0) + z = x - y + self.assertAlmostEqual(-5.0, z) + z = y - x + self.assertAlmostEqual(5.0, z) + + def test_repr(self): + x = Rate(percent_change=-50.0) + s = x.__repr__(relative=True) + self.assertEqual("-50.000%", s) + s = x.__repr__() + self.assertEqual("+50.000%", s) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/string_utils_test.py b/tests/string_utils_test.py index 0472daa..cc57036 100755 --- a/tests/string_utils_test.py +++ b/tests/string_utils_test.py @@ -180,6 +180,12 @@ class TestStringUtils(unittest.TestCase): self.assertFalse(su.is_snake_case('thisIsATest')) self.assertTrue(su.is_snake_case('this_is_a_test')) + def test_sprintf_context(self): + with su.SprintfStdout() as buf: + print("This is a test.") + print("This is another one.") + self.assertEqual('This is a test.\nThis is another one.\n', buf()) + if __name__ == '__main__': bootstrap.initialize(unittest.main)() diff --git a/type/centcount.py b/type/centcount.py new file mode 100644 index 0000000..4181721 --- /dev/null +++ b/type/centcount.py @@ -0,0 +1,226 @@ +#!/usr/bin/env python3 + +import re +from typing import Optional, TypeVar, Tuple + +import math_utils + + +T = TypeVar('T', bound='CentCount') + + +class CentCount(object): + """A class for representing monetary amounts potentially with + different currencies. + """ + + def __init__ ( + self, + centcount, + currency: str = 'USD', + *, + strict_mode = False + ): + self.strict_mode = strict_mode + if isinstance(centcount, str): + ret = CentCount._parse(centcount) + if ret is None: + raise Exception(f'Unable to parse money string "{centcount}"') + centcount = ret[0] + currency = ret[1] + if isinstance(centcount, float): + centcount = int(centcount * 100.0) + if not isinstance(centcount, int): + centcount = int(centcount) + self.centcount = centcount + if not currency: + self.currency: Optional[str] = None + else: + self.currency: Optional[str] = currency + + def __repr__(self): + a = float(self.centcount) + a /= 100 + a = round(a, 2) + s = f'{a:,.2f}' + if self.currency is not None: + return '%s %s' % (s, self.currency) + else: + return '$%s' % s + + def __pos__(self): + return CentCount(centcount=self.centcount, currency=self.currency) + + def __neg__(self): + return CentCount(centcount=-self.centcount, currency=self.currency) + + def __add__(self, other): + if isinstance(other, CentCount): + if self.currency == other.currency: + return CentCount( + centcount = self.centcount + other.centcount, + currency = self.currency + ) + else: + raise TypeError('Incompatible currencies in add expression') + else: + if self.strict_mode: + raise TypeError('In strict_mode only two moneys can be added') + else: + return self.__add__(CentCount(other, self.currency)) + + def __sub__(self, other): + if isinstance(other, CentCount): + if self.currency == other.currency: + return CentCount( + centcount = self.centcount - other.centcount, + currency = self.currency + ) + else: + raise TypeError('Incompatible currencies in add expression') + else: + if self.strict_mode: + raise TypeError('In strict_mode only two moneys can be added') + else: + return self.__sub__(CentCount(other, self.currency)) + + def __mul__(self, other): + if isinstance(other, CentCount): + raise TypeError('can not multiply monetary quantities') + else: + return CentCount( + centcount = int(self.centcount * float(other)), + currency = self.currency + ) + + def __truediv__(self, other): + if isinstance(other, CentCount): + raise TypeError('can not divide monetary quantities') + else: + return CentCount( + centcount = int(float(self.centcount) / float(other)), + currency = self.currency + ) + + def __int__(self): + return self.centcount.__int__() + + def __float__(self): + return self.centcount.__float__() / 100.0 + + def truncate_fractional_cents(self): + x = int(self) + self.centcount = int(math_utils.truncate_float(x)) + return self.centcount + + def round_fractional_cents(self): + x = int(self) + self.centcount = int(round(x, 2)) + return self.centcount + + __radd__ = __add__ + + def __rsub__(self, other): + if isinstance(other, CentCount): + if self.currency == other.currency: + return CentCount( + centcount = other.centcount - self.centcount, + currency = self.currency + ) + else: + raise TypeError('Incompatible currencies in sub expression') + else: + if self.strict_mode: + raise TypeError('In strict_mode only two moneys can be added') + else: + return CentCount( + centcount = int(other) - self.centcount, + currency = self.currency + ) + + __rmul__ = __mul__ + + # + # Override comparison operators to also compare currency. + # + def __eq__(self, other): + if other is None: + return False + if isinstance(other, CentCount): + return ( + self.centcount == other.centcount and + self.currency == other.currency + ) + if self.strict_mode: + raise TypeError("In strict mode only two CentCounts can be compared") + else: + return self.centcount == int(other) + + def __ne__(self, other): + result = self.__eq__(other) + if result is NotImplemented: + return result + return not result + + def __lt__(self, other): + if isinstance(other, CentCount): + if self.currency == other.currency: + return self.centcount < other.centcount + else: + raise TypeError('can not directly compare different currencies') + else: + if self.strict_mode: + raise TypeError('In strict mode, only two CentCounts can be compated') + else: + return self.centcount < int(other) + + def __gt__(self, other): + if isinstance(other, CentCount): + if self.currency == other.currency: + return self.centcount > other.centcount + else: + raise TypeError('can not directly compare different currencies') + else: + if self.strict_mode: + raise TypeError('In strict mode, only two CentCounts can be compated') + else: + return self.centcount > int(other) + + def __le__(self, other): + return self < other or self == other + + def __ge__(self, other): + return self > other or self == other + + def __hash__(self): + return self.__repr__ + + CENTCOUNT_RE = re.compile("^([+|-]?)(\d+)(\.\d+)$") + CURRENCY_RE = re.compile("^[A-Z][A-Z][A-Z]$") + + @classmethod + def _parse(cls, s: str) -> Optional[Tuple[int, str]]: + centcount = None + currency = None + s = s.strip() + chunks = s.split(' ') + try: + for chunk in chunks: + if CentCount.CENTCOUNT_RE.match(chunk) is not None: + centcount = int(float(chunk) * 100.0) + elif CentCount.CURRENCY_RE.match(chunk) is not None: + currency = chunk + except: + pass + if centcount is not None and currency is not None: + return (centcount, currency) + elif centcount is not None: + return (centcount, 'USD') + return None + + @classmethod + def parse(cls, s: str) -> T: + chunks = CentCount._parse(s) + if chunks is not None: + return CentCount(chunks[0], chunks[1]) + raise Exception(f'Unable to parse money string "{s}"') diff --git a/type/money.py b/type/money.py new file mode 100644 index 0000000..c77a938 --- /dev/null +++ b/type/money.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 + +from decimal import Decimal +import re +from typing import Optional, TypeVar, Tuple + +import math_utils + + +T = TypeVar('T', bound='Money') + + +class Money(object): + """A class for representing monetary amounts potentially with + different currencies. + """ + + def __init__ ( + self, + amount: Decimal = Decimal("0.0"), + currency: str = 'USD', + *, + strict_mode = False + ): + self.strict_mode = strict_mode + if isinstance(amount, str): + ret = Money._parse(amount) + if ret is None: + raise Exception(f'Unable to parse money string "{amount}"') + amount = ret[0] + currency = ret[1] + if not isinstance(amount, Decimal): + amount = Decimal(float(amount)) + self.amount = amount + if not currency: + self.currency: Optional[str] = None + else: + self.currency: Optional[str] = currency + + def __repr__(self): + a = float(self.amount) + a = round(a, 2) + s = f'{a:,.2f}' + if self.currency is not None: + return '%s %s' % (s, self.currency) + else: + return '$%s' % s + + def __pos__(self): + return Money(amount=self.amount, currency=self.currency) + + def __neg__(self): + return Money(amount=-self.amount, currency=self.currency) + + def __add__(self, other): + if isinstance(other, Money): + if self.currency == other.currency: + return Money( + amount = self.amount + other.amount, + currency = self.currency + ) + else: + raise TypeError('Incompatible currencies in add expression') + else: + if self.strict_mode: + raise TypeError('In strict_mode only two moneys can be added') + else: + return Money( + amount = self.amount + Decimal(float(other)), + currency = self.currency + ) + + def __sub__(self, other): + if isinstance(other, Money): + if self.currency == other.currency: + return Money( + amount = self.amount - other.amount, + currency = self.currency + ) + else: + raise TypeError('Incompatible currencies in add expression') + else: + if self.strict_mode: + raise TypeError('In strict_mode only two moneys can be added') + else: + return Money( + amount = self.amount - Decimal(float(other)), + currency = self.currency + ) + + def __mul__(self, other): + if isinstance(other, Money): + raise TypeError('can not multiply monetary quantities') + else: + return Money( + amount = self.amount * Decimal(float(other)), + currency = self.currency + ) + + def __truediv__(self, other): + if isinstance(other, Money): + raise TypeError('can not divide monetary quantities') + else: + return Money( + amount = self.amount / Decimal(float(other)), + currency = self.currency + ) + + def __float__(self): + return self.amount.__float__() + + def truncate_fractional_cents(self): + x = float(self) + self.amount = Decimal(math_utils.truncate_float(x)) + return self.amount + + def round_fractional_cents(self): + x = float(self) + self.amount = Decimal(round(x, 2)) + return self.amount + + __radd__ = __add__ + + def __rsub__(self, other): + if isinstance(other, Money): + if self.currency == other.currency: + return Money( + amount = other.amount - self.amount, + currency = self.currency + ) + else: + raise TypeError('Incompatible currencies in sub expression') + else: + if self.strict_mode: + raise TypeError('In strict_mode only two moneys can be added') + else: + return Money( + amount = Decimal(float(other)) - self.amount, + currency = self.currency + ) + + __rmul__ = __mul__ + + # + # Override comparison operators to also compare currency. + # + def __eq__(self, other): + if other is None: + return False + if isinstance(other, Money): + return ( + self.amount == other.amount and + self.currency == other.currency + ) + if self.strict_mode: + raise TypeError("In strict mode only two Moneys can be compared") + else: + return self.amount == Decimal(float(other)) + + def __ne__(self, other): + result = self.__eq__(other) + if result is NotImplemented: + return result + return not result + + def __lt__(self, other): + if isinstance(other, Money): + if self.currency == other.currency: + return self.amount < other.amount + else: + raise TypeError('can not directly compare different currencies') + else: + if self.strict_mode: + raise TypeError('In strict mode, only two Moneys can be compated') + else: + return self.amount < Decimal(float(other)) + + def __gt__(self, other): + if isinstance(other, Money): + if self.currency == other.currency: + return self.amount > other.amount + else: + raise TypeError('can not directly compare different currencies') + else: + if self.strict_mode: + raise TypeError('In strict mode, only two Moneys can be compated') + else: + return self.amount > Decimal(float(other)) + + def __le__(self, other): + return self < other or self == other + + def __ge__(self, other): + return self > other or self == other + + def __hash__(self): + return self.__repr__ + + AMOUNT_RE = re.compile("^([+|-]?)(\d+)(\.\d+)$") + CURRENCY_RE = re.compile("^[A-Z][A-Z][A-Z]$") + + @classmethod + def _parse(cls, s: str) -> Optional[Tuple[Decimal, str]]: + amount = None + currency = None + s = s.strip() + chunks = s.split(' ') + try: + for chunk in chunks: + if Money.AMOUNT_RE.match(chunk) is not None: + amount = Decimal(chunk) + elif Money.CURRENCY_RE.match(chunk) is not None: + currency = chunk + except: + pass + if amount is not None and currency is not None: + return (amount, currency) + elif amount is not None: + return (amount, 'USD') + return None + + @classmethod + def parse(cls, s: str) -> T: + chunks = Money._parse(s) + if chunks is not None: + return Money(chunks[0], chunks[1]) + raise Exception(f'Unable to parse money string "{s}"') diff --git a/type/rate.py b/type/rate.py new file mode 100644 index 0000000..3161131 --- /dev/null +++ b/type/rate.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 + +from typing import Optional + + +class Rate(object): + def __init__( + self, + multiplier: Optional[float] = None, + *, + percentage: Optional[float] = None, + percent_change: Optional[float] = None, + ): + count = 0 + if multiplier is not None: + if isinstance(multiplier, str): + multiplier = multiplier.replace('%', '') + m = float(multiplier) + m /= 100 + self.multiplier = m + else: + self.multiplier = multiplier + count += 1 + if percentage is not None: + self.multiplier = percentage / 100 + count += 1 + if percent_change is not None: + self.multiplier = 1.0 + percent_change / 100 + count += 1 + if count != 1: + raise Exception( + 'Exactly one of percentage, percent_change or multiplier is required.' + ) + + def apply_to(self, other): + return self.__mul__(other) + + def of(self, other): + return self.__mul__(other) + + def __float__(self): + return self.multiplier + + def __mul__(self, other): + return self.multiplier * float(other) + + __rmul__ = __mul__ + + def __truediv__(self, other): + return self.multiplier / float(other) + + def __add__(self, other): + return self.multiplier + float(other) + + __radd__ = __add__ + + def __sub__(self, other): + return self.multiplier - float(other) + + def __eq__(self, other): + return self.multiplier == float(other) + + def __ne__(self, other): + return not self.__eq__(other) + + def __lt__(self, other): + return self.multiplier < float(other) + + def __gt__(self, other): + return self.multiplier > float(other) + + def __le__(self, other): + return self < other or self == other + + def __ge__(self, other): + return self > other or self == other + + def __hash__(self): + return self.multiplier + + def __repr__(self, + *, + relative=False, + places=3): + if relative: + percentage = (self.multiplier - 1.0) * 100.0 + else: + percentage = self.multiplier * 100.0 + return f'{percentage:+.{places}f}%' diff --git a/unittest_utils.py b/unittest_utils.py index 5987da6..2dc8cfe 100644 --- a/unittest_utils.py +++ b/unittest_utils.py @@ -1,9 +1,10 @@ #!/usr/bin/env python3 """Helpers for unittests. Note that when you import this we -automatically wrap unittest.main() with a call to bootstrap.initialize -so that we getLogger config, commandline args, logging control, -etc... this works fine but it's a little hacky so caveat emptor. + automatically wrap unittest.main() with a call to + bootstrap.initialize so that we getLogger config, commandline args, + logging control, etc... this works fine but it's a little hacky so + caveat emptor. """ import contextlib @@ -170,7 +171,7 @@ class RecordStdout(object): def __exit__(self, *args) -> bool: self.recorder.__exit__(*args) self.destination.seek(0) - return True + return None class RecordStderr(object): @@ -192,7 +193,7 @@ class RecordStderr(object): def __exit__(self, *args) -> bool: self.recorder.__exit__(*args) self.destination.seek(0) - return True + return None class RecordMultipleStreams(object): -- 2.47.1