#!/usr/bin/env python3 from abc import abstractmethod import logging from typing import Optional from type.money import Money from person import Person import data from taxman import TaxCollector logger = logging.getLogger(__name__) class Account(object): def __init__(self, name: str, owner: Person): self.name = name self.owner = owner def get_name(self) -> str: return self.name def get_owner(self) -> Person: return self.owner def belongs_to_scott(self) -> bool: return self.get_owner() == Person.SCOTT def belongs_to_lynn(self) -> bool: return self.get_owner() == Person.LYNN @abstractmethod def get_balance(self) -> Money: """Returns the account's current balance""" pass @abstractmethod def appreciate(self, rate: float) -> Money: """Grow/shrink the balance using rate (<1 = shrink, 1.0=identity, >1 = grow). Return the new balance. Raise on error. """ pass @abstractmethod def withdraw( self, amount: Money, taxes: Optional[TaxCollector] = None ) -> Money: """Withdraw money from the account and return the Money. Raise on error. If the TaxCollector is passed, the money will be recorded as income if applicable per the account type. """ pass @abstractmethod def is_age_restricted(self) -> bool: """Is this account age restricted. Subclasses should implement.""" pass @abstractmethod def has_rmd(self) -> bool: """Does this account have a required minimum distribution? Sub- classes should implement.""" pass @abstractmethod def do_rmd_withdrawal(self, owner_age: int, taxes: Optional[TaxCollector]) -> Money: """Compute the magnitude of any RMD, withdraw it from the account and return it. If TaxCollector is provided, record the distribution as income if applicable. Raise on error.""" pass @abstractmethod def has_roth(self) -> bool: """Does this account have a Roth balance?""" pass @abstractmethod def dump_final_report(self) -> None: """Produce a simulation final report.""" pass class AgeRestrictedTaxDeferredAccount(Account): """Account object used to represent an age-restricted tax deferred account such as a 401(k), IRA or annuity.""" def __init__(self, name: str, owner: Person, *, total_balance: Money = Money(0), roth_subbalance: Money = Money(0)): """C'tor for this class takes a roth_amount, the number of dollars in the account that are in-plan Roth and can be taken out tax free. It keeps this pile separate from the pretax pile and tries to estimate taxes.""" super().__init__(name, owner) self.roth: Money = roth_subbalance self.pretax: Money = total_balance - roth_subbalance self.total_roth_withdrawals: Money = 0 self.total_pretax_withdrawals: Money = 0 self.total_investment_gains: Money = 0 self.total_roth_conversions: Money = 0 def is_age_restricted(self) -> bool: return True def withdraw(self, amount: Money, taxes: Optional[TaxCollector]) -> Money: """Assume that money withdrawn from this account will be a mixture of pretax funds (which count as ordinary income) and Roth funds (which are available tax-free). Assume that the ratio of pretax to Roth in this overall account determines the amount from each partition in this withdrawal.""" balance = self.get_balance() if balance < amount: raise Exception("Insufficient funds") ratio = float(self.roth) / float(balance) roth_part = amount * ratio pretax_part = amount - roth_part if roth_part > 0: self.roth -= roth_part self.total_roth_withdrawals += roth_part logger.debug( f'Account {self.name} satisfying {amount} with {roth_part} Roth.' ) taxes.record_roth_income(roth_part) self.pretax -= pretax_part self.total_pretax_withdrawals += pretax_part logger.debug( f'Account {self.name} satisfying {amount} with {pretax_part} pretax.' ) taxes.record_ordinary_income(pretax_part) return amount def appreciate(self, rate: float) -> Money: """In this class we basically ignore the balance field in favor of just using pretax and roth so that we can track them separately.""" old_pretax = self.pretax self.pretax *= rate delta = self.pretax - old_pretax self.total_investment_gains += delta old_roth = self.roth self.roth *= rate delta = self.roth - old_roth self.total_investment_gains += delta return self.get_balance() def get_balance(self): """In this class we basically ignore the balance field in favor of just using pretax and roth so that we can track them separately.""" return self.pretax + self.roth def has_rmd(self): return True def do_rmd_withdrawal(self, owner_age, taxes): balance = self.get_balance() if balance > 0 and owner_age >= 72: rmd_factor = data.get_actuary_number_years_to_live(owner_age) amount = balance / rmd_factor self.withdraw(amount, taxes) return amount return 0 def has_roth(self): return True def do_roth_conversion(self, amount: Money) -> Money: if amount <= 0: return Money(0) if self.pretax >= amount: self.roth += amount self.pretax -= amount self.total_roth_conversions += amount logger.debug( f'Account {self.name} executed pre-tax --> Roth conversion of {amount}' ) return amount elif self.pretax > 0: actual_amount = self.pretax self.roth += actual_amount self.pretax = 0 self.total_roth_conversions += actual_amount logger.debug( f'Account {self.name} executed pre-tax --> Roth conversion of ' + f'{actual_amount}' ) return actual_amount return Money(0) def dump_final_report(self): print(f'Account: {self.name}:') print(" %-50s: %18s" % ("Ending balance", self.get_balance())) print(" %-50s: %18s" % ("Total investment gains", self.total_investment_gains)) print(" %-50s: %18s" % ("Total Roth withdrawals", self.total_roth_withdrawals)) print(" %-50s: %18s" % ("Total pre-tax withdrawals", self.total_pretax_withdrawals)) print(" %-50s: %18s" % ("Total pre-tax converted to Roth", self.total_roth_conversions)) class AgeRestrictedRothAccount(AgeRestrictedTaxDeferredAccount): """This is an object to represent a Roth account like a Roth IRA. All money in here is tax free. Most of the base account class works here including the implementation of withdraw() which says that none of the money withdrawn was taxable.""" def __init__(self, name: str, owner: Person, *, total_balance: Money = 0): super().__init__( name, owner, total_balance=total_balance, roth_subbalance=total_balance ) def has_rmd(self): return False def do_rmd_withdrawal(self, owner_age, taxes): raise Exception("This account has no RMDs") def do_roth_conversion(self, amount) -> Money: return Money(0) class BrokerageAccount(Account): """A class to represent money in a taxable brokerage account.""" def __init__(self, name: str, owner: Person, *, total_balance = Money(0), cost_basis = Money(0)): """The c'tor of this class partitions balance into three pieces: the cost_basis (i.e. how much money was invested in the account), the short_term_gain (i.e. appreciation that has been around for less than a year) and long_term_gain (i.e. appreciation that has been around for more than a year). We separate these because taxes on long_term_gain (and qualified dividends, which are not modeled) are usually lower than short_term_gain. Today those are taxed at 15% and as ordinary income, respectively.""" super().__init__(name, owner) self.cost_basis = cost_basis self.short_term_gain = Money(0) self.long_term_gain = total_balance - cost_basis self.total_cost_basis_withdrawals = Money(0) self.total_long_term_gain_withdrawals = Money(0) self.total_short_term_gain_withdrawals = Money(0) self.total_investment_gains = Money(0) def withdraw(self, amount: Money, taxes: Optional[TaxCollector]) -> Money: """Override the base class' withdraw implementation since we're dealing with three piles of money instead of one. When you sell securities to get money out of this account the gains are taxed (and the cost_basis part isn't). Assume that the ratio of cost_basis to overall balance can be used to determine how much of the withdrawal will be taxed (and how).""" balance = self.get_balance() if balance < amount: raise Exception("Insufficient funds") if self.cost_basis > 0 and (self.short_term_gain + self.long_term_gain) > 0: return self._withdraw_with_ratio(amount, taxes) else: return self._withdraw_waterfall(amount, taxes) def _withdraw_short_term_gain(self, amount: Money, taxes: Optional[TaxCollector]) -> Money: if self.short_term_gain >= amount: self.short_term_gain -= amount self.total_short_term_gain_withdrawals += amount if taxes is not None: taxes.record_short_term_gain(amount) return amount raise Exception('Insufficient funds') def _withdraw_long_term_gain(self, amount: Money, taxes: Optional[TaxCollector]) -> Money: if self.long_term_gain >= amount: self.long_term_gain -= amount self.total_long_term_gain_withdrawals += amount if taxes is not None: taxes.record_dividend_or_long_term_gain(amount) return amount raise Exception('Insufficient funds') def _withdraw_cost_basis(self, amount: Money, taxes: Optional[TaxCollector]) -> Money: if self.cost_basis >= amount: self.cost_basis -= amount self.total_cost_basis_withdrawals += amount return amount raise Exception('Insufficient funds') def _withdraw_with_ratio(self, amount: Money, taxes: Optional[TaxCollector]) -> Money: ratio = float(self.cost_basis) / float(self.get_balance()) invested_capital_part = amount * ratio invested_capital_part.truncate_fractional_cents() gains_part = amount - invested_capital_part gains_part -= 0.01 if self.cost_basis >= invested_capital_part: self._withdraw_cost_basis(invested_capital_part, taxes) logger.debug( f'Account {self.name}: satisfying {invested_capital_part} from cost basis funds.' ) logger.debug( f'Account {self.name}: satisfying {gains_part} from investment gains...' ) self._withdraw_from_gains(gains_part, taxes) else: logger.debug( f'Account {self.name}: satisfying {gains_part} from investment gains...' ) self._withdraw_from_gains(amount, taxes) return amount def _withdraw_waterfall(self, amount: Money, taxes: Optional[TaxCollector]) -> Money: to_find = amount if self.short_term_gain > 0: if to_find < self.short_term_gain: self.short_term_gain -= to_find self.total_short_term_gain_withdrawals += to_find to_find = Money(0) else: to_find -= self.short_term_gain self.total_short_term_gain_withdrawals += self.short_term_gain self.short_term_gain = Money(0) if self.long_term_gain > 0: if to_find < self.long_term_gain: self.long_term_gain -= to_find self.total_long_term_gain_withdrawals += to_find to_find = Money(0) else: to_find -= self.long_term_gain self.total_long_term_gain_withdrawals += self.long_term_gain self.long_term_gain = Money(0) if self.cost_basis > 0: if to_find < self.cost_basis: self.cost_basis -= to_find to_find = Money(0) else: to_find -= self.cost_basis self.cost_basis = Money(0) assert(to_find == Money(0)) def _withdraw_from_gains(self, amount: Money, taxes: Optional[TaxCollector]) -> Money: """Withdraw some money from gains. Prefer the long term ones if possible.""" to_find = amount if to_find > (self.long_term_gain + self.short_term_gain): raise Exception("Insufficient funds") if self.long_term_gain >= to_find: self._withdraw_long_term_gain(to_find, taxes) logger.debug( f'Account {self.name}: satisfying {to_find} from long term gains.' ) return to_find logger.debug( f'Account {self.name}: satisfying {self.long_term_gain} from long term gains ' + '(exhausting long term gains)' ) self._withdraw_long_term_gain(self.long_term_gain, taxes) to_find -= self.long_term_gain self._withdraw_short_term_gain(to_find, taxes) logger.debug( f'Account {self.name}: satisfying {to_find} from short term gains' ) return amount def get_balance(self) -> Money: """We're ignoring the base class' balance field in favor of tracking it as three separate partitions of money.""" return self.cost_basis + self.long_term_gain + self.short_term_gain def appreciate(self, rate: float) -> Money: """Appreciate... another year has passed so short_term_gains turn into long_term_gains and the appreciation is our new short_term_gains.""" balance = self.get_balance() gain = balance * (rate - 1.0) # Note: rate is something like 1.04 self.total_investment_gains += gain self.long_term_gain += self.short_term_gain self.short_term_gain = gain return self.get_balance() def is_age_restricted(self) -> bool: return False def has_rmd(self) -> bool: return False def has_roth(self) -> bool: return False def do_roth_conversion(self, amount): return Money(0) def dump_final_report(self): print(f'Account {self.name}:') print(" %-50s: %18s" % ("Ending balance", self.get_balance())) print(" %-50s: %18s" % ("Total investment gains", self.total_investment_gains)) print(" %-50s: %18s" % ("Total cost basis withdrawals", self.total_cost_basis_withdrawals)) print(" %-50s: %18s" % ("Total long term gain withdrawals", self.total_long_term_gain_withdrawals)) print(" %-50s: %18s" % ("Total short term gain withdrawals", self.total_short_term_gain_withdrawals))