Lots of changes.
authorScott Gasch <[email protected]>
Fri, 9 Jul 2021 02:44:27 +0000 (19:44 -0700)
committerScott Gasch <[email protected]>
Fri, 9 Jul 2021 02:44:27 +0000 (19:44 -0700)
27 files changed:
ansi.py
argparse_utils.py
bootstrap.py
config.py
dateparse/dateparse_utils.g4
dateparse/dateparse_utils.py
datetime_utils.py
decorator_utils.py
dict_utils.py
executors.py
file_utils.py
light_utils.py
list_utils.py
logging_utils.py
misc_utils.py
ml_model_trainer.py
presence.py
remote_worker.py
simple_acl.py [new file with mode: 0644]
smart_future.py
state_tracker.py
string_utils.py
tests/dateparse_utils_test.py [new file with mode: 0755]
tests/dict_utils_test.py [new file with mode: 0755]
tests/simple_acl_test.py [new file with mode: 0755]
tests/string_utils_test.py
unittest_utils.py [new file with mode: 0644]

diff --git a/ansi.py b/ansi.py
index dc9a31542f4551a73d9ab1de341b7fce11e85f47..089863931e6f8fc7b7f000af56c37274e33916af 100755 (executable)
--- a/ansi.py
+++ b/ansi.py
@@ -1842,18 +1842,16 @@ def bg(name: Optional[str] = "",
     return bg_24bit(red, green, blue)
 
 
-def main() -> None:
-    name = " ".join(sys.argv[1:])
-    for possibility in COLOR_NAMES_TO_RGB:
-        if name in possibility:
-            f = fg(possibility)
-            b = bg(possibility)
-            _ = pick_contrasting_color(possibility)
-            xf = fg(None, _[0], _[1], _[2])
-            xb = bg(None, _[0], _[1], _[2])
-            print(f'{f}{xb}{possibility}{reset()}\t\t\t'
-                  f'{b}{xf}{possibility}{reset()}')
-
-
 if __name__ == '__main__':
+    def main() -> None:
+        name = " ".join(sys.argv[1:])
+        for possibility in COLOR_NAMES_TO_RGB:
+            if name in possibility:
+                f = fg(possibility)
+                b = bg(possibility)
+                _ = pick_contrasting_color(possibility)
+                xf = fg(None, _[0], _[1], _[2])
+                xb = bg(None, _[0], _[1], _[2])
+                print(f'{f}{xb}{possibility}{reset()}\t\t\t'
+                      f'{b}{xf}{possibility}{reset()}')
     main()
index 75bec0475ce25f4246d7ce2e2f4d05aee306c4c5..80046dee7b2c4619406bfec618d131df24d1e5c3 100644 (file)
@@ -1,6 +1,7 @@
 #!/usr/bin/python3
 
 import argparse
+import datetime
 import logging
 import os
 
@@ -79,6 +80,7 @@ def valid_mac(mac: str) -> str:
 
 
 def valid_percentage(num: str) -> float:
+    num = num.strip('%')
     n = float(num)
     if 0.0 <= n <= 100.0:
         return n
@@ -94,3 +96,21 @@ def valid_filename(filename: str) -> str:
     msg = f"{filename} was not found and is therefore invalid."
     logger.warning(msg)
     raise argparse.ArgumentTypeError(msg)
+
+
+def valid_date(txt: str) -> datetime.date:
+    date = string_utils.to_date(txt)
+    if date is not None:
+        return date
+    msg = f'Cannot parse argument as a date: {txt}'
+    logger.warning(msg)
+    raise argparse.ArgumentTypeError(msg)
+
+
+def valid_datetime(txt: str) -> datetime.datetime:
+    dt = string_utils.to_datetime(txt)
+    if dt is not None:
+        return dt
+    msg = f'Cannot parse argument as datetime: {txt}'
+    logger.warning(msg)
+    raise argparse.ArgumentTypeError(msg)
index d1233e9344f54be3f84555a9b392faf1d5fe2054..0d37dcbd7215067ad17801e48e8eafab64c63c61 100644 (file)
@@ -47,7 +47,8 @@ def initialize(funct):
         sys.excepthook = handle_uncaught_exception
         config.parse()
         logging_utils.initialize_logging(logging.getLogger())
-        logger.debug(f"About to invoke {funct}...")
+        config.late_logging()
+        logger.debug(f'Starting {funct.__name__}')
         start = time.perf_counter()
         ret = funct(*args, **kwargs)
         end = time.perf_counter()
@@ -59,6 +60,9 @@ def initialize(funct):
                      f'child system: {cstime}s\n'
                      f'elapsed: {elapsed_time}s\n'
                      f'walltime: {end - start}s\n')
-        logger.info(f'Exit {ret}')
+        if ret != 0:
+            logger.info(f'Exit {ret}')
+        else:
+            logger.debug(f'Exit {ret}')
         sys.exit(ret)
     return initialize_wrapper
index bc0dcdf14f23b5edd8582a4fd8afcd9425a5bd36..9b4a53d120f8c4c5ee73f393f8080bb1f456e137 100644 (file)
--- a/config.py
+++ b/config.py
@@ -55,17 +55,23 @@ different modules).  Usage:
       --dry_run
                    Should we really do the thing?
 
-    Arguments themselves should be accessed via config.config['arg_name'].  e.g.
+    Arguments themselves should be accessed via
+    config.config['arg_name'].  e.g.
 
     if not config.config['dry_run']:
         module.do_the_thing()
+
 """
 
 import argparse
+import logging
+import os
 import pprint
 import re
 import sys
-from typing import Dict, Any
+from typing import Any, Dict, List
+
+import string_utils
 
 # Note: at this point in time, logging hasn't been configured and
 # anything we log will come out the root logger.
@@ -96,7 +102,7 @@ class LoadFromFile(argparse.Action):
 args = argparse.ArgumentParser(
     description=f"This program uses config.py ({__file__}) for global, cross-module configuration.",
     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
-    fromfile_prefix_chars="@"
+    fromfile_prefix_chars="@",
 )
 config_parse_called = False
 
@@ -104,6 +110,9 @@ config_parse_called = False
 # It is also this variable that modules use to access parsed arguments
 config: Dict[str, Any] = {}
 
+# Defer logging messages until later when logging has been initialized.
+saved_messages: List[str] = []
+
 
 def add_commandline_args(title: str, description: str = ""):
     """Create a new context for arguments and return a handle."""
@@ -137,11 +146,68 @@ group.add_argument(
 )
 
 
+def is_flag_already_in_argv(var: str):
+    """Is a particular flag passed on the commandline?"""
+    for _ in sys.argv:
+        if var in _:
+            return True
+    return False
+
+
 def parse() -> Dict[str, Any]:
     """Main program should call this early in main()"""
     global config_parse_called
+    if config_parse_called:
+        return
     config_parse_called = True
-    config.update(vars(args.parse_args()))
+    global saved_messages
+
+    # Examine the environment variables to settings that match
+    # known flags.
+    usage_message = args.format_usage()
+    optional = False
+    var = ''
+    for x in usage_message.split():
+        if x[0] == '[':
+            optional = True
+        if optional:
+            var += f'{x} '
+            if x[-1] == ']':
+                optional = False
+                var = var.strip()
+                var = var.strip('[')
+                var = var.strip(']')
+                chunks = var.split()
+                if len(chunks) > 1:
+                    var = var.split()[0]
+
+                # Environment vars the same as flag names without
+                # the initial -'s and in UPPERCASE.
+                env = var.strip('-').upper()
+                if env in os.environ:
+                    if not is_flag_already_in_argv(var):
+                        value = os.environ[env]
+                        saved_messages.append(
+                            f'Initialized from environment: {var} = {value}'
+                        )
+                        if len(chunks) == 1 and string_utils.to_bool(value):
+                            sys.argv.append(var)
+                        elif len(chunks) > 1:
+                            sys.argv.append(var)
+                            sys.argv.append(value)
+                var = ''
+                env = ''
+        else:
+            next
+
+    # Parse (possibly augmented) commandline args with argparse normally.
+    #config.update(vars(args.parse_args()))
+    known, unknown = args.parse_known_args()
+    config.update(vars(known))
+
+    # Reconstruct the argv with unrecognized flags for the benefit of
+    # future argument parsers.
+    sys.argv = sys.argv[:1] + unknown
 
     if config['config_savefile']:
         with open(config['config_savefile'], 'w') as wf:
@@ -149,11 +215,11 @@ def parse() -> Dict[str, Any]:
 
     if config['config_dump']:
         dump_config()
-
     return config
 
 
 def has_been_parsed() -> bool:
+    """Has the global config been parsed yet?"""
     global config_parse_called
     return config_parse_called
 
@@ -162,3 +228,11 @@ def dump_config():
     """Print the current config to stdout."""
     print("Global Configuration:", file=sys.stderr)
     pprint.pprint(config, stream=sys.stderr)
+
+
+def late_logging():
+    """Log messages saved earlier now that logging has been initialized."""
+    logger = logging.getLogger(__name__)
+    global saved_messages
+    for _ in saved_messages:
+        logger.debug(_)
index 86a9cd40c4dfb7e2e13aae870551033cc248673c..d45b18768adf6efec877803b24cb674a51d959c2 100644 (file)
@@ -1,5 +1,4 @@
 // antlr4 -Dlanguage=Python3 ./dateparse_utils.g4
-
 // Hi, self.  In ANTLR grammars, there are two separate types of symbols: those
 // for the lexer and those for the parser.  The former begin with a CAPITAL
 // whereas the latter begin with lowercase.  The order of the lexer symbols
 
 grammar dateparse_utils;
 
-parse: dateExpr ;
+parse
+    : SPACE* dateExpr
+    | SPACE* timeExpr
+    | SPACE* dateExpr SPACE* dtdiv? SPACE* timeExpr
+    | SPACE* timeExpr SPACE* tddiv? SPACE+ dateExpr
+    ;
 
 dateExpr
     : singleDateExpr
     | baseAndOffsetDateExpr
     ;
 
+timeExpr
+    : singleTimeExpr
+    | baseAndOffsetTimeExpr
+    ;
+
+singleTimeExpr
+    : twentyFourHourTimeExpr
+    | twelveHourTimeExpr
+    | specialTimeExpr
+    ;
+
+twentyFourHourTimeExpr
+    : hour ((SPACE|tdiv)+ minute ((SPACE|tdiv)+ second ((SPACE|tdiv)+ micros)? )? )? SPACE* tzExpr?
+    ;
+
+twelveHourTimeExpr
+    : hour ((SPACE|tdiv)+ minute ((SPACE|tdiv)+ second ((SPACE|tdiv)+ micros)? )? )? SPACE* ampm SPACE* tzExpr?
+    ;
+
+ampm: ('a'|'am'|'p'|'pm'|'AM'|'PM'|'A'|'P');
+
 singleDateExpr
     : monthDayMaybeYearExpr
     | dayMonthMaybeYearExpr
@@ -28,45 +53,91 @@ singleDateExpr
     | specialDateMaybeYearExpr
     | nthWeekdayInMonthMaybeYearExpr
     | firstLastWeekdayInMonthMaybeYearExpr
+    | deltaDateExprRelativeToTodayImplied
+    | dayName
     ;
 
 monthDayMaybeYearExpr
-    : monthExpr DIV* dayOfMonth (DIV* year)?
+    : monthExpr (SPACE|ddiv)+ dayOfMonth ((SPACE|ddiv)+ year)?
     ;
 
 dayMonthMaybeYearExpr
-    : dayOfMonth DIV* monthName (DIV* year)?
+    : dayOfMonth (SPACE|ddiv)+ monthName ((SPACE|ddiv)+ year)?
     ;
 
 yearMonthDayExpr
-    : year DIV* monthName DIV* dayOfMonth
+    : year (SPACE|ddiv)+ monthExpr (SPACE|ddiv)+ dayOfMonth
     ;
 
 nthWeekdayInMonthMaybeYearExpr
-    : nth dayName ('in'|'of') monthName (DIV* year)?
+    : nth SPACE+ dayName SPACE+ ('in'|'of') SPACE+ monthName ((ddiv|SPACE)+ year)?
     ;
 
 firstLastWeekdayInMonthMaybeYearExpr
-    : firstOrLast dayName ('in'|'of'|DIV)? monthName (DIV* year)?
+    : firstOrLast SPACE+ dayName (SPACE+ ('in'|'of'))? SPACE+ monthName ((ddiv|SPACE)+ year)?
     ;
 
 specialDateMaybeYearExpr
-    : specialDate (DIV* year)?
+    : (thisNextLast SPACE+)? specialDate ((SPACE|ddiv)+ year)?
     ;
 
+thisNextLast: (THIS|NEXT|LAST) ;
+
 baseAndOffsetDateExpr
-    : baseDate deltaPlusMinusExpr
-    | deltaPlusMinusExpr baseDate
+    : baseDate SPACE+ deltaPlusMinusExpr
+    | deltaPlusMinusExpr SPACE+ baseDate
+    ;
+
+deltaDateExprRelativeToTodayImplied
+    : nFoosFromTodayAgoExpr
+    | deltaRelativeToTodayExpr
+    ;
+
+deltaRelativeToTodayExpr
+    : thisNextLast SPACE+ deltaUnit
+    ;
+
+nFoosFromTodayAgoExpr
+    : unsignedInt SPACE+ deltaUnit SPACE+ AGO_FROM_NOW
     ;
 
 baseDate: singleDateExpr ;
 
-deltaPlusMinusExpr: deltaInt deltaUnit deltaBeforeAfter? ;
+baseAndOffsetTimeExpr
+    : deltaPlusMinusTimeExpr SPACE+ baseTime
+    | baseTime SPACE+ deltaPlusMinusTimeExpr
+    ;
+
+baseTime: singleTimeExpr ;
+
+deltaPlusMinusExpr
+    : nth SPACE+ deltaUnit (SPACE+ deltaBeforeAfter)?
+    ;
+
+deltaNextLast
+    : (NEXT|LAST) ;
+
+deltaPlusMinusTimeExpr
+    : countUnitsBeforeAfterTimeExpr
+    | fractionBeforeAfterTimeExpr
+    ;
+
+countUnitsBeforeAfterTimeExpr
+    : nth (SPACE+ deltaTimeUnit)? SPACE+ deltaTimeBeforeAfter
+    ;
+
+fractionBeforeAfterTimeExpr
+    : deltaTimeFraction SPACE+ deltaTimeBeforeAfter
+    ;
 
-deltaUnit: (WEEK|DAY|SUN|WEEKDAY) ;
+deltaUnit: (YEAR|MONTH|WEEK|DAY|WEEKDAY|WORKDAY) ;
+
+deltaTimeUnit: (SECOND|MINUTE|HOUR|WORKDAY) ;
 
 deltaBeforeAfter: (BEFORE|AFTER) ;
 
+deltaTimeBeforeAfter: (BEFORE|AFTER) ;
+
 monthExpr
     : monthName
     | monthNumber
@@ -76,137 +147,318 @@ year: DIGIT DIGIT DIGIT DIGIT ;
 
 specialDate: SPECIAL_DATE ;
 
-dayOfMonth: DIGIT? DIGIT ('st'|'nd'|'rd'|'th')? ;
+dayOfMonth
+    : DIGIT DIGIT? ('st'|'ST'|'nd'|'ND'|'rd'|'RD'|'th'|'TH')?
+    | KALENDS (SPACE+ 'of')?
+    | IDES (SPACE+ 'of')?
+    | NONES (SPACE+ 'of')?
+    ;
 
 firstOrLast: (FIRST|LAST) ;
 
-nth: DIGIT ('st'|'nd'|'rd'|'th')? ;
+nth: (DASH|PLUS)? DIGIT+ ('st'|'ST'|'nd'|'ND'|'rd'|'RD'|'th'|'TH')? ;
+
+unsignedInt: DIGIT+ ;
+
+deltaTimeFraction: DELTA_TIME_FRACTION ;
 
-deltaInt: ('+'|'-')? DIGIT+ ;
+specialTimeExpr: specialTime (SPACE+ tzExpr)? ;
+
+specialTime: SPECIAL_TIME ;
 
 dayName: WEEKDAY ;
 
-monthName: MONTH ;
+monthName: MONTH_NAME ;
+
+monthNumber: DIGIT DIGIT? ;
+
+hour: DIGIT DIGIT? ;
+
+minute: DIGIT DIGIT ;
 
-monthNumber: DIGIT? DIGIT ;
+second: DIGIT DIGIT ;
+
+micros: DIGIT DIGIT? DIGIT? DIGIT? DIGIT? DIGIT? DIGIT? ;
+
+ddiv: (SLASH|DASH|',') ;
+
+tdiv: (COLON|DOT) ;
+
+dtdiv: ('T'|'t'|'at'|'AT'|','|';') ;
+
+tddiv: ('on'|'ON'|','|';') ;
+
+tzExpr
+    : ntz
+    | ltz
+    ;
+
+ntz: (PLUS|DASH) DIGIT DIGIT COLON? DIGIT DIGIT ;
+
+ltz: UPPERCASE_STRING ;
 
 // ----------------------------------
 
+SPACE: [ \t\r\n] ;
+
 COMMENT: '#' ~[\r\n]* -> skip ;
 
-SPACE: [ \t\r\n] -> skip ;
+THE: ('the'|'The') SPACE* -> skip ;
 
-THE: 'the' -> skip ;
+DASH: '-' ;
 
-DIV: ('/'|','|'.') ;
+PLUS: '+' ;
 
-MONTH: (JAN|FEB|MAR|APR|MAY|JUN|JUL|AUG|SEP|OCT|NOV|DEC) ;
+SLASH: '/' ;
+
+DOT: '.' ;
+
+COLON: ':' ;
+
+MONTH_NAME: (JAN|FEB|MAR|APR|MAY|JUN|JUL|AUG|SEP|OCT|NOV|DEC) ;
 
 JAN : 'jan'
+    | 'Jan'
+    | 'JAN'
+    | 'January'
     | 'january'
     ;
 
 FEB : 'feb'
+    | 'Feb'
+    | 'FEB'
+    | 'February'
     | 'february'
     ;
 
 MAR : 'mar'
+    | 'Mar'
+    | 'MAR'
+    | 'March'
     | 'march'
     ;
 
 APR : 'apr'
+    | 'Apr'
+    | 'APR'
+    | 'April'
     | 'april'
     ;
 
 MAY : 'may'
+    | 'May'
+    | 'MAY'
     ;
 
 JUN : 'jun'
+    | 'Jun'
+    | 'JUN'
+    | 'June'
     | 'june'
     ;
 
 JUL : 'jul'
+    | 'Jul'
+    | 'JUL'
+    | 'July'
     | 'july'
     ;
 
 AUG : 'aug'
+    | 'Aug'
+    | 'AUG'
+    | 'August'
     | 'august'
     ;
 
 SEP : 'sep'
+    | 'Sep'
+    | 'SEP'
     | 'sept'
+    | 'Sept'
+    | 'SEPT'
+    | 'September'
     | 'september'
     ;
 
 OCT : 'oct'
+    | 'Oct'
+    | 'OCT'
+    | 'October'
     | 'october'
     ;
 
 NOV : 'nov'
+    | 'Nov'
+    | 'NOV'
+    | 'November'
     | 'november'
     ;
 
 DEC : 'dec'
+    | 'Dec'
+    | 'DEC'
+    | 'December'
     | 'december'
     ;
 
 WEEKDAY: (SUN|MON|TUE|WED|THU|FRI|SAT) ;
 
 SUN : 'sun'
+    | 'Sun'
+    | 'SUN'
     | 'suns'
+    | 'Suns'
+    | 'SUNS'
     | 'sunday'
+    | 'Sunday'
     | 'sundays'
+    | 'Sundays'
     ;
 
 MON : 'mon'
+    | 'Mon'
+    | 'MON'
     | 'mons'
+    | 'Mons'
+    | 'MONS'
     | 'monday'
+    | 'Monday'
     | 'mondays'
+    | 'Mondays'
     ;
 
 TUE : 'tue'
+    | 'Tue'
+    | 'TUE'
     | 'tues'
+    | 'Tues'
+    | 'TUES'
     | 'tuesday'
+    | 'Tuesday'
     | 'tuesdays'
+    | 'Tuesdays'
     ;
 
 WED : 'wed'
+    | 'Wed'
+    | 'WED'
     | 'weds'
+    | 'Weds'
+    | 'WEDS'
     | 'wednesday'
+    | 'Wednesday'
     | 'wednesdays'
+    | 'Wednesdays'
     ;
 
 THU : 'thu'
+    | 'Thu'
+    | 'THU'
     | 'thur'
+    | 'Thur'
+    | 'THUR'
     | 'thurs'
+    | 'Thurs'
+    | 'THURS'
     | 'thursday'
+    | 'Thursday'
     | 'thursdays'
+    | 'Thursdays'
     ;
 
 FRI : 'fri'
+    | 'Fri'
+    | 'FRI'
     | 'fris'
+    | 'Fris'
+    | 'FRIS'
     | 'friday'
+    | 'Friday'
     | 'fridays'
+    | 'Fridays'
     ;
 
 SAT : 'sat'
+    | 'Sat'
+    | 'SAT'
     | 'sats'
+    | 'Sats'
+    | 'SATS'
     | 'saturday'
+    | 'Saturday'
     | 'saturdays'
+    | 'Saturdays'
     ;
 
 WEEK
     : 'week'
+    | 'Week'
     | 'weeks'
+    | 'Weeks'
+    | 'wks'
     ;
 
 DAY
     : 'day'
+    | 'Day'
     | 'days'
+    | 'Days'
+    ;
+
+HOUR
+    : 'hour'
+    | 'Hour'
+    | 'hours'
+    | 'Hours'
+    | 'hrs'
+    ;
+
+MINUTE
+    : 'min'
+    | 'Min'
+    | 'MIN'
+    | 'mins'
+    | 'Mins'
+    | 'MINS'
+    | 'minute'
+    | 'Minute'
+    | 'minutes'
+    | 'Minutes'
+    ;
+
+SECOND
+    : 'sec'
+    | 'Sec'
+    | 'SEC'
+    | 'secs'
+    | 'Secs'
+    | 'SECS'
+    | 'second'
+    | 'Second'
+    | 'seconds'
+    | 'Seconds'
+    ;
+
+MONTH
+    : 'month'
+    | 'Month'
+    | 'months'
+    | 'Months'
+    ;
+
+YEAR
+    : 'year'
+    | 'Year'
+    | 'years'
+    | 'Years'
+    | 'yrs'
     ;
 
 SPECIAL_DATE
     : TODAY
+    | YESTERDAY
+    | TOMORROW
     | NEW_YEARS_EVE
     | NEW_YEARS_DAY
     | MARTIN_LUTHER_KING_DAY
@@ -217,73 +469,120 @@ SPECIAL_DATE
     | LABOR_DAY
     | COLUMBUS_DAY
     | VETERANS_DAY
+    | HALLOWEEN
     | THANKSGIVING_DAY
     | CHRISTMAS_EVE
     | CHRISTMAS
     ;
 
+SPECIAL_TIME
+    : NOON
+    | MIDNIGHT
+    ;
+
+NOON
+    : ('noon'|'Noon'|'midday'|'Midday')
+    ;
+
+MIDNIGHT
+    : ('midnight'|'Midnight')
+    ;
+
 // today
 TODAY
-    : 'today'
+    : ('today'|'Today'|'now'|'Now')
+    ;
+
+// yeste
+YESTERDAY
+    : ('yesterday'|'Yesterday')
+    ;
+
+// tomor
+TOMORROW
+    : ('tomorrow'|'Tomorrow')
     ;
 
 // easte
 EASTER
-    : 'easter'
-    | 'easter sunday'
+    : 'easter' SUN?
+    | 'Easter' SUN?
     ;
 
 // newye
 NEW_YEARS_DAY
     : 'new years'
+    | 'New Years'
     | 'new years day'
+    | 'New Years Day'
     | 'new year\'s'
+    | 'New Year\'s'
     | 'new year\'s day'
+    | 'New year\'s Day'
     ;
 
 // newyeeve
 NEW_YEARS_EVE
     : 'nye'
+    | 'NYE'
     | 'new years eve'
+    | 'New Years Eve'
     | 'new year\'s eve'
+    | 'New Year\'s Eve'
     ;
 
 // chris
 CHRISTMAS
     : 'christmas'
+    | 'Christmas'
     | 'christmas day'
+    | 'Christmas Day'
     | 'xmas'
+    | 'Xmas'
     | 'xmas day'
+    | 'Xmas Day'
     ;
 
 // chriseve
 CHRISTMAS_EVE
     : 'christmas eve'
+    | 'Christmas Eve'
     | 'xmas eve'
+    | 'Xmas Eve'
     ;
 
 // mlk
 MARTIN_LUTHER_KING_DAY
     : 'martin luther king day'
+    | 'Martin Luther King Day'
     | 'mlk day'
+    | 'MLK Day'
+    | 'MLK day'
     | 'mlk'
+    | 'MLK'
     ;
 
 // memor
 MEMORIAL_DAY
     : 'memorial'
+    | 'Memorial'
     | 'memorial day'
+    | 'Memorial Day'
     ;
 
 // indep
 INDEPENDENCE_DAY
     : 'independence day'
+    | 'Independence day'
+    | 'Independence Day'
     ;
 
 // labor
 LABOR_DAY
     : 'labor'
+    | 'Labor'
     | 'labor day'
+    | 'Labor Day'
     ;
 
 // presi
@@ -294,6 +593,12 @@ PRESIDENTS_DAY
     | 'presidents'
     | 'president\'s'
     | 'presidents\''
+    | 'Presidents\' Day'
+    | 'President\'s Day'
+    | 'Presidents Day'
+    | 'Presidents'
+    | 'President\'s'
+    | 'Presidents\''
     ;
 
 // colum
@@ -302,6 +607,10 @@ COLUMBUS_DAY
     | 'columbus day'
     | 'indiginous peoples day'
     | 'indiginous peoples\' day'
+    | 'Columbus'
+    | 'Columbus Day'
+    | 'Indiginous Peoples Day'
+    | 'Indiginous Peoples\' Day'
     ;
 
 // veter
@@ -309,20 +618,62 @@ VETERANS_DAY
     : 'veterans'
     | 'veterans day'
     | 'veterans\' day'
+    | 'Veterans'
+    | 'Veterans Day'
+    | 'Veterans\' Day'
+    ;
+
+// hallo
+HALLOWEEN
+    : 'halloween'
+    | 'Halloween'
     ;
 
 // thank
 THANKSGIVING_DAY
     : 'thanksgiving'
     | 'thanksgiving day'
+    | 'Thanksgiving'
+    | 'Thanksgiving Day'
     ;
 
-FIRST: 'first' ;
+FIRST: ('first'|'First') ;
+
+LAST: ('last'|'Last'|'this past') ;
+
+THIS: ('this'|'This'|'this coming') ;
+
+NEXT: ('next'|'Next') ;
 
-LAST: 'last' ;
+AGO_FROM_NOW: (AGO|FROM_NOW) ;
 
-BEFORE: 'before' ;
+AGO: ('ago'|'Ago'|'back'|'Back') ;
 
-AFTER: ('after'|'from') ;
+FROM_NOW: ('from now'|'From Now') ;
+
+BEFORE: ('to'|'To'|'before'|'Before'|'til'|'until'|'Until') ;
+
+AFTER: ('after'|'After'|'from'|'From'|'past'|'Past') ;
+
+DELTA_TIME_FRACTION: ('quarter'|'Quarter'|'half'|'Half') ;
 
 DIGIT: ('0'..'9') ;
+
+IDES: ('ides'|'Ides') ;
+
+NONES: ('nones'|'Nones') ;
+
+KALENDS: ('kalends'|'Kalends') ;
+
+WORKDAY
+    : 'workday'
+    | 'workdays'
+    | 'work days'
+    | 'working days'
+    | 'Workday'
+    | 'Workdays'
+    | 'Work Days'
+    | 'Working Days'
+    ;
+
+UPPERCASE_STRING: [A-Z]+ ;
index fb55eef6b973fd9ad67f8a19c74c2d184172272b..fc5408fb8cd9a252a2f21165aea960d0c3fe0a3e 100755 (executable)
 #!/usr/bin/env python3
 
-import antlr4  # type: ignore
 import datetime
-import dateutil.easter
+import functools
 import holidays  # type: ignore
+import logging
 import re
 import sys
-from typing import Any, Dict, Optional
+from typing import Any, Callable, Dict, Optional
 
+import antlr4  # type: ignore
+import dateutil.easter
+import dateutil.tz
+import pytz
+
+import bootstrap
+import datetime_utils
+import decorator_utils
 from dateparse.dateparse_utilsLexer import dateparse_utilsLexer  # type: ignore
 from dateparse.dateparse_utilsListener import dateparse_utilsListener  # type: ignore
 from dateparse.dateparse_utilsParser import dateparse_utilsParser  # type: ignore
+import simple_acl as acl
+
+
+logger = logging.getLogger(__name__)
+
+
+def debug_parse(enter_or_exit_f: Callable[[Any, Any], None]):
+    @functools.wraps(enter_or_exit_f)
+    def debug_parse_wrapper(*args, **kwargs):
+        slf = args[0]
+        ctx = args[1]
+        depth = ctx.depth()
+        logger.debug(
+            '  ' * (depth-1) +
+            f'Entering {enter_or_exit_f.__name__} ({ctx.invokingState} / {ctx.exception})'
+        )
+        for c in ctx.getChildren():
+            logger.debug(
+                '  ' * (depth-1) +
+                f'{c} {type(c)}'
+            )
+        retval = enter_or_exit_f(*args, **kwargs)
+        return retval
+    return debug_parse_wrapper
 
 
 class ParseException(Exception):
+    """An exception thrown during parsing because of unrecognized input."""
     def __init__(self, message: str) -> None:
+        logger.error(message)
         self.message = message
 
 
+class RaisingErrorListener(antlr4.DiagnosticErrorListener):
+    """An error listener that raises ParseExceptions."""
+    def syntaxError(
+            self, recognizer, offendingSymbol, line, column, msg, e
+    ):
+        logger.error(msg)
+        raise ParseException(msg)
+
+    def reportAmbiguity(
+            self, recognizer, dfa, startIndex, stopIndex, exact,
+            ambigAlts, configs
+    ):
+        pass
+
+    def reportAttemptingFullContext(
+            self, recognizer, dfa, startIndex, stopIndex, conflictingAlts,
+            configs
+    ):
+        pass
+
+    def reportContextSensitivity(
+            self, recognizer, dfa, startIndex, stopIndex, prediction, configs
+    ):
+        pass
+
+
+@decorator_utils.decorate_matching_methods_with(debug_parse,
+                                                acl=acl.StringWildcardBasedACL(
+                                                    allowed_patterns=[
+                                                        'enter*',
+                                                        'exit*',
+                                                    ],
+                                                    denied_patterns=None,
+                                                    order_to_check_allow_deny=acl.ACL_ORDER_DENY_ALLOW,
+                                                    default_answer=False
+                                                ))
 class DateParser(dateparse_utilsListener):
     PARSE_TYPE_SINGLE_DATE_EXPR = 1
     PARSE_TYPE_BASE_AND_OFFSET_EXPR = 2
-    CONSTANT_DAYS = 7
-    CONSTANT_WEEKS = 8
-    CONSTANT_MONTHS = 9
-    CONSTANT_YEARS = 10
+    PARSE_TYPE_SINGLE_TIME_EXPR = 3
+    PARSE_TYPE_BASE_AND_OFFSET_TIME_EXPR = 4
 
-    def __init__(self):
+    def __init__(
+            self,
+            *,
+            override_now_for_test_purposes = None
+    ) -> None:
+        """C'tor.  Passing a value to override_now_for_test_purposes can be
+        used to force this instance to use a custom date/time for its
+        idea of "now" so that the code can be more easily unittested.
+        Leave as None for real use cases.
+        """
         self.month_name_to_number = {
-            "jan": 1,
-            "feb": 2,
-            "mar": 3,
-            "apr": 4,
-            "may": 5,
-            "jun": 6,
-            "jul": 7,
-            "aug": 8,
-            "sep": 9,
-            "oct": 10,
-            "nov": 11,
-            "dec": 12,
+            'jan': 1,
+            'feb': 2,
+            'mar': 3,
+            'apr': 4,
+            'may': 5,
+            'jun': 6,
+            'jul': 7,
+            'aug': 8,
+            'sep': 9,
+            'oct': 10,
+            'nov': 11,
+            'dec': 12,
         }
+
+        # Used only for ides/nones.  Month length on a non-leap year.
+        self.typical_days_per_month = {
+            1: 31,
+            2: 28,
+            3: 31,
+            4: 30,
+            5: 31,
+            6: 30,
+            7: 31,
+            8: 31,
+            9: 30,
+            10: 31,
+            11: 30,
+            12: 31
+        }
+
+        # N.B. day number is also synched with datetime_utils.TimeUnit values
+        # which allows expressions like "3 wednesdays from now" to work.
         self.day_name_to_number = {
-            "mon": 0,
-            "tue": 1,
-            "wed": 2,
-            "thu": 3,
-            "fri": 4,
-            "sat": 5,
-            "sun": 6,
+            'mon': 0,
+            'tue': 1,
+            'wed': 2,
+            'thu': 3,
+            'fri': 4,
+            'sat': 5,
+            'sun': 6,
+        }
+
+        # These TimeUnits are defined in datetime_utils and are used as params
+        # to datetime_utils.n_timeunits_from_base.
+        self.time_delta_unit_to_constant = {
+            'hou': datetime_utils.TimeUnit.HOURS,
+            'min': datetime_utils.TimeUnit.MINUTES,
+            'sec': datetime_utils.TimeUnit.SECONDS,
         }
         self.delta_unit_to_constant = {
-            "day": DateParser.CONSTANT_DAYS,
-            "wee": DateParser.CONSTANT_WEEKS,
+            'day': datetime_utils.TimeUnit.DAYS,
+            'wor': datetime_utils.TimeUnit.WORKDAYS,
+            'wee': datetime_utils.TimeUnit.WEEKS,
+            'mon': datetime_utils.TimeUnit.MONTHS,
+            'yea': datetime_utils.TimeUnit.YEARS,
         }
-        self.date: Optional[datetime.date] = None
+        self.override_now_for_test_purposes = override_now_for_test_purposes
+        self._reset()
+
+    def parse(self, date_string: str) -> Optional[datetime.datetime]:
+        """Parse a date/time expression and return a timezone agnostic
+        datetime on success.  Also sets self.datetime, self.date and
+        self.time which can each be accessed other methods on the
+        class: get_datetime(), get_date() and get_time().  Raises a
+        ParseException with a helpful(?) message on parse error or
+        confusion.
+
+        To get an idea of what expressions can be parsed, check out
+        the unittest and the grammar.
+
+        Usage:
 
-    def parse_date_string(self, date_string: str) -> Optional[datetime.date]:
+        txt = '3 weeks before last tues at 9:15am'
+        dp = DateParser()
+        dt1 = dp.parse(txt)
+        dt2 = dp.get_datetime(tz=pytz.timezone('US/Pacific'))
+
+        # dt1 and dt2 will be identical other than the fact that
+        # the latter's tzinfo will be set to PST/PDT.
+
+        This is the main entrypoint to this class for caller code.
+        """
+        self._reset()
+        listener = RaisingErrorListener()
         input_stream = antlr4.InputStream(date_string)
         lexer = dateparse_utilsLexer(input_stream)
+        lexer.removeErrorListeners()
+        lexer.addErrorListener(listener)
         stream = antlr4.CommonTokenStream(lexer)
         parser = dateparse_utilsParser(stream)
+        parser.removeErrorListeners()
+        parser.addErrorListener(listener)
         tree = parser.parse()
         walker = antlr4.ParseTreeWalker()
         walker.walk(self, tree)
-        return self.get_date()
+        return self.datetime
 
     def get_date(self) -> Optional[datetime.date]:
+        """Return the date part or None."""
         return self.date
 
-    def enterDateExpr(self, ctx: dateparse_utilsParser.DateExprContext):
-        self.date = None
+    def get_time(self) -> Optional[datetime.time]:
+        """Return the time part or None."""
+        return self.time
+
+    def get_datetime(self, *, tz=None) -> Optional[datetime.datetime]:
+        """Return as a datetime.  Parsed date expressions without any time
+        part return hours = minutes = seconds = microseconds = 0 (i.e. at
+        midnight that day).  Parsed time expressions without any date part
+        default to date = today.
+
+        The optional tz param allows the caller to request the datetime be
+        timezone aware and sets the tzinfo to the indicated zone.  Defaults
+        to timezone naive (i.e. tzinfo = None).
+        """
+        dt = self.datetime
+        if tz is not None:
+            dt = dt.replace(tzinfo=None).astimezone(tz=tz)
+        return dt
+
+    # -- helpers --
+
+    def _reset(self):
+        """Reset at init and between parses."""
+        if self.override_now_for_test_purposes is None:
+            self.now_datetime = datetime.datetime.now()
+            self.today = datetime.date.today()
+        else:
+            self.now_datetime = self.override_now_for_test_purposes
+            self.today = datetime_utils.datetime_to_date(
+                self.override_now_for_test_purposes
+            )
+        self.date: Optional[datetime.date] = None
+        self.time: Optional[datetime.time] = None
+        self.datetime: Optional[datetime.datetime] = None
         self.context: Dict[str, Any] = {}
-        if ctx.singleDateExpr() is not None:
-            self.main_type = DateParser.PARSE_TYPE_SINGLE_DATE_EXPR
-        elif ctx.baseAndOffsetDateExpr() is not None:
-            self.main_type = DateParser.PARSE_TYPE_BASE_AND_OFFSET_EXPR
+        self.timedelta = datetime.timedelta(seconds=0)
 
     @staticmethod
-    def normalize_special_day_name(name: str) -> str:
+    def _normalize_special_day_name(name: str) -> str:
+        """String normalization / canonicalization for date expressions."""
         name = name.lower()
-        name = name.replace("'", "")
-        name = name.replace("xmas", "christmas")
-        name = name.replace("mlk", "martin luther king")
-        name = name.replace(" ", "")
-        eve = "eve" if name[-3:] == "eve" else ""
+        name = name.replace("'", '')
+        name = name.replace('xmas', 'christmas')
+        name = name.replace('mlk', 'martin luther king')
+        name = name.replace(' ', '')
+        eve = 'eve' if name[-3:] == 'eve' else ''
         name = name[:5] + eve
-        name = name.replace("washi", "presi")
+        name = name.replace('washi', 'presi')
         return name
 
-    def parse_special(self, name: str) -> Optional[datetime.date]:
-        today = datetime.date.today()
-        year = self.context.get("year", today.year)
-        name = DateParser.normalize_special_day_name(self.context["special"])
-        if name == "today":
+    def _figure_out_date_unit(self, orig: str) -> int:
+        """Figure out what unit a date expression piece is talking about."""
+        if 'month' in orig:
+            return datetime_utils.TimeUnit.MONTHS
+        txt = orig.lower()[:3]
+        if txt in self.day_name_to_number:
+            return(self.day_name_to_number[txt])
+        elif txt in self.delta_unit_to_constant:
+            return(self.delta_unit_to_constant[txt])
+        raise ParseException(f'Invalid date unit: {orig}')
+
+    def _figure_out_time_unit(self, orig: str) -> int:
+        """Figure out what unit a time expression piece is talking about."""
+        txt = orig.lower()[:3]
+        if txt in self.time_delta_unit_to_constant:
+            return(self.time_delta_unit_to_constant[txt])
+        raise ParseException(f'Invalid time unit: {orig}')
+
+    def _parse_special_date(self, name: str) -> Optional[datetime.date]:
+        """Parse what we think is a special date name and return its datetime
+        (or None if it can't be parsed).
+        """
+        today = self.today
+        year = self.context.get('year', today.year)
+        name = DateParser._normalize_special_day_name(self.context['special'])
+
+        # Yesterday, today, tomorrow -- ignore any next/last
+        if name == 'today' or name == 'now':
             return today
-        if name == "easte":
+        if name == 'yeste':
+            return today + datetime.timedelta(days=-1)
+        if name == 'tomor':
+            return today + datetime.timedelta(days=+1)
+
+        next_last = self.context.get('special_next_last', '')
+        if next_last == 'next':
+            year += 1
+        elif next_last == 'last':
+            year -= 1
+
+        # Holiday names
+        if name == 'easte':
             return dateutil.easter.easter(year=year)
+        elif name == 'hallo':
+            return datetime.date(year=year, month=10, day=31)
+
         for holiday_date, holiday_name in sorted(
             holidays.US(years=year).items()
         ):
-            if "Observed" not in holiday_name:
-                holiday_name = DateParser.normalize_special_day_name(
+            if 'Observed' not in holiday_name:
+                holiday_name = DateParser._normalize_special_day_name(
                     holiday_name
                 )
                 if name == holiday_name:
                     return holiday_date
-        if name == "chriseve":
+        if name == 'chriseve':
             return datetime.date(year=year, month=12, day=24)
-        elif name == "newyeeve":
+        elif name == 'newyeeve':
             return datetime.date(year=year, month=12, day=31)
         return None
 
-    def parse_normal(self) -> datetime.date:
-        if "month" not in self.context:
-            raise ParseException("Missing month")
-        if "day" not in self.context:
-            raise ParseException("Missing day")
-        if "year" not in self.context:
-            today = datetime.date.today()
-            self.context["year"] = today.year
+    def _resolve_ides_nones(self, day: str, month_number: int) -> int:
+        """Handle date expressions like "the ides of March" which require
+        both the "ides" and the month since the definition of the "ides"
+        changes based on the length of the month.
+        """
+        assert 'ide' in day or 'non' in day
+        assert month_number in self.typical_days_per_month
+        typical_days_per_month = self.typical_days_per_month[month_number]
+
+        # "full" month
+        if typical_days_per_month == 31:
+            if self.context['day'] == 'ide':
+                return 15
+            else:
+                return 7
+
+        # "hollow" month
+        else:
+            if self.context['day'] == 'ide':
+                return 13
+            else:
+                return 5
+
+    def _parse_normal_date(self) -> datetime.date:
+        if 'dow' in self.context:
+            d = self.today
+            while d.weekday() != self.context['dow']:
+                d += datetime.timedelta(days=1)
+            return d
+
+        if 'month' not in self.context:
+            raise ParseException('Missing month')
+        if 'day' not in self.context:
+            raise ParseException('Missing day')
+        if 'year' not in self.context:
+            self.context['year'] = self.today.year
+
+        # Handling "ides" and "nones" requires both the day and month.
+        if (
+                self.context['day'] == 'ide' or
+                self.context['day'] == 'non'
+        ):
+            self.context['day'] = self._resolve_ides_nones(
+                self.context['day'], self.context['month']
+            )
+
         return datetime.date(
-            year=int(self.context["year"]),
-            month=int(self.context["month"]),
-            day=int(self.context["day"]),
+            year=self.context['year'],
+            month=self.context['month'],
+            day=self.context['day'],
+        )
+
+    def _parse_tz(self, txt: str) -> Any:
+        if txt == 'Z':
+            txt = 'UTC'
+
+        # Try pytz
+        try:
+            tz = pytz.timezone(txt)
+            if tz is not None:
+                return tz
+        except:
+            pass
+
+        # Try dateutil
+        try:
+            tz = dateutil.tz.gettz(txt)
+            if tz is not None:
+                return tz
+        except:
+            pass
+
+        # Try constructing an offset in seconds
+        try:
+            sign = txt[0]
+            if sign == '-' or sign == '+':
+                sign = +1 if sign == '+' else -1
+                hour = int(txt[1:3])
+                minute = int(txt[-2:])
+                offset = sign * (hour * 60 * 60) + sign * (minute * 60)
+                tzoffset = dateutil.tz.tzoffset(txt, offset)
+                return tzoffset
+        except:
+            pass
+        return None
+
+    def _get_int(self, txt: str) -> int:
+        while not txt[0].isdigit() and txt[0] != '-' and txt[0] != '+':
+            txt = txt[1:]
+        while not txt[-1].isdigit():
+            txt = txt[:-1]
+        return int(txt)
+
+    # -- overridden methods invoked by parse walk --
+
+    def visitErrorNode(self, node: antlr4.ErrorNode) -> None:
+        pass
+
+    def visitTerminal(self, node: antlr4.TerminalNode) -> None:
+        pass
+
+    def exitParse(self, ctx: dateparse_utilsParser.ParseContext) -> None:
+        """Populate self.datetime."""
+        if self.date is None:
+            self.date = self.today
+        year = self.date.year
+        month = self.date.month
+        day = self.date.day
+
+        if self.time is None:
+            self.time = datetime.time(0, 0, 0)
+        hour = self.time.hour
+        minute = self.time.minute
+        second = self.time.second
+        micros = self.time.microsecond
+
+        self.datetime = datetime.datetime(
+            year, month, day, hour, minute, second, micros,
+            tzinfo=self.time.tzinfo
         )
 
+        # Apply resudual adjustments to times here when we have a
+        # datetime.
+        self.datetime = self.datetime + self.timedelta
+        self.time = datetime.time(
+            self.datetime.hour,
+            self.datetime.minute,
+            self.datetime.second,
+            self.datetime.microsecond,
+            self.datetime.tzinfo
+        )
+
+    def enterDateExpr(self, ctx: dateparse_utilsParser.DateExprContext):
+        self.date = None
+        if ctx.singleDateExpr() is not None:
+            self.main_type = DateParser.PARSE_TYPE_SINGLE_DATE_EXPR
+        elif ctx.baseAndOffsetDateExpr() is not None:
+            self.main_type = DateParser.PARSE_TYPE_BASE_AND_OFFSET_EXPR
+
+    def enterTimeExpr(self, ctx: dateparse_utilsParser.TimeExprContext):
+        self.time = None
+        if ctx.singleTimeExpr() is not None:
+            self.time_type = DateParser.PARSE_TYPE_SINGLE_TIME_EXPR
+        elif ctx.baseAndOffsetTimeExpr() is not None:
+            self.time_type = DateParser.PARSE_TYPE_BASE_AND_OFFSET_TIME_EXPR
+
     def exitDateExpr(self, ctx: dateparse_utilsParser.DateExprContext) -> None:
         """When we leave the date expression, populate self.date."""
-        if "special" in self.context:
-            self.date = self.parse_special(self.context["special"])
+        if 'special' in self.context:
+            self.date = self._parse_special_date(self.context['special'])
         else:
-            self.date = self.parse_normal()
+            self.date = self._parse_normal_date()
         assert self.date is not None
 
         # For a single date, just return the date we pulled out.
@@ -140,74 +477,198 @@ class DateParser(dateparse_utilsListener):
 
         # Otherwise treat self.date as a base date that we're modifying
         # with an offset.
-        if not "delta_int" in self.context:
-            raise ParseException("Missing delta_int?!")
-        count = self.context["delta_int"]
+        if 'delta_int' not in self.context:
+            raise ParseException('Missing delta_int?!')
+        count = self.context['delta_int']
         if count == 0:
             return
 
         # Adjust count's sign based on the presence of 'before' or 'after'.
-        if "delta_before_after" in self.context:
-            before_after = self.context["delta_before_after"].lower()
-            if before_after == "before":
+        if 'delta_before_after' in self.context:
+            before_after = self.context['delta_before_after'].lower()
+            if (
+                    before_after == 'before' or
+                    before_after == 'until' or
+                    before_after == 'til' or
+                    before_after == 'to'
+            ):
                 count = -count
 
         # What are we counting units of?
-        if "delta_unit" not in self.context:
-            raise ParseException("Missing delta_unit?!")
-        unit = self.context["delta_unit"]
-        if unit == DateParser.CONSTANT_DAYS:
-            timedelta = datetime.timedelta(days=count)
-            self.date = self.date + timedelta
-        elif unit == DateParser.CONSTANT_WEEKS:
-            timedelta = datetime.timedelta(weeks=count)
-            self.date = self.date + timedelta
+        if 'delta_unit' not in self.context:
+            raise ParseException('Missing delta_unit?!')
+        unit = self.context['delta_unit']
+        dt = datetime_utils.n_timeunits_from_base(
+            count,
+            unit,
+            datetime_utils.date_to_datetime(self.date)
+        )
+        self.date = datetime_utils.datetime_to_date(dt)
+
+    def exitTimeExpr(self, ctx: dateparse_utilsParser.TimeExprContext) -> None:
+        # Simple time?
+        self.time = datetime.time(
+            self.context['hour'],
+            self.context['minute'],
+            self.context['seconds'],
+            self.context['micros'],
+            tzinfo=self.context.get('tz', None),
+        )
+        if self.time_type == DateParser.PARSE_TYPE_SINGLE_TIME_EXPR:
+            return
+
+        # If we get here there (should be) a relative adjustment to
+        # the time.
+        if 'nth' in self.context:
+            count = self.context['nth']
+        elif 'time_delta_int' in self.context:
+            count = self.context['time_delta_int']
         else:
-            direction = 1 if count > 0 else -1
-            count = abs(count)
-            timedelta = datetime.timedelta(days=direction)
+            raise ParseException('Missing delta in relative time.')
+        if count == 0:
+            return
 
-            while True:
-                dow = self.date.weekday()
-                if dow == unit:
-                    count -= 1
-                    if count == 0:
-                        return
-                self.date = self.date + timedelta
+        # Adjust count's sign based on the presence of 'before' or 'after'.
+        if 'time_delta_before_after' in self.context:
+            before_after = self.context['time_delta_before_after'].lower()
+            if (
+                    before_after == 'before' or
+                    before_after == 'until' or
+                    before_after == 'til' or
+                    before_after == 'to'
+            ):
+                count = -count
+
+        # What are we counting units of... assume minutes.
+        if 'time_delta_unit' not in self.context:
+            self.timedelta += datetime.timedelta(minutes=count)
+        else:
+            unit = self.context['time_delta_unit']
+            if unit == datetime_utils.TimeUnit.SECONDS:
+                self.timedelta += datetime.timedelta(seconds=count)
+            elif unit == datetime_utils.TimeUnit.MINUTES:
+                self.timedelta = datetime.timedelta(minutes=count)
+            elif unit == datetime_utils.TimeUnit.HOURS:
+                self.timedelta = datetime.timedelta(hours=count)
+            else:
+                raise ParseException()
 
-    def enterDeltaInt(self, ctx: dateparse_utilsParser.DeltaIntContext) -> None:
+    def exitDeltaPlusMinusExpr(
+        self, ctx: dateparse_utilsParser.DeltaPlusMinusExprContext
+    ) -> None:
         try:
-            i = int(ctx.getText())
+            n = ctx.nth()
+            if n is None:
+                raise ParseException(
+                    f'Bad N in Delta +/- Expr: {ctx.getText()}'
+                )
+            n = n.getText()
+            n = self._get_int(n)
+            unit = self._figure_out_date_unit(
+                ctx.deltaUnit().getText().lower()
+            )
         except:
-            raise ParseException(f"Bad delta int: {ctx.getText()}")
+            raise ParseException(f'Invalid Delta +/-: {ctx.getText()}')
         else:
-            self.context["delta_int"] = i
+            self.context['delta_int'] = n
+            self.context['delta_unit'] = unit
 
-    def enterDeltaUnit(
+    def exitNextLastUnit(
         self, ctx: dateparse_utilsParser.DeltaUnitContext
     ) -> None:
         try:
-            txt = ctx.getText().lower()[:3]
-            if txt in self.day_name_to_number:
-                txt = self.day_name_to_number[txt]
-            elif txt in self.delta_unit_to_constant:
-                txt = self.delta_unit_to_constant[txt]
+            unit = self._figure_out_date_unit(ctx.getText().lower())
+        except:
+            raise ParseException(f'Bad delta unit: {ctx.getText()}')
+        else:
+            self.context['delta_unit'] = unit
+
+    def exitDeltaNextLast(
+            self, ctx: dateparse_utilsParser.DeltaNextLastContext
+    ) -> None:
+        try:
+            txt = ctx.getText().lower()
+        except:
+            raise ParseException(f'Bad next/last: {ctx.getText()}')
+        if (
+                'month' in self.context or
+                'day' in self.context or
+                'year' in self.context
+        ):
+            raise ParseException(
+                'Next/last expression expected to be relative to today.'
+            )
+        if txt[:4] == 'next':
+            self.context['delta_int'] = +1
+            self.context['day'] = self.now_datetime.day
+            self.context['month'] = self.now_datetime.month
+            self.context['year'] = self.now_datetime.year
+        elif txt[:4] == 'last':
+            self.context['delta_int'] = -1
+            self.context['day'] = self.now_datetime.day
+            self.context['month'] = self.now_datetime.month
+            self.context['year'] = self.now_datetime.year
+        else:
+            raise ParseException(f'Bad next/last: {ctx.getText()}')
+
+    def exitCountUnitsBeforeAfterTimeExpr(
+        self, ctx: dateparse_utilsParser.CountUnitsBeforeAfterTimeExprContext
+    ) -> None:
+        if 'nth' not in self.context:
+            raise ParseException(
+                f'Bad count expression: {ctx.getText()}'
+            )
+        try:
+            unit = self._figure_out_time_unit(
+                ctx.deltaTimeUnit().getText().lower()
+            )
+            self.context['time_delta_unit'] = unit
+        except:
+            raise ParseException(f'Bad delta unit: {ctx.getText()}')
+        if 'time_delta_before_after' not in self.context:
+            raise ParseException(
+                f'Bad Before/After: {ctx.getText()}'
+            )
+
+    def exitDeltaTimeFraction(
+            self, ctx: dateparse_utilsParser.DeltaTimeFractionContext
+    ) -> None:
+        try:
+            txt = ctx.getText().lower()[:4]
+            if txt == 'quar':
+                self.context['time_delta_int'] = 15
+                self.context[
+                    'time_delta_unit'
+                ] = datetime_utils.TimeUnit.MINUTES
+            elif txt == 'half':
+                self.context['time_delta_int'] = 30
+                self.context[
+                    'time_delta_unit'
+                ] = datetime_utils.TimeUnit.MINUTES
             else:
-                raise ParseException(f"Bad delta unit: {ctx.getText()}")
+                raise ParseException(f'Bad time fraction {ctx.getText()}')
         except:
-            raise ParseException(f"Bad delta unit: {ctx.getText()}")
+            raise ParseException(f'Bad time fraction {ctx.getText()}')
+
+    def exitDeltaBeforeAfter(
+        self, ctx: dateparse_utilsParser.DeltaBeforeAfterContext
+    ) -> None:
+        try:
+            txt = ctx.getText().lower()
+        except:
+            raise ParseException(f'Bad delta before|after: {ctx.getText()}')
         else:
-            self.context["delta_unit"] = txt
+            self.context['delta_before_after'] = txt
 
-    def enterDeltaBeforeAfter(
+    def exitDeltaTimeBeforeAfter(
         self, ctx: dateparse_utilsParser.DeltaBeforeAfterContext
     ) -> None:
         try:
             txt = ctx.getText().lower()
         except:
-            raise ParseException(f"Bad delta before|after: {ctx.getText()}")
+            raise ParseException(f'Bad delta before|after: {ctx.getText()}')
         else:
-            self.context["delta_before_after"] = txt
+            self.context['time_delta_before_after'] = txt
 
     def exitNthWeekdayInMonthMaybeYearExpr(
         self, ctx: dateparse_utilsParser.NthWeekdayInMonthMaybeYearExprContext
@@ -220,25 +681,25 @@ class DateParser(dateparse_utilsListener):
         ...into base + offset expressions instead.
         """
         try:
-            if "nth" not in self.context:
-                raise ParseException(f"Missing nth number: {ctx.getText()}")
-            n = self.context["nth"]
+            if 'nth' not in self.context:
+                raise ParseException(f'Missing nth number: {ctx.getText()}')
+            n = self.context['nth']
             if n < 1 or n > 5:  # months never have more than 5 Foodays
                 if n != -1:
-                    raise ParseException(f"Invalid nth number: {ctx.getText()}")
-            del self.context["nth"]
-            self.context["delta_int"] = n
+                    raise ParseException(f'Invalid nth number: {ctx.getText()}')
+            del self.context['nth']
+            self.context['delta_int'] = n
 
-            year = self.context.get("year", datetime.date.today().year)
-            if "month" not in self.context:
+            year = self.context.get('year', self.today.year)
+            if 'month' not in self.context:
                 raise ParseException(
-                    f"Missing month expression: {ctx.getText()}"
+                    f'Missing month expression: {ctx.getText()}'
                 )
-            month = self.context["month"]
+            month = self.context['month']
 
-            dow = self.context["dow"]
-            del self.context["dow"]
-            self.context["delta_unit"] = dow
+            dow = self.context['dow']
+            del self.context['dow']
+            self.context['delta_unit'] = dow
 
             # For the nth Fooday in Month, start at the 1st of the
             # month and count ahead N Foodays.  For the last Fooday in
@@ -252,21 +713,21 @@ class DateParser(dateparse_utilsListener):
                 tmp_date = datetime.date(year=year, month=month, day=1)
                 tmp_date = tmp_date - datetime.timedelta(days=1)
 
-                self.context["year"] = tmp_date.year
-                self.context["month"] = tmp_date.month
-                self.context["day"] = tmp_date.day
+                self.context['year'] = tmp_date.year
+                self.context['month'] = tmp_date.month
+                self.context['day'] = tmp_date.day
 
                 # The delta adjustment code can handle the case where
                 # the last day of the month is the day we're looking
                 # for already.
             else:
-                self.context["year"] = year
-                self.context["month"] = month
-                self.context["day"] = 1
+                self.context['year'] = year
+                self.context['month'] = month
+                self.context['day'] = 1
             self.main_type = DateParser.PARSE_TYPE_BASE_AND_OFFSET_EXPR
         except:
             raise ParseException(
-                f"Invalid nthWeekday expression: {ctx.getText()}"
+                f'Invalid nthWeekday expression: {ctx.getText()}'
             )
 
     def exitFirstLastWeekdayInMonthMaybeYearExpr(
@@ -275,120 +736,330 @@ class DateParser(dateparse_utilsListener):
     ) -> None:
         self.exitNthWeekdayInMonthMaybeYearExpr(ctx)
 
-    def enterNth(self, ctx: dateparse_utilsParser.NthContext) -> None:
+    def exitNth(self, ctx: dateparse_utilsParser.NthContext) -> None:
         try:
-            i = ctx.getText()
-            m = re.match("\d+[a-z][a-z]", i)
-            if m is not None:
-                i = i[:-2]
-            i = int(i)
+            i = self._get_int(ctx.getText())
         except:
-            raise ParseException(f"Bad nth expression: {ctx.getText()}")
+            raise ParseException(f'Bad nth expression: {ctx.getText()}')
         else:
-            self.context["nth"] = i
+            self.context['nth'] = i
 
-    def enterFirstOrLast(
+    def exitFirstOrLast(
         self, ctx: dateparse_utilsParser.FirstOrLastContext
     ) -> None:
         try:
             txt = ctx.getText()
-            if txt == "first":
+            if txt == 'first':
                 txt = 1
-            elif txt == "last":
+            elif txt == 'last':
                 txt = -1
             else:
                 raise ParseException(
-                    f"Bad first|last expression: {ctx.getText()}"
+                    f'Bad first|last expression: {ctx.getText()}'
                 )
         except:
-            raise ParseException(f"Bad first|last expression: {ctx.getText()}")
+            raise ParseException(f'Bad first|last expression: {ctx.getText()}')
         else:
-            self.context["nth"] = txt
+            self.context['nth'] = txt
 
-    def enterDayName(self, ctx: dateparse_utilsParser.DayNameContext) -> None:
+    def exitDayName(self, ctx: dateparse_utilsParser.DayNameContext) -> None:
         try:
             dow = ctx.getText().lower()[:3]
             dow = self.day_name_to_number.get(dow, None)
         except:
-            raise ParseException("Bad day of week")
+            raise ParseException('Bad day of week')
         else:
-            self.context["dow"] = dow
+            self.context['dow'] = dow
 
-    def enterDayOfMonth(
+    def exitDayOfMonth(
         self, ctx: dateparse_utilsParser.DayOfMonthContext
     ) -> None:
         try:
-            day = int(ctx.getText())
+            day = ctx.getText().lower()
+            if day[:3] == 'ide':
+                self.context['day'] = 'ide'
+                return
+            if day[:3] == 'non':
+                self.context['day'] = 'non'
+                return
+            if day[:3] == 'kal':
+                self.context['day'] = 1
+                return
+            day = self._get_int(day)
             if day < 1 or day > 31:
                 raise ParseException(
-                    f"Bad dayOfMonth expression: {ctx.getText()}"
+                    f'Bad dayOfMonth expression: {ctx.getText()}'
                 )
         except:
-            raise ParseException(f"Bad dayOfMonth expression: {ctx.getText()}")
-        self.context["day"] = day
+            raise ParseException(f'Bad dayOfMonth expression: {ctx.getText()}')
+        self.context['day'] = day
 
-    def enterMonthName(
+    def exitMonthName(
         self, ctx: dateparse_utilsParser.MonthNameContext
     ) -> None:
         try:
             month = ctx.getText()
-            month = month.lower()[:3]
+            while month[0] == '/' or month[0] == '-':
+                month = month[1:]
+            month = month[:3].lower()
             month = self.month_name_to_number.get(month, None)
             if month is None:
                 raise ParseException(
-                    f"Bad monthName expression: {ctx.getText()}"
+                    f'Bad monthName expression: {ctx.getText()}'
                 )
         except:
-            raise ParseException(f"Bad monthName expression: {ctx.getText()}")
+            raise ParseException(f'Bad monthName expression: {ctx.getText()}')
         else:
-            self.context["month"] = month
+            self.context['month'] = month
 
-    def enterMonthNumber(
+    def exitMonthNumber(
         self, ctx: dateparse_utilsParser.MonthNumberContext
     ) -> None:
         try:
-            month = int(ctx.getText())
+            month = self._get_int(ctx.getText())
             if month < 1 or month > 12:
                 raise ParseException(
-                    f"Bad monthNumber expression: {ctx.getText()}"
+                    f'Bad monthNumber expression: {ctx.getText()}'
                 )
         except:
-            raise ParseException(f"Bad monthNumber expression: {ctx.getText()}")
+            raise ParseException(
+                f'Bad monthNumber expression: {ctx.getText()}'
+            )
         else:
-            self.context["month"] = month
+            self.context['month'] = month
 
-    def enterYear(self, ctx: dateparse_utilsParser.YearContext) -> None:
+    def exitYear(self, ctx: dateparse_utilsParser.YearContext) -> None:
         try:
-            year = int(ctx.getText())
+            year = self._get_int(ctx.getText())
             if year < 1:
-                raise ParseException(f"Bad year expression: {ctx.getText()}")
+                raise ParseException(f'Bad year expression: {ctx.getText()}')
         except:
-            raise ParseException(f"Bad year expression: {ctx.getText()}")
+            raise ParseException(f'Bad year expression: {ctx.getText()}')
         else:
-            self.context["year"] = year
+            self.context['year'] = year
 
-    def enterSpecialDate(
-        self, ctx: dateparse_utilsParser.SpecialDateContext
+    def exitSpecialDateMaybeYearExpr(
+        self, ctx: dateparse_utilsParser.SpecialDateMaybeYearExprContext
     ) -> None:
         try:
-            txt = ctx.getText().lower()
+            special = ctx.specialDate().getText().lower()
+            self.context['special'] = special
+        except:
+            raise ParseException(
+                f'Bad specialDate expression: {ctx.specialDate().getText()}'
+            )
+        try:
+            mod = ctx.thisNextLast()
+            if mod is not None:
+                if mod.THIS() is not None:
+                    self.context['special_next_last'] = 'this'
+                elif mod.NEXT() is not None:
+                    self.context['special_next_last'] = 'next'
+                elif mod.LAST() is not None:
+                    self.context['special_next_last'] = 'last'
+        except:
+            raise ParseException(
+                f'Bad specialDateNextLast expression: {ctx.getText()}'
+            )
+
+    def exitNFoosFromTodayAgoExpr(
+        self, ctx: dateparse_utilsParser.NFoosFromTodayAgoExprContext
+    ) -> None:
+        d = self.now_datetime
+        try:
+            count = self._get_int(ctx.unsignedInt().getText())
+            unit = ctx.deltaUnit().getText().lower()
+            ago_from_now = ctx.AGO_FROM_NOW().getText()
         except:
-            raise ParseException(f"Bad specialDate expression: {ctx.getText()}")
+            raise ParseException(
+                f'Bad NFoosFromTodayAgoExpr: {ctx.getText()}'
+            )
+
+        if "ago" in ago_from_now or "back" in ago_from_now:
+            count = -count
+
+        unit = self._figure_out_date_unit(unit)
+        d = datetime_utils.n_timeunits_from_base(
+            count,
+            unit,
+            d)
+        self.context['year'] = d.year
+        self.context['month'] = d.month
+        self.context['day'] = d.day
+
+    def exitDeltaRelativeToTodayExpr(
+        self, ctx: dateparse_utilsParser.DeltaRelativeToTodayExprContext
+    ) -> None:
+        d = self.now_datetime
+        try:
+            mod = ctx.thisNextLast()
+            if mod.LAST():
+                count = -1
+            elif mod.THIS():
+                count = +1
+            elif mod.NEXT():
+                count = +2
+            else:
+                raise ParseException(
+                    f'Bad This/Next/Last modifier: {mod}'
+                )
+            unit = ctx.deltaUnit().getText().lower()
+        except:
+            raise ParseException(
+                f'Bad DeltaRelativeToTodayExpr: {ctx.getText()}'
+            )
+        unit = self._figure_out_date_unit(unit)
+        d = datetime_utils.n_timeunits_from_base(
+            count,
+            unit,
+            d)
+        self.context['year'] = d.year
+        self.context['month'] = d.month
+        self.context['day'] = d.day
+
+    def exitSpecialTimeExpr(
+        self, ctx: dateparse_utilsParser.SpecialTimeExprContext
+    ) -> None:
+        try:
+            txt = ctx.specialTime().getText().lower()
+        except:
+            raise ParseException(
+                f'Bad special time expression: {ctx.getText()}'
+            )
         else:
-            self.context["special"] = txt
+            if txt == 'noon' or txt == 'midday':
+                self.context['hour'] = 12
+                self.context['minute'] = 0
+                self.context['seconds'] = 0
+                self.context['micros'] = 0
+            elif txt == 'midnight':
+                self.context['hour'] = 0
+                self.context['minute'] = 0
+                self.context['seconds'] = 0
+                self.context['micros'] = 0
+            else:
+                raise ParseException(f'Bad special time expression: {txt}')
+
+        try:
+            tz = ctx.tzExpr().getText()
+            self.context['tz'] = self._parse_tz(tz)
+        except:
+            pass
+
+    def exitTwelveHourTimeExpr(
+        self, ctx: dateparse_utilsParser.TwelveHourTimeExprContext
+    ) -> None:
+        try:
+            hour = ctx.hour().getText()
+            while not hour[-1].isdigit():
+                hour = hour[:-1]
+            hour = self._get_int(hour)
+        except:
+            raise ParseException(f'Bad hour: {ctx.hour().getText()}')
+        if hour <= 0 or hour > 12:
+            raise ParseException(f'Bad hour (out of range): {hour}')
+
+        try:
+            minute = self._get_int(ctx.minute().getText())
+        except:
+            minute = 0
+        if minute < 0 or minute > 59:
+            raise ParseException(f'Bad minute (out of range): {minute}')
+        self.context['minute'] = minute
+
+        try:
+            seconds = self._get_int(ctx.second().getText())
+        except:
+            seconds = 0
+        if seconds < 0 or seconds > 59:
+            raise ParseException(f'Bad second (out of range): {seconds}')
+        self.context['seconds'] = seconds
+
+        try:
+            micros = self._get_int(ctx.micros().getText())
+        except:
+            micros = 0
+        if micros < 0 or micros > 1000000:
+            raise ParseException(f'Bad micros (out of range): {micros}')
+        self.context['micros'] = micros
+
+        try:
+            ampm = ctx.ampm().getText()
+        except:
+            raise ParseException(f'Bad ampm: {ctx.ampm().getText()}')
+        if hour == 12:
+            hour = 0
+        if ampm[0] == 'p':
+            hour += 12
+        self.context['hour'] = hour
+
+        try:
+            tz = ctx.tzExpr().getText()
+            self.context['tz'] = self._parse_tz(tz)
+        except:
+            pass
+
+    def exitTwentyFourHourTimeExpr(
+        self, ctx: dateparse_utilsParser.TwentyFourHourTimeExprContext
+    ) -> None:
+        try:
+            hour = ctx.hour().getText()
+            while not hour[-1].isdigit():
+                hour = hour[:-1]
+            hour = self._get_int(hour)
+        except:
+            raise ParseException(f'Bad hour: {ctx.hour().getText()}')
+        if hour < 0 or hour > 23:
+            raise ParseException(f'Bad hour (out of range): {hour}')
+        self.context['hour'] = hour
+
+        try:
+            minute = self._get_int(ctx.minute().getText())
+        except:
+            minute = 0
+        if minute < 0 or minute > 59:
+            raise ParseException(f'Bad minute (out of range): {ctx.getText()}')
+        self.context['minute'] = minute
+
+        try:
+            seconds = self._get_int(ctx.second().getText())
+        except:
+            seconds = 0
+        if seconds < 0 or seconds > 59:
+            raise ParseException(f'Bad second (out of range): {ctx.getText()}')
+        self.context['seconds'] = seconds
+
+        try:
+            micros = self._get_int(ctx.micros().getText())
+        except:
+            micros = 0
+        if micros < 0 or micros >= 1000000:
+            raise ParseException(f'Bad micros (out of range): {ctx.getText()}')
+        self.context['micros'] = micros
+
+        try:
+            tz = ctx.tzExpr().getText()
+            self.context['tz'] = self._parse_tz(tz)
+        except:
+            pass
 
 
 def main() -> None:
     parser = DateParser()
     for line in sys.stdin:
         line = line.strip()
-        line = line.lower()
         line = re.sub(r"#.*$", "", line)
         if re.match(r"^ *$", line) is not None:
             continue
-        print(parser.parse_date_string(line))
+        try:
+            dt = parser.parse(line)
+        except Exception as e:
+            print("Unrecognized.")
+        else:
+            print(dt.strftime('%A %Y/%m/%d %H:%M:%S.%f %Z(%z)'))
     sys.exit(0)
 
 
 if __name__ == "__main__":
+    main = bootstrap.initialize(main)
     main()
index d70bf4a79008effe5e2a3aec8684d3352be6c78d..0b94283b01df595300ecc448f108303c2dba2b2e 100644 (file)
@@ -3,10 +3,12 @@
 """Utilities related to dates and times and datetimes."""
 
 import datetime
+import enum
 import logging
 import re
-from typing import NewType
+from typing import NewType, Tuple
 
+import holidays  # type: ignore
 import pytz
 
 import constants
@@ -14,29 +16,173 @@ import constants
 logger = logging.getLogger(__name__)
 
 
-def now_pst() -> datetime.datetime:
-    return datetime.datetime.now(tz=pytz.timezone("US/Pacific"))
+def replace_timezone(dt: datetime.datetime,
+                     tz: datetime.tzinfo) -> datetime.datetime:
+    return dt.replace(tzinfo=None).astimezone(tz=tz)
 
 
 def now() -> datetime.datetime:
     return datetime.datetime.now()
 
 
-def datetime_to_string(
-    dt: datetime.datetime,
-    *,
-    date_time_separator=" ",
-    include_timezone=True,
-    include_dayname=False,
-    include_seconds=True,
-    include_fractional=False,
-    twelve_hour=True,
+def now_pst() -> datetime.datetime:
+    return replace_timezone(now(), pytz.timezone("US/Pacific"))
+
+
+def date_to_datetime(date: datetime.date) -> datetime.datetime:
+    return datetime.datetime(
+        date.year,
+        date.month,
+        date.day,
+        0, 0, 0, 0
+    )
+
+
+def date_and_time_to_datetime(date: datetime.date,
+                              time: datetime.time) -> datetime.datetime:
+    return datetime.datetime(
+        date.year,
+        date.month,
+        date.day,
+        time.hour,
+        time.minute,
+        time.second,
+        time.millisecond
+    )
+
+
+def datetime_to_date(date: datetime.datetime) -> datetime.date:
+    return datetime.date(
+        date.year,
+        date.month,
+        date.day
+    )
+
+
+# An enum to represent units with which we can compute deltas.
+class TimeUnit(enum.Enum):
+    MONDAYS = 0
+    TUESDAYS = 1
+    WEDNESDAYS = 2
+    THURSDAYS = 3
+    FRIDAYS = 4
+    SATURDAYS = 5
+    SUNDAYS = 6
+    SECONDS = 7
+    MINUTES = 8
+    HOURS = 9
+    DAYS = 10
+    WORKDAYS = 11
+    WEEKS = 12
+    MONTHS = 13
+    YEARS = 14
+
+
+def n_timeunits_from_base(
+    count: int,
+    unit: TimeUnit,
+    base: datetime.datetime
+) -> datetime.datetime:
+    if count == 0:
+        return base
+
+    # N days from base
+    if unit == TimeUnit.DAYS:
+        timedelta = datetime.timedelta(days=count)
+        return base + timedelta
+
+    # N workdays from base
+    elif unit == TimeUnit.WORKDAYS:
+        if count < 0:
+            count = abs(count)
+            timedelta = datetime.timedelta(days=-1)
+        else:
+            timedelta = datetime.timedelta(days=1)
+        skips = holidays.US(years=base.year).keys()
+        while count > 0:
+            old_year = base.year
+            base += timedelta
+            if base.year != old_year:
+                skips = holidays.US(years=base.year).keys()
+            if (
+                    base.weekday() < 5 and
+                    datetime.date(base.year,
+                                  base.month,
+                                  base.day) not in skips
+            ):
+                count -= 1
+        return base
+
+    # N weeks from base
+    elif unit == TimeUnit.WEEKS:
+        timedelta = datetime.timedelta(weeks=count)
+        base = base + timedelta
+        return base
+
+    # N months from base
+    elif unit == TimeUnit.MONTHS:
+        month_term = count % 12
+        year_term = count // 12
+        new_month = base.month + month_term
+        if new_month > 12:
+            new_month %= 12
+            year_term += 1
+        new_year = base.year + year_term
+        return datetime.datetime(
+            new_year,
+            new_month,
+            base.day,
+            base.hour,
+            base.minute,
+            base.second,
+            base.microsecond,
+        )
+
+    # N years from base
+    elif unit == TimeUnit.YEARS:
+        new_year = base.year + count
+        return datetime.datetime(
+            new_year,
+            base.month,
+            base.day,
+            base.hour,
+            base.minute,
+            base.second,
+            base.microsecond,
+        )
+
+    # N weekdays from base (e.g. 4 wednesdays from today)
+    direction = 1 if count > 0 else -1
+    count = abs(count)
+    timedelta = datetime.timedelta(days=direction)
+    start = base
+    while True:
+        dow = base.weekday()
+        if dow == unit and start != base:
+            count -= 1
+            if count == 0:
+                return base
+        base = base + timedelta
+
+
+def get_format_string(
+        *,
+        date_time_separator=" ",
+        include_timezone=True,
+        include_dayname=False,
+        use_month_abbrevs=False,
+        include_seconds=True,
+        include_fractional=False,
+        twelve_hour=True,
 ) -> str:
-    """A nice way to convert a datetime into a string."""
     fstring = ""
     if include_dayname:
         fstring += "%a/"
-    fstring = f"%Y/%b/%d{date_time_separator}"
+
+    if use_month_abbrevs:
+        fstring = f"%Y/%b/%d{date_time_separator}"
+    else:
+        fstring = f"%Y/%m/%d{date_time_separator}"
     if twelve_hour:
         fstring += "%I:%M"
         if include_seconds:
@@ -50,9 +196,58 @@ def datetime_to_string(
         fstring += ".%f"
     if include_timezone:
         fstring += "%z"
+    return fstring
+
+
+def datetime_to_string(
+    dt: datetime.datetime,
+    *,
+    date_time_separator=" ",
+    include_timezone=True,
+    include_dayname=False,
+    use_month_abbrevs=False,
+    include_seconds=True,
+    include_fractional=False,
+    twelve_hour=True,
+) -> str:
+    """A nice way to convert a datetime into a string."""
+    fstring = get_format_string(
+        date_time_separator=date_time_separator,
+        include_timezone=include_timezone,
+        include_dayname=include_dayname,
+        include_seconds=include_seconds,
+        include_fractional=include_fractional,
+        twelve_hour=twelve_hour)
     return dt.strftime(fstring).strip()
 
 
+def string_to_datetime(
+        txt: str,
+        *,
+        date_time_separator=" ",
+        include_timezone=True,
+        include_dayname=False,
+        use_month_abbrevs=False,
+        include_seconds=True,
+        include_fractional=False,
+        twelve_hour=True,
+) -> Tuple[datetime.datetime, str]:
+    """A nice way to convert a string into a datetime.  Also consider
+    dateparse.dateparse_utils for a full parser.
+    """
+    fstring = get_format_string(
+        date_time_separator=date_time_separator,
+        include_timezone=include_timezone,
+        include_dayname=include_dayname,
+        include_seconds=include_seconds,
+        include_fractional=include_fractional,
+        twelve_hour=twelve_hour)
+    return (
+        datetime.datetime.strptime(txt, fstring),
+        fstring
+    )
+
+
 def timestamp() -> str:
     """Return a timestamp for now in Pacific timezone."""
     ts = datetime.datetime.now(tz=pytz.timezone("US/Pacific"))
@@ -104,7 +299,9 @@ def datetime_to_minute_number(dt: datetime.datetime) -> MinuteOfDay:
 
 
 def minute_number_to_time_string(minute_num: MinuteOfDay) -> str:
-    """Convert minute number from start of day into hour:minute am/pm string."""
+    """Convert minute number from start of day into hour:minute am/pm
+    string.
+    """
     hour = minute_num // 60
     minute = minute_num % 60
     ampm = "a"
@@ -171,5 +368,6 @@ def describe_duration_briefly(age: int) -> str:
         descr = f"{int(days[0])}d "
     if hours[0] > 0:
         descr = descr + f"{int(hours[0])}h "
-    descr = descr + f"{int(minutes[0])}m"
-    return descr
+    if minutes[0] > 0 or len(descr) == 0:
+        descr = descr + f"{int(minutes[0])}m"
+    return descr.strip()
index 03e7c880433fad5d359a2bb3acc29a4266204e65..c07023b1205950cad1d89edbbe13ad457c45f678 100644 (file)
@@ -5,6 +5,7 @@
 import datetime
 import enum
 import functools
+import inspect
 import logging
 import math
 import multiprocessing
@@ -20,6 +21,7 @@ import warnings
 import exceptions
 import thread_utils
 
+
 logger = logging.getLogger(__name__)
 
 
@@ -517,3 +519,19 @@ def call_with_sample_rate(sample_rate: float) -> Callable:
                 )
         return _call_with_sample_rate
     return decorator
+
+
+def decorate_matching_methods_with(decorator, acl=None):
+    """Apply decorator to all methods in a class whose names begin with
+    prefix.  If prefix is None (default), decorate all methods in the
+    class.
+    """
+    def decorate_the_class(cls):
+        for name, m in inspect.getmembers(cls, inspect.isfunction):
+            if acl is None:
+                setattr(cls, name, decorator(m))
+            else:
+                if acl(name):
+                    setattr(cls, name, decorator(m))
+        return cls
+    return decorate_the_class
index 29a5cd0c5b10fb84a61d078698109b8ebd31ffb6..2362749b12c91b7f2d2b6519d3c9db1506a12fe8 100644 (file)
@@ -1,7 +1,9 @@
 #!/usr/bin/env python3
 
 from itertools import islice
-from typing import Any, Callable, Dict, Iterator
+from typing import Any, Callable, Dict, Iterator, Tuple
+
+import list_utils
 
 
 def init_or_inc(
@@ -24,11 +26,39 @@ def shard(d: Dict[Any, Any], size: int) -> Iterator[Dict[Any, Any]]:
         yield {key: value for (key, value) in islice(items, x, x + size)}
 
 
-def item_with_max_value(d: Dict[Any, Any]) -> Any:
+def coalesce_by_creating_list(key, v1, v2):
+    return list_utils.flatten([v1, v2])
+
+
+def coalesce_by_creating_set(key, v1, v2):
+    return set(coalesce_by_creating_list(key, v1, v2))
+
+
+def raise_on_duplicated_keys(key, v1, v2):
+    raise Exception(f'Key {key} is duplicated in more than one input dict.')
+
+
+def coalesce(
+        inputs: Iterator[Dict[Any, Any]],
+        *,
+        aggregation_function: Callable[[Any, Any, Any], Any] = coalesce_by_creating_list
+) -> Dict[Any, Any]:
+    out = {}
+    for d in inputs:
+        for key in d:
+            if key in out:
+                value = aggregation_function(d[key], out[key])
+            else:
+                value = d[key]
+            out[key] = value
+    return out
+
+
+def item_with_max_value(d: Dict[Any, Any]) -> Tuple[Any, Any]:
     return max(d.items(), key=lambda _: _[1])
 
 
-def item_with_min_value(d: Dict[Any, Any]) -> Any:
+def item_with_min_value(d: Dict[Any, Any]) -> Tuple[Any, Any]:
     return min(d.items(), key=lambda _: _[1])
 
 
@@ -54,19 +84,3 @@ def max_key(d: Dict[Any, Any]) -> Any:
 
 def min_key(d: Dict[Any, Any]) -> Any:
     return min(d.keys())
-
-
-def merge(a: Dict[Any, Any], b: Dict[Any, Any], path=None) -> Dict[Any, Any]:
-    if path is None:
-        path = []
-    for key in b:
-        if key in a:
-            if isinstance(a[key], dict) and isinstance(b[key], dict):
-                merge(a[key], b[key], path + [str(key)])
-            elif a[key] == b[key]:
-                pass
-            else:
-                raise Exception("Conflict at %s" % ".".join(path + [str(key)]))
-        else:
-            a[key] = b[key]
-    return a
index b9c0748391f733e0719a744d9b30280c34bb30ee..2b2f0252ce326fd34800080735f5324f91bfc315 100644 (file)
@@ -52,12 +52,17 @@ parser.add_argument(
     action=argparse_utils.ActionNoYes,
     help='Should we schedule duplicative backup work if a remote bundle is slow',
 )
+parser.add_argument(
+    '--executors_max_bundle_failures',
+    type=int,
+    default=3,
+    metavar='#FAILURES',
+    help='Maximum number of failures before giving up on a bundle',
+)
 
-rsync = 'rsync -q --no-motd -W --ignore-existing --timeout=60 --size-only -z'
-ssh = 'ssh -oForwardX11=no'
-
-
-hist = histogram.SimpleHistogram(
+RSYNC = 'rsync -q --no-motd -W --ignore-existing --timeout=60 --size-only -z'
+SSH = 'ssh -oForwardX11=no'
+HIST = histogram.SimpleHistogram(
     histogram.SimpleHistogram.n_evenly_spaced_buckets(
         int(0), int(500), 25
     )
@@ -71,7 +76,7 @@ def run_local_bundle(fun, *args, **kwargs):
     end = time.time()
     duration = end - start
     logger.debug(f"{fun.__name__} finished; used {duration:.1f}s")
-    hist.add_item(duration)
+    HIST.add_item(duration)
     return result
 
 
@@ -144,7 +149,7 @@ class ThreadExecutor(BaseExecutor):
     def shutdown(self,
                  wait = True) -> None:
         logger.debug("Shutting down threadpool executor.")
-        print(hist)
+        print(HIST)
         self._thread_pool_executor.shutdown(wait)
 
 
@@ -177,7 +182,7 @@ class ProcessExecutor(BaseExecutor):
 
     def shutdown(self, wait=True) -> None:
         logger.debug('Shutting down processpool executor')
-        print(hist)
+        print(HIST)
         self._process_executor.shutdown(wait)
 
 
@@ -214,6 +219,7 @@ class BundleDetails:
     is_cancelled: threading.Event
     was_cancelled: bool
     backup_bundles: Optional[List[BundleDetails]]
+    failure_count: int
 
 
 class RemoteExecutorStatus:
@@ -559,10 +565,12 @@ class RemoteExecutor(BaseExecutor):
         return False
 
     def launch(self, bundle: BundleDetails) -> Any:
-        # Find a worker for bundle or block until one is available.
+        """Find a worker for bundle or block until one is available."""
         uuid = bundle.uuid
         hostname = bundle.hostname
         avoid_machine = None
+
+        # Try not to schedule a backup on the same host as the original.
         if bundle.src_bundle is not None:
             avoid_machine = bundle.src_bundle.machine
         worker = None
@@ -574,29 +582,33 @@ class RemoteExecutor(BaseExecutor):
         self.status.record_acquire_worker(worker, uuid)
         logger.debug(f'Running bundle {uuid} on {worker}...')
 
-        # Before we do work, make sure it's still viable.
+        # Before we do any work, make sure the bundle is still viable.
         if self.check_if_cancelled(bundle):
-            return self.post_launch_work(bundle)
+            try:
+                return self.post_launch_work(bundle)
+            except Exception as e:
+                logger.exception(e)
+                logger.info(f"Bundle {uuid} seems to have failed?!")
+                if bundle.failure_count < config.config['executors_max_bundle_failures']:
+                    return self.launch(bundle)
+                else:
+                    logger.info(f"Bundle {uuid} is poison, giving up on it.")
+                    return None
 
         # Send input to machine if it's not local.
         if hostname not in machine:
-            cmd = f'{rsync} {bundle.code_file} {username}@{machine}:{bundle.code_file}'
-            logger.debug(f"Copying work to {worker} via {cmd}")
+            cmd = f'{RSYNC} {bundle.code_file} {username}@{machine}:{bundle.code_file}'
+            logger.info(f"Copying work to {worker} via {cmd}")
             exec_utils.run_silently(cmd)
 
-        # Before we do more work, make sure it's still viable.
-        if self.check_if_cancelled(bundle):
-            return self.post_launch_work(bundle)
-
-        # Fucking Apple has a python3 binary in /usr/sbin that is not
-        # the one we want and is protected by the OS so make sure that
-        # /usr/local/bin is early in the path.
-        cmd = (f'{ssh} {bundle.username}@{bundle.machine} '
-               f'"export PATH=/usr/local/bin:/sbin:/bin:/usr/sbin:/usr/bin:/usr/local/sbin:/home/scott/bin:/home/scott/.local/bin; /home/scott/lib/python_modules/remote_worker.py'
+        # Do it.
+        cmd = (f'{SSH} {bundle.username}@{bundle.machine} '
+               f'"source remote-execution/bin/activate &&'
+               f' /home/scott/lib/python_modules/remote_worker.py'
                f' --code_file {bundle.code_file} --result_file {bundle.result_file}"')
         p = exec_utils.cmd_in_background(cmd, silent=True)
         bundle.pid = pid = p.pid
-        logger.debug(f"Running {cmd} in the background as process {pid}")
+        logger.info(f"Running {cmd} in the background as process {pid}")
 
         while True:
             try:
@@ -614,7 +626,16 @@ class RemoteExecutor(BaseExecutor):
                     f"{pid}/{bundle.uuid} has finished its work normally."
                 )
                 break
-        return self.post_launch_work(bundle)
+
+        try:
+            return self.post_launch_work(bundle)
+        except Exception as e:
+            logger.exception(e)
+            logger.info(f"Bundle {uuid} seems to have failed?!")
+            if bundle.failure_count < config.config['executors_max_bundle_failures']:
+                return self.launch(bundle)
+            logger.info(f"Bundle {uuid} is poison, giving up on it.")
+            return None
 
     def post_launch_work(self, bundle: BundleDetails) -> Any:
         with self.status.lock:
@@ -631,15 +652,15 @@ class RemoteExecutor(BaseExecutor):
             if not was_cancelled:
                 assert bundle.machine is not None
                 if bundle.hostname not in bundle.machine:
-                    cmd = f'{rsync} {username}@{machine}:{result_file} {result_file} 2>/dev/null'
-                    logger.debug(
+                    cmd = f'{RSYNC} {username}@{machine}:{result_file} {result_file} 2>/dev/null'
+                    logger.info(
                         f"Fetching results from {username}@{machine} via {cmd}"
                     )
                     try:
                         exec_utils.run_silently(cmd)
                     except subprocess.CalledProcessError:
                         pass
-                    exec_utils.run_silently(f'{ssh} {username}@{machine}'
+                    exec_utils.run_silently(f'{SSH} {username}@{machine}'
                                             f' "/bin/rm -f {code_file} {result_file}"')
             bundle.end_ts = time.time()
             assert bundle.worker is not None
@@ -650,15 +671,31 @@ class RemoteExecutor(BaseExecutor):
             )
             if not was_cancelled:
                 dur = bundle.end_ts - bundle.start_ts
-                hist.add_item(dur)
+                HIST.add_item(dur)
+
+        # Original or not, the results should be back on the local
+        # machine.  Are they?
+        if not os.path.exists(result_file):
+            msg = f'{result_file} unexpectedly missing, wtf?!'
+            logger.critical(msg)
+            bundle.failure_count += 1
+            self.release_worker(bundle.worker)
+            raise Exception(msg)
 
         # Only the original worker should unpickle the file contents
         # though since it's the only one whose result matters.
         if is_original:
             logger.debug(f"Unpickling {result_file}.")
-            with open(f'{result_file}', 'rb') as rb:
-                serialized = rb.read()
-            result = cloudpickle.loads(serialized)
+            try:
+                with open(f'{result_file}', 'rb') as rb:
+                    serialized = rb.read()
+                result = cloudpickle.loads(serialized)
+            except Exception as e:
+                msg = f'Failed to load {result_file}'
+                logger.critical(msg)
+                bundle.failure_count += 1
+                self.release_worker(bundle.worker)
+                raise Exception(e)
             os.remove(f'{result_file}')
             os.remove(f'{code_file}')
 
@@ -718,6 +755,7 @@ class RemoteExecutor(BaseExecutor):
             is_cancelled = threading.Event(),
             was_cancelled = False,
             backup_bundles = [],
+            failure_count = 0,
         )
         self.status.record_bundle_details(bundle)
         logger.debug(f'Created original bundle {uuid}')
@@ -746,6 +784,7 @@ class RemoteExecutor(BaseExecutor):
             is_cancelled = threading.Event(),
             was_cancelled = False,
             backup_bundles = None,    # backup backups not allowed
+            failure_count = 0,
         )
         src_bundle.backup_bundles.append(backup_bundle)
         self.status.record_bundle_details_already_locked(backup_bundle)
@@ -779,7 +818,7 @@ class RemoteExecutor(BaseExecutor):
 
     def shutdown(self, wait=True) -> None:
         self._helper_executor.shutdown(wait)
-        print(hist)
+        print(HIST)
 
 
 @singleton
@@ -815,8 +854,8 @@ class DefaultExecutors(object):
                     RemoteWorkerRecord(
                         username = 'scott',
                         machine = 'cheetah.house',
-                        weight = 10,
-                        count = 6,
+                        weight = 12,
+                        count = 4,
                     ),
                 )
             if self.ping('video.house'):
@@ -824,7 +863,7 @@ class DefaultExecutors(object):
                     RemoteWorkerRecord(
                         username = 'scott',
                         machine = 'video.house',
-                        weight = 2,
+                        weight = 1,
                         count = 4,
                     ),
                 )
@@ -851,8 +890,8 @@ class DefaultExecutors(object):
                     RemoteWorkerRecord(
                         username = 'scott',
                         machine = 'backup.house',
-                        weight = 3,
-                        count = 2,
+                        weight = 1,
+                        count = 4,
                     ),
                 )
             if self.ping('puma.cabin'):
@@ -860,8 +899,8 @@ class DefaultExecutors(object):
                     RemoteWorkerRecord(
                         username = 'scott',
                         machine = 'puma.cabin',
-                        weight = 10,
-                        count = 6,
+                        weight = 12,
+                        count = 4,
                     ),
                 )
             policy = WeightedRandomRemoteWorkerSelectionPolicy()
index 7cc8b632ac692d47a6f272403100c6806dc19136..7108d6a0db74842a74587e8f9c53a1973fb0bde8 100644 (file)
@@ -244,5 +244,5 @@ def get_files_recursive(directory: str):
     for filename in get_files(directory):
         yield filename
     for subdir in get_directories(directory):
-        for filename in get_files_recursive(subdir):
-            yield filename
+        for file_or_directory in get_files_recursive(subdir):
+            yield file_or_directory
index 8c21b86312dcd25c5771fcd2be44c8ec0ca81dc4..f63ba0b9b75193d2a076ed5fcfa9c015b65dc621 100644 (file)
@@ -79,6 +79,18 @@ class Light(ABC):
     def turn_off(self) -> bool:
         pass
 
+    @abstractmethod
+    def is_on(self) -> bool:
+        pass
+
+    @abstractmethod
+    def is_off(self) -> bool:
+        pass
+
+    @abstractmethod
+    def get_dimmer_level(self) -> Optional[int]:
+        pass
+
     @abstractmethod
     def set_dimmer_level(self, level: int) -> bool:
         pass
@@ -119,11 +131,42 @@ class GoogleLight(Light):
             goog.ask_google(f"turn {self.goog_name()} off")
         )
 
+    def is_on(self) -> bool:
+        r = goog.ask_google(f"is {self.goog_name()} on?")
+        if not r.success:
+            return False
+        return 'is on' in r.audio_transcription
+
+    def is_off(self) -> bool:
+        return not self.is_on()
+
+    def get_dimmer_level(self) -> Optional[int]:
+        if not self.has_keyword("dimmer"):
+            return False
+        r = goog.ask_google(f'how bright is {self.goog_name()}?')
+        if not r.success:
+            return None
+
+        # the bookcase one is set to 40% bright
+        txt = r.audio_transcription
+        m = re.search(r"(\d+)% bright", txt)
+        if m is not None:
+            return int(m.group(1))
+        if "is off" in txt:
+            return 0
+        return None
+
     def set_dimmer_level(self, level: int) -> bool:
+        if not self.has_keyword("dimmer"):
+            return False
         if 0 <= level <= 100:
-            return GoogleLight.parse_google_response(
-                goog.ask_google(f"set {self.goog_name()} to {level} percent")
-            )
+            was_on = self.is_on()
+            r = goog.ask_google(f"set {self.goog_name()} to {level} percent")
+            if not r.success:
+                return False
+            if not was_on:
+                self.turn_off()
+            return True
         return False
 
     def make_color(self, color: str) -> bool:
@@ -177,6 +220,12 @@ class TPLinkLight(Light):
     def turn_off(self, child: str = None) -> bool:
         return self.command("off", child)
 
+    def is_on(self) -> bool:
+        return self.get_on_duration_seconds() > 0
+
+    def is_off(self) -> bool:
+        return not self.is_on()
+
     def make_color(self, color: str) -> bool:
         raise NotImplementedError
 
@@ -220,6 +269,14 @@ class TPLinkLight(Light):
                 return int(m.group(1)) * 60
         return None
 
+    def get_dimmer_level(self) -> Optional[int]:
+        if not self.has_keyword("dimmer"):
+            return False
+        self.info = self.get_info()
+        if self.info is None:
+            return None
+        return int(self.info.get("brightness", "0"))
+
     def set_dimmer_level(self, level: int) -> bool:
         if not self.has_keyword("dimmer"):
             return False
index 9a5d4fde0dcad7936a16ffc6f53a7a50ba0b67fa..74f1cf3078457d371194deb33ddf5ad6410ed599 100644 (file)
@@ -15,7 +15,6 @@ def flatten(lst: List[Any]) -> List[Any]:
 
         >>> flatten([ 1, [2, 3, 4, [5], 6], 7, [8, [9]]])
         [1, 2, 3, 4, 5, 6, 7, 8, 9]
-
     """
     if len(lst) == 0:
         return lst
index 03a23d9ad063a352d4f1728e585fc6a83421958b..a24f1c9fb520f37ad7c13ec423e65c08b86074ae 100644 (file)
@@ -15,17 +15,17 @@ import config
 import string_utils as su
 import thread_utils as tu
 
-parser = config.add_commandline_args(
+cfg = config.add_commandline_args(
     f'Logging ({__file__})',
     'Args related to logging')
-parser.add_argument(
+cfg.add_argument(
     '--logging_config_file',
     type=argparse_utils.valid_filename,
     default=None,
     metavar='FILENAME',
     help='Config file containing the logging setup, see: https://docs.python.org/3/howto/logging.html#logging-advanced-tutorial',
 )
-parser.add_argument(
+cfg.add_argument(
     '--logging_level',
     type=str,
     default='INFO',
@@ -33,59 +33,59 @@ parser.add_argument(
     metavar='LEVEL',
     help='The level below which to squelch log messages.',
 )
-parser.add_argument(
+cfg.add_argument(
     '--logging_format',
     type=str,
     default='%(levelname)s:%(asctime)s: %(message)s',
     help='The format for lines logged via the logger module.'
 )
-parser.add_argument(
+cfg.add_argument(
     '--logging_date_format',
     type=str,
     default='%Y/%m/%dT%H:%M:%S.%f%z',
     metavar='DATEFMT',
     help='The format of any dates in --logging_format.'
 )
-parser.add_argument(
+cfg.add_argument(
     '--logging_console',
     action=argparse_utils.ActionNoYes,
     default=True,
     help='Should we log to the console (stderr)',
 )
-parser.add_argument(
+cfg.add_argument(
     '--logging_filename',
     type=str,
     default=None,
     metavar='FILENAME',
     help='The filename of the logfile to write.'
 )
-parser.add_argument(
+cfg.add_argument(
     '--logging_filename_maxsize',
     type=int,
     default=(1024*1024),
     metavar='#BYTES',
     help='The maximum size (in bytes) to write to the logging_filename.'
 )
-parser.add_argument(
+cfg.add_argument(
     '--logging_filename_count',
     type=int,
     default=2,
     metavar='COUNT',
     help='The number of logging_filename copies to keep before deleting.'
 )
-parser.add_argument(
+cfg.add_argument(
     '--logging_syslog',
     action=argparse_utils.ActionNoYes,
     default=False,
     help='Should we log to localhost\'s syslog.'
 )
-parser.add_argument(
+cfg.add_argument(
     '--logging_debug_threads',
     action=argparse_utils.ActionNoYes,
     default=False,
     help='Should we prepend pid/tid data to all log messages?'
 )
-parser.add_argument(
+cfg.add_argument(
     '--logging_info_is_print',
     action=argparse_utils.ActionNoYes,
     default=False,
index 62e579846b52bed06ade1a37b3e1959bc5c3fa0c..3775d3f4f1c72644e1f66f61d14c5cb6c4336b10 100644 (file)
@@ -7,29 +7,3 @@ import string_utils
 
 def is_running_as_root() -> bool:
     return os.geteuid() == 0
-
-
-def is_are(n: int) -> str:
-    if n == 1:
-        return "is"
-    return "are"
-
-
-def pluralize(n: int) -> str:
-    if n == 1:
-        return ""
-    return "s"
-
-
-def thify(n: int) -> str:
-    digit = str(n)
-    assert string_utils.is_integer_number(digit)
-    digit = digit[-1:]
-    if digit == "1":
-        return "st"
-    elif digit == "2":
-        return "nd"
-    elif digit == "3":
-        return "rd"
-    else:
-        return "th"
index edddcc0c9f794232a5d2ee6593793518abd2ef12..b0a9a1bcce5a51bd8f7d875a5f9a46fe9a7efe76 100644 (file)
@@ -142,35 +142,36 @@ class TrainingBlueprint(ABC):
         best_test_score = None
         best_training_score = None
         best_params = None
-        for model in smart_future.wait_many(models):
+        for model in smart_future.wait_any(models):
             params = modelid_to_params[model.get_id()]
             if isinstance(model, smart_future.SmartFuture):
                 model = model._resolve()
-            training_score, test_score = self.evaluate_model(
-                model,
-                self.X_train_scaled,
-                self.y_train,
-                self.X_test_scaled,
-                self.y_test,
-            )
-            score = (training_score + test_score * 20) / 21
-            if not self.spec.quiet:
-                print(
-                    f"{bold()}{params}{reset()}: "
-                    f"Training set score={training_score:.2f}%, "
-                    f"test set score={test_score:.2f}%",
-                    file=sys.stderr,
+            if model is not None:
+                training_score, test_score = self.evaluate_model(
+                    model,
+                    self.X_train_scaled,
+                    self.y_train,
+                    self.X_test_scaled,
+                    self.y_test,
                 )
-            if best_score is None or score > best_score:
-                best_score = score
-                best_test_score = test_score
-                best_training_score = training_score
-                best_model = model
-                best_params = params
+                score = (training_score + test_score * 20) / 21
                 if not self.spec.quiet:
                     print(
-                        f"New best score {best_score:.2f}% with params {params}"
+                        f"{bold()}{params}{reset()}: "
+                        f"Training set score={training_score:.2f}%, "
+                        f"test set score={test_score:.2f}%",
+                        file=sys.stderr,
                     )
+                if best_score is None or score > best_score:
+                    best_score = score
+                    best_test_score = test_score
+                    best_training_score = training_score
+                    best_model = model
+                    best_params = params
+                    if not self.spec.quiet:
+                        print(
+                            f"New best score {best_score:.2f}% with params {params}"
+                        )
 
         if not self.spec.quiet:
             msg = f"Done training; best test set score was: {best_test_score:.1f}%"
@@ -279,7 +280,7 @@ class TrainingBlueprint(ABC):
             file_list = list(files)
             results.append(self.read_files_from_list(file_list, n))
 
-        for result in smart_future.wait_many(results, callback=self.make_progress_graph):
+        for result in smart_future.wait_any(results, callback=self.make_progress_graph):
             result = result._resolve()
             for z in result[0]:
                 X.append(z)
index 02f85d6bde44013260f791491e63b6e5981a10d6..b6e9fc304be31188b7fd01bc7ac8616057a203ba 100644 (file)
@@ -5,11 +5,9 @@ from collections import defaultdict
 import enum
 import logging
 import re
-import sys
 from typing import Dict, List
 
 import argparse_utils
-import bootstrap
 import config
 import dict_utils
 import exec_utils
@@ -82,6 +80,9 @@ class PresenceDetection(object):
             Location, Dict[str, datetime.datetime]
         ] = defaultdict(dict)
         self.names_by_mac: Dict[str, str] = {}
+        self.update()
+
+    def update(self) -> None:
         persisted_macs = config.config['presence_macs_file']
         self.read_persisted_macs_file(persisted_macs, Location.HOUSE)
         raw = exec_utils.cmd(
@@ -168,20 +169,3 @@ class PresenceDetection(object):
             item = dict_utils.item_with_max_value(votes)
             return item[0]
         return Location.UNKNOWN
-
-
-def main() -> None:
-    config.parse()
-    p = PresenceDetection()
-
-    for loc in Location:
-        print(f'{loc}: {p.is_anyone_in_location_now(loc)}')
-
-    for u in Person:
-        print(f'{u}: {p.where_is_person_now(u)}')
-    sys.exit(0)
-
-
-if __name__ == '__main__':
-    main()
index ebd510040d15ac377165281a75c20c8ce63a8474..43b841589c670b758d52c777e716835b32863c51 100755 (executable)
@@ -14,6 +14,7 @@ import time
 import cloudpickle  # type: ignore
 import psutil  # type: ignore
 
+import argparse_utils
 import bootstrap
 import config
 from thread_utils import background_thread
@@ -37,6 +38,12 @@ cfg.add_argument(
     metavar='FILENAME',
     help='The location where we should write the computation results.'
 )
+cfg.add_argument(
+    '--watch_for_cancel',
+    action=argparse_utils.ActionNoYes,
+    default=False,
+    help='Should we watch for the cancellation of our parent ssh process?'
+)
 
 
 @background_thread
@@ -50,7 +57,6 @@ def watch_for_cancel(terminate_event: threading.Event) -> None:
             if 'ssh' in name or 'Ssh' in name:
                 saw_sshd = True
                 break
-
         if not saw_sshd:
             os.system('pstree')
             os.kill(os.getpid(), signal.SIGTERM)
@@ -59,33 +65,38 @@ def watch_for_cancel(terminate_event: threading.Event) -> None:
         time.sleep(1.0)
 
 
-def main() -> None:
-    hostname = platform.node()
-
-    # Windows-Linux is retarded.
-    if hostname != 'VIDEO-COMPUTER':
-        (thread, terminate_event) = watch_for_cancel()
-
-    in_file = config.config['code_file']
-    out_file = config.config['result_file']
-
-    with open(in_file, 'rb') as rb:
-        serialized = rb.read()
-
-    fun, args, kwargs = cloudpickle.loads(serialized)
-    ret = fun(*args, **kwargs)
-
-    serialized = cloudpickle.dumps(ret)
-    with open(out_file, 'wb') as wb:
-        wb.write(serialized)
-
-    # Windows-Linux is retarded.
-    if hostname != 'VIDEO-COMPUTER':
-        terminate_event.set()
-        thread.join()
-    sys.exit(0)
-
-
 if __name__ == '__main__':
+    @bootstrap.initialize
+    def main() -> None:
+        hostname = platform.node()
+
+        # Windows-Linux is retarded.
+    #    if (
+    #            hostname != 'VIDEO-COMPUTER' and
+    #            config.config['watch_for_cancel']
+    #    ):
+    #        (thread, terminate_event) = watch_for_cancel()
+
+        in_file = config.config['code_file']
+        out_file = config.config['result_file']
+
+        with open(in_file, 'rb') as rb:
+            serialized = rb.read()
+
+        fun, args, kwargs = cloudpickle.loads(serialized)
+        print(fun)
+        print(args)
+        print(kwargs)
+        print("Invoking the code...")
+        ret = fun(*args, **kwargs)
+
+        serialized = cloudpickle.dumps(ret)
+        with open(out_file, 'wb') as wb:
+            wb.write(serialized)
+
+        # Windows-Linux is retarded.
+    #    if hostname != 'VIDEO-COMPUTER':
+    #        terminate_event.set()
+    #        thread.join()
+        sys.exit(0)
     main()
diff --git a/simple_acl.py b/simple_acl.py
new file mode 100644 (file)
index 0000000..39129ce
--- /dev/null
@@ -0,0 +1,192 @@
+#!/usr/bin/env python3
+
+from abc import ABC, abstractmethod
+import fnmatch
+import logging
+import re
+from typing import Any, Callable, List, Optional, Set
+
+
+logger = logging.getLogger(__name__)
+
+
+ACL_ORDER_ALLOW_DENY = 1
+ACL_ORDER_DENY_ALLOW = 2
+
+
+class SimpleACL(ABC):
+    """A simple Access Control List interface."""
+
+    def __init__(
+        self,
+        *,
+        order_to_check_allow_deny: int,
+        default_answer: bool
+    ):
+        if order_to_check_allow_deny not in (
+                ACL_ORDER_ALLOW_DENY, ACL_ORDER_DENY_ALLOW
+        ):
+            raise Exception(
+                'order_to_check_allow_deny must be ACL_ORDER_ALLOW_DENY or ' +
+                'ACL_ORDER_DENY_ALLOW')
+        self.order_to_check_allow_deny = order_to_check_allow_deny
+        self.default_answer = default_answer
+
+    def __call__(self, x: Any) -> bool:
+        """Returns True if x is allowed, False otherwise."""
+        if self.order_to_check_allow_deny == ACL_ORDER_ALLOW_DENY:
+            if self.check_allowed(x):
+                return True
+            if self.check_denied(x):
+                return False
+            return self.default_answer
+        elif self.order_to_check_allow_deny == ACL_ORDER_DENY_ALLOW:
+            if self.check_denied(x):
+                return False
+            if self.check_allowed(x):
+                return True
+            return self.default_answer
+        raise Exception('Should never get here.')
+
+    @abstractmethod
+    def check_allowed(self, x: Any) -> bool:
+        """Return True if x is allowed, False otherwise."""
+        pass
+
+    @abstractmethod
+    def check_denied(self, x: Any) -> bool:
+        """Return True if x is denied, False otherwise."""
+        pass
+
+
+class SetBasedACL(SimpleACL):
+    def __init__(self,
+                 *,
+                 allow_set: Optional[Set[Any]] = None,
+                 deny_set: Optional[Set[Any]] = None,
+                 order_to_check_allow_deny: int,
+                 default_answer: bool) -> None:
+        super().__init__(
+            order_to_check_allow_deny=order_to_check_allow_deny,
+            default_answer=default_answer
+        )
+        self.allow_set = allow_set
+        self.deny_set = deny_set
+
+    def check_allowed(self, x: Any) -> bool:
+        if self.allow_set is None:
+            return False
+        return x in self.allow_set
+
+    def check_denied(self, x: Any) -> bool:
+        if self.deny_set is None:
+            return False
+        return x in self.deny_set
+
+
+class PredicateListBasedACL(SimpleACL):
+    def __init__(self,
+                 *,
+                 allow_predicate_list: List[Callable[[Any], bool]] = None,
+                 deny_predicate_list: List[Callable[[Any], bool]] = None,
+                 order_to_check_allow_deny: int,
+                 default_answer: bool) -> None:
+        super().__init__(
+            order_to_check_allow_deny=order_to_check_allow_deny,
+            default_answer=default_answer
+        )
+        self.allow_predicate_list = allow_predicate_list
+        self.deny_predicate_list = deny_predicate_list
+
+    def check_allowed(self, x: Any) -> bool:
+        if self.allow_predicate_list is None:
+            return False
+        return any(predicate(x) for predicate in self.allow_predicate_list)
+
+    def check_denied(self, x: Any) -> bool:
+        if self.deny_predicate_list is None:
+            return False
+        return any(predicate(x) for predicate in self.deny_predicate_list)
+
+
+class StringWildcardBasedACL(PredicateListBasedACL):
+    def __init__(self,
+                 *,
+                 allowed_patterns: Optional[List[str]] = None,
+                 denied_patterns: Optional[List[str]] = None,
+                 order_to_check_allow_deny: int,
+                 default_answer: bool) -> None:
+        allow_predicates = []
+        if allowed_patterns is not None:
+            for pattern in allowed_patterns:
+                allow_predicates.append(
+                    lambda x, pattern=pattern: fnmatch.fnmatch(x, pattern)
+                )
+        deny_predicates = None
+        if denied_patterns is not None:
+            deny_predicates = []
+            for pattern in denied_patterns:
+                deny_predicates.append(
+                    lambda x, pattern=pattern: fnmatch.fnmatch(x, pattern)
+                )
+
+        super().__init__(
+            allow_predicate_list=allow_predicates,
+            deny_predicate_list=deny_predicates,
+            order_to_check_allow_deny=order_to_check_allow_deny,
+            default_answer=default_answer,
+        )
+
+
+class StringREBasedACL(PredicateListBasedACL):
+    def __init__(self,
+                 *,
+                 allowed_regexs: Optional[List[re.Pattern]] = None,
+                 denied_regexs: Optional[List[re.Pattern]] = None,
+                 order_to_check_allow_deny: int,
+                 default_answer: bool) -> None:
+        allow_predicates = None
+        if allowed_regexs is not None:
+            allow_predicates = []
+            for pattern in allowed_regexs:
+                allow_predicates.append(
+                    lambda x, pattern=pattern: pattern.match(x) is not None
+                )
+        deny_predicates = None
+        if denied_regexs is not None:
+            deny_predicates = []
+            for pattern in denied_regexs:
+                deny_predicates.append(
+                    lambda x, pattern=pattern: pattern.match(x) is not None
+                )
+        super().__init__(
+            allow_predicate_list=allow_predicates,
+            deny_predicate_list=deny_predicates,
+            order_to_check_allow_deny=order_to_check_allow_deny,
+            default_answer=default_answer,
+        )
+
+
+class CompoundACL(object):
+    ANY = 1
+    ALL = 2
+
+    def __init__(
+            self,
+            *,
+            subacls: Optional[List[SimpleACL]],
+            match_requirement: int = ALL
+    ) -> None:
+        self.subacls = subacls
+        if match_requirement not in (CompoundACL.ANY, CompoundACL.ALL):
+            raise Exception(
+                'match_requirement must be CompoundACL.ANY or CompoundACL.ALL'
+            )
+        self.match_requirement = match_requirement
+
+    def __call__(self, x: Any) -> bool:
+        if self.match_requirement == CompoundACL.ANY:
+            return any(acl(x) for acl in self.subacls)
+        elif self.match_requirement == CompoundACL.ALL:
+            return all(acl(x) for acl in self.subacls)
+        raise Exception('Should never get here.')
index f1ffee1c63250b4fb0d0be319a8090ce406f5fc0..e4832d43d5b1674988628e5dae43a67cf8ed0565 100644 (file)
@@ -12,7 +12,7 @@ import id_generator
 T = TypeVar('T')
 
 
-def wait_many(futures: List[SmartFuture], *, callback: Callable = None):
+def wait_any(futures: List[SmartFuture], *, callback: Callable = None):
     finished: Mapping[int, bool] = {}
     x = 0
     while True:
index 225584bb1a907314de374c9da159d7afb5c96cb1..16d2f595cf12ee34e3c28cdb9f1c522f8ac0b978 100644 (file)
@@ -8,8 +8,8 @@ from typing import Dict, Optional
 
 import pytz
 
-from thread_utils import background_thread
 import math_utils
+from thread_utils import background_thread
 
 logger = logging.getLogger(__name__)
 
@@ -67,33 +67,34 @@ class StateTracker(ABC):
         """
         self.now = datetime.datetime.now(tz=pytz.timezone("US/Pacific"))
         for update_id in sorted(self.last_reminder_ts.keys()):
-            refresh_secs = self.update_ids_to_update_secs[update_id]
             if force_all_updates_to_run:
                 logger.debug('Forcing all updates to run')
                 self.update(
                     update_id, self.now, self.last_reminder_ts[update_id]
                 )
                 self.last_reminder_ts[update_id] = self.now
+                return
+
+            refresh_secs = self.update_ids_to_update_secs[update_id]
+            last_run = self.last_reminder_ts[update_id]
+            if last_run is None:  # Never run before
+                logger.debug(
+                    f'id {update_id} has never been run; running it now'
+                )
+                self.update(
+                    update_id, self.now, self.last_reminder_ts[update_id]
+                )
+                self.last_reminder_ts[update_id] = self.now
             else:
-                last_run = self.last_reminder_ts[update_id]
-                if last_run is None:  # Never run before
-                    logger.debug(
-                        f'id {update_id} has never been run; running it now'
-                    )
+                delta = self.now - last_run
+                if delta.total_seconds() >= refresh_secs:  # Is overdue?
+                    logger.debug(f'id {update_id} is overdue; running it now')
                     self.update(
-                        update_id, self.now, self.last_reminder_ts[update_id]
+                        update_id,
+                        self.now,
+                        self.last_reminder_ts[update_id],
                     )
                     self.last_reminder_ts[update_id] = self.now
-                else:
-                    delta = self.now - last_run
-                    if delta.total_seconds() >= refresh_secs:  # Is overdue
-                        logger.debug('id {update_id} is overdue; running it now')
-                        self.update(
-                            update_id,
-                            self.now,
-                            self.last_reminder_ts[update_id],
-                        )
-                        self.last_reminder_ts[update_id] = self.now
 
 
 class AutomaticStateTracker(StateTracker):
index 83575ff47ce878a93f5237565e066abac57a0b1a..7ad9c42a1e2af3304e18ba6beba021c35acbb086 100644 (file)
@@ -1,7 +1,9 @@
 #!/usr/bin/env python3
 
+import datetime
 from itertools import zip_longest
 import json
+import logging
 import random
 import re
 import string
@@ -9,6 +11,11 @@ from typing import Any, List, Optional
 import unicodedata
 from uuid import uuid4
 
+import dateparse.dateparse_utils as dp
+
+
+logger = logging.getLogger(__name__)
+
 NUMBER_RE = re.compile(r"^([+\-]?)((\d+)(\.\d+)?([e|E]\d+)?|\.\d+)$")
 
 HEX_NUMBER_RE = re.compile(r"^([+|-]?)0[x|X]([0-9A-Fa-f]+)$")
@@ -247,7 +254,6 @@ def _add_thousands_separator(in_str: str, *, separator_char = ',', places = 3) -
     return ret
 
 
-
 # Full url example:
 # scheme://username:[email protected]:8042/folder/subfolder/file.extension?param=value&param2=value2#hash
 def is_url(in_str: Any, allowed_schemes: Optional[List[str]] = None) -> bool:
@@ -354,13 +360,14 @@ def number_to_suffix_string(num: int) -> Optional[str]:
     d = 0.0
     suffix = None
     for (sfx, size) in NUM_SUFFIXES.items():
-        if num > size:
+        if num >= size:
             d = num / size
             suffix = sfx
             break
     if suffix is not None:
         return f"{d:.1f}{suffix}"
-    return None
+    else:
+        return f'{num:d}'
 
 
 def is_credit_card(in_str: Any, card_type: str = None) -> bool:
@@ -807,6 +814,45 @@ def to_bool(in_str: str) -> bool:
     return in_str.lower() in ("true", "1", "yes", "y", "t")
 
 
+def to_date(in_str: str) -> Optional[datetime.date]:
+    try:
+        d = dp.DateParser()
+        d.parse(in_str)
+        return d.get_date()
+    except dp.ParseException:
+        logger.warning(f'Unable to parse date {in_str}.')
+    return None
+
+
+def valid_date(in_str: str) -> bool:
+    try:
+        d = dp.DateParser()
+        _ = d.parse(in_str)
+        return True
+    except dp.ParseException:
+        logger.warning(f'Unable to parse date {in_str}.')
+    return False
+
+
+def to_datetime(in_str: str) -> Optional[datetime.datetime]:
+    try:
+        d = dp.DateParser()
+        dt = d.parse(in_str)
+        if type(dt) == datetime.datetime:
+            return dt
+    except ValueError:
+        logger.warning(f'Unable to parse datetime {in_str}.')
+    return None
+
+
+def valid_datetime(in_str: str) -> bool:
+    _ = to_datetime(in_str)
+    if _ is not None:
+        return True
+    logger.warning(f'Unable to parse datetime {in_str}.')
+    return False
+
+
 def dedent(in_str: str) -> str:
     """
     Removes tab indentation from multi line strings (inspired by analogous Scala function).
@@ -869,3 +915,29 @@ def sprintf(*args, **kwargs) -> str:
             ret += str(arg)
     ret += end
     return ret
+
+
+def is_are(n: int) -> str:
+    if n == 1:
+        return "is"
+    return "are"
+
+
+def pluralize(n: int) -> str:
+    if n == 1:
+        return ""
+    return "s"
+
+
+def thify(n: int) -> str:
+    digit = str(n)
+    assert is_integer_number(digit)
+    digit = digit[-1:]
+    if digit == "1":
+        return "st"
+    elif digit == "2":
+        return "nd"
+    elif digit == "3":
+        return "rd"
+    else:
+        return "th"
diff --git a/tests/dateparse_utils_test.py b/tests/dateparse_utils_test.py
new file mode 100755 (executable)
index 0000000..ff16e01
--- /dev/null
@@ -0,0 +1,203 @@
+#!/usr/bin/env python3
+
+import datetime
+import unittest
+
+import pytz
+
+import dateparse.dateparse_utils as du
+import unittest_utils as uu
+
+
+class TestDateparseUtils(unittest.TestCase):
+
+    @uu.check_method_for_perf_regressions
+    def test_dateparsing(self):
+        dp = du.DateParser(
+            override_now_for_test_purposes = datetime.datetime(2021, 7, 2)
+        )
+        parsable_expressions = [
+            ('today',
+             datetime.datetime(2021, 7, 2)),
+            ('tomorrow',
+             datetime.datetime(2021, 7, 3)),
+            ('yesterday',
+             datetime.datetime(2021, 7, 1)),
+            ('21:30',
+             datetime.datetime(2021, 7, 2, 21, 30, 0, 0)),
+            ('12:01am',
+             datetime.datetime(2021, 7, 2, 0, 1, 0, 0)),
+            ('12:02p',
+             datetime.datetime(2021, 7, 2, 12, 2, 0, 0)),
+            ('0:03',
+             datetime.datetime(2021, 7, 2, 0, 3, 0, 0)),
+            ('last wednesday',
+             datetime.datetime(2021, 6, 30)),
+            ('this wed',
+             datetime.datetime(2021, 7, 7)),
+            ('next wed',
+             datetime.datetime(2021, 7, 14)),
+            ('this coming tues',
+             datetime.datetime(2021, 7, 6)),
+            ('this past monday',
+             datetime.datetime(2021, 6, 28)),
+            ('4 days ago',
+             datetime.datetime(2021, 6, 28)),
+            ('4 mondays ago',
+             datetime.datetime(2021, 6, 7)),
+            ('4 months ago',
+             datetime.datetime(2021, 3, 2)),
+            ('3 days back',
+             datetime.datetime(2021, 6, 29)),
+            ('13 weeks from now',
+             datetime.datetime(2021, 10, 1)),
+            ('1 year from now',
+             datetime.datetime(2022, 7, 2)),
+            ('4 weeks from now',
+             datetime.datetime(2021, 7, 30)),
+            ('3 saturdays ago',
+             datetime.datetime(2021, 6, 12)),
+            ('4 months from today',
+             datetime.datetime(2021, 11, 2)),
+            ('4 years from yesterday',
+             datetime.datetime(2025, 7, 1)),
+            ('4 weeks from tomorrow',
+             datetime.datetime(2021, 7, 31)),
+            ('april 15, 2005',
+             datetime.datetime(2005, 4, 15)),
+            ('april 14',
+             datetime.datetime(2021, 4, 14)),
+            ('9:30am on last wednesday',
+             datetime.datetime(2021, 6, 30, 9, 30)),
+            ('2005/apr/15',
+             datetime.datetime(2005, 4, 15)),
+            ('2005 apr 15',
+             datetime.datetime(2005, 4, 15)),
+            ('the 1st wednesday in may',
+             datetime.datetime(2021, 5, 5)),
+            ('last sun of june',
+             datetime.datetime(2021, 6, 27)),
+            ('this Easter',
+             datetime.datetime(2021, 4, 4)),
+            ('last christmas',
+             datetime.datetime(2020, 12, 25)),
+            ('last Xmas',
+             datetime.datetime(2020, 12, 25)),
+            ('xmas, 1999',
+             datetime.datetime(1999, 12, 25)),
+            ('next mlk day',
+             datetime.datetime(2022, 1, 17)),
+            ('Halloween, 2020',
+             datetime.datetime(2020, 10, 31)),
+            ('5 work days after independence day',
+             datetime.datetime(2021, 7, 12)),
+            ('50 working days from last wed',
+             datetime.datetime(2021, 9, 10)),
+            ('25 working days before columbus day',
+             datetime.datetime(2021, 9, 3)),
+            ('today +1 week',
+             datetime.datetime(2021, 7, 9)),
+            ('sunday -3 weeks',
+             datetime.datetime(2021, 6, 13)),
+            ('4 weeks before xmas, 1999',
+             datetime.datetime(1999, 11, 27)),
+            ('3 days before new years eve, 2000',
+             datetime.datetime(2000, 12, 28)),
+            ('july 4th',
+             datetime.datetime(2021, 7, 4)),
+            ('the ides of march',
+             datetime.datetime(2021, 3, 15)),
+            ('the nones of april',
+             datetime.datetime(2021, 4, 5)),
+            ('the kalends of may',
+             datetime.datetime(2021, 5, 1)),
+            ('9/11/2001',
+             datetime.datetime(2001, 9, 11)),
+            ('4 sundays before veterans\' day',
+             datetime.datetime(2021, 10, 17)),
+            ('xmas eve',
+             datetime.datetime(2021, 12, 24)),
+            ('this friday at 5pm',
+             datetime.datetime(2021, 7, 9, 17, 0, 0)),
+            ('presidents day',
+             datetime.datetime(2021, 2, 15)),
+            ('memorial day, 1921',
+             datetime.datetime(1921, 5, 30)),
+            ('today -4 wednesdays',
+             datetime.datetime(2021, 6, 9)),
+            ('thanksgiving',
+             datetime.datetime(2021, 11, 25)),
+            ('2 sun in jun',
+             datetime.datetime(2021, 6, 13)),
+            ('easter -40 days',
+             datetime.datetime(2021, 2, 23)),
+            ('easter +39 days',
+             datetime.datetime(2021, 5, 13)),
+            ('1st tuesday in nov, 2024',
+             datetime.datetime(2024, 11, 5)),
+            ('2 days before last xmas at 3:14:15.92a',
+             datetime.datetime(2020, 12, 23, 3, 14, 15, 92)),
+            ('3 weeks after xmas, 1995 at midday',
+             datetime.datetime(1996, 1, 15, 12, 0, 0)),
+            ('4 months before easter, 1992 at midnight',
+             datetime.datetime(1991, 12, 19)),
+            ('5 months before halloween, 1995 at noon',
+             datetime.datetime(1995, 5, 31, 12)),
+            ('4 days before last wednesday',
+             datetime.datetime(2021, 6, 26)),
+            ('44 months after today',
+             datetime.datetime(2025, 3, 2)),
+            ('44 years before today',
+             datetime.datetime(1977, 7, 2)),
+            ('44 weeks ago',
+             datetime.datetime(2020, 8, 28)),
+            ('15 minutes to 3am',
+             datetime.datetime(2021, 7, 2, 2, 45)),
+            ('quarter past 4pm',
+             datetime.datetime(2021, 7, 2, 16, 15)),
+            ('half past 9',
+             datetime.datetime(2021, 7, 2, 9, 30)),
+            ('4 seconds to midnight',
+             datetime.datetime(2021, 7, 1, 23, 59, 56)),
+            ('4 seconds to midnight, tomorrow',
+             datetime.datetime(2021, 7, 2, 23, 59, 56)),
+            ('2021/apr/15T21:30:44.55',
+             datetime.datetime(2021, 4, 15, 21, 30, 44, 55)),
+            ('2021/apr/15 at 21:30:44.55',
+             datetime.datetime(2021, 4, 15, 21, 30, 44, 55)),
+            ('2021/4/15 at 21:30:44.55',
+             datetime.datetime(2021, 4, 15, 21, 30, 44, 55)),
+            ('2021/04/15 at 21:30:44.55',
+             datetime.datetime(2021, 4, 15, 21, 30, 44, 55)),
+            ('2021/04/15 at 21:30:44.55Z',
+             datetime.datetime(2021, 4, 15, 21, 30, 44, 55,
+                               tzinfo=pytz.timezone('UTC'))),
+            ('2021/04/15 at 21:30:44.55EST',
+             datetime.datetime(2021, 4, 15, 21, 30, 44, 55,
+                               tzinfo=pytz.timezone('EST'))),
+            ('13 days after last memorial day at 12 seconds before 4pm',
+             datetime.datetime(2020, 6, 7, 15, 59, 48)),
+            ('    2     days     before   yesterday    at   9am      ',
+             datetime.datetime(2021, 6, 29, 9)),
+            ('-3 days before today',
+             datetime.datetime(2021, 7, 5)),
+            ('3 days before yesterday at midnight EST',
+             datetime.datetime(2021, 6, 28, tzinfo=pytz.timezone('EST'))),
+        ]
+
+        for (txt, expected_dt) in parsable_expressions:
+            try:
+                print(f'> {txt}')
+                actual_dt = dp.parse(txt)
+                self.assertIsNotNone(actual_dt)
+                self.assertEqual(
+                    actual_dt,
+                    expected_dt,
+                    f'"{txt}", got "{actual_dt}" while expecting "{expected_dt}"'
+                )
+            except du.ParseException:
+                self.fail(f'Expected "{txt}" to parse successfully.')
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/tests/dict_utils_test.py b/tests/dict_utils_test.py
new file mode 100755 (executable)
index 0000000..1bdbb9b
--- /dev/null
@@ -0,0 +1,62 @@
+#!/usr/bin/env python3
+
+import unittest
+
+import dict_utils as du
+
+
+class TestDictUtils(unittest.TestCase):
+
+    def test_init_or_inc(self):
+        d = {}
+        du.init_or_inc(d, 'a')
+        du.init_or_inc(d, 'b')
+        du.init_or_inc(d, 'a')
+        du.init_or_inc(d, 'b')
+        du.init_or_inc(d, 'c')
+        du.init_or_inc(d, 'c')
+        du.init_or_inc(d, 'd')
+        du.init_or_inc(d, 'e')
+        du.init_or_inc(d, 'a')
+        du.init_or_inc(d, 'b')
+        e = {
+            'a': 3, 'b': 3, 'c': 2, 'd': 1, 'e': 1
+        }
+        self.assertEqual(d, e)
+
+    def test_shard_coalesce(self):
+        d = {
+            'a': 3, 'b': 3, 'c': 2, 'd': 1, 'e': 1
+        }
+        shards = du.shard(d, 2)
+        merged = du.coalesce(shards)
+        self.assertEqual(d, merged)
+
+    def test_item_with_max_value(self):
+        d = {
+            'a': 4, 'b': 3, 'c': 2, 'd': 1, 'e': 1
+        }
+        self.assertEqual('a', du.item_with_max_value(d)[0])
+        self.assertEqual(4, du.item_with_max_value(d)[1])
+        self.assertEqual('a', du.key_with_max_value(d))
+        self.assertEqual(4, du.max_value(d))
+
+    def test_item_with_min_value(self):
+        d = {
+            'a': 4, 'b': 3, 'c': 2, 'd': 1, 'e': 0
+        }
+        self.assertEqual('e', du.item_with_min_value(d)[0])
+        self.assertEqual(0, du.item_with_min_value(d)[1])
+        self.assertEqual('e', du.key_with_min_value(d))
+        self.assertEqual(0, du.min_value(d))
+
+    def test_min_max_key(self):
+        d = {
+            'a': 4, 'b': 3, 'c': 2, 'd': 1, 'e': 0
+        }
+        self.assertEqual('a', du.min_key(d))
+        self.assertEqual('e', du.max_key(d))
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/tests/simple_acl_test.py b/tests/simple_acl_test.py
new file mode 100755 (executable)
index 0000000..7c17415
--- /dev/null
@@ -0,0 +1,47 @@
+#!/usr/bin/env python3
+
+import re
+import unittest
+
+import simple_acl as acl
+
+
+class TestSimpleACL(unittest.TestCase):
+
+    def test_set_based_acl(self):
+        even = acl.SetBasedACL(
+            allow_set = set([2, 4, 6, 8, 10]),
+            deny_set = set([1, 3, 5, 7, 9]),
+            order_to_check_allow_deny = acl.ACL_ORDER_ALLOW_DENY,
+            default_answer = False
+        )
+        self.assertTrue(even(2))
+        self.assertFalse(even(3))
+        self.assertFalse(even(-4))
+
+    def test_wildcard_based_acl(self):
+        a_or_b = acl.StringWildcardBasedACL(
+            allowed_patterns = ['a*', 'b*'],
+            order_to_check_allow_deny = acl.ACL_ORDER_ALLOW_DENY,
+            default_answer = False
+        )
+        self.assertTrue(a_or_b('aardvark'))
+        self.assertTrue(a_or_b('bubblegum'))
+        self.assertFalse(a_or_b('charlie'))
+
+    def test_re_based_acl(self):
+        weird = acl.StringREBasedACL(
+            denied_regexs = [
+                re.compile('^a.*a$'),
+                re.compile('^b.*b$')
+            ],
+            order_to_check_allow_deny = acl.ACL_ORDER_ALLOW_DENY,
+            default_answer = True
+        )
+        self.assertTrue(weird('aardvark'))
+        self.assertFalse(weird('anaconda'))
+        self.assertFalse(weird('beelzebub'))
+
+
+if __name__ == '__main__':
+    unittest.main()
index 157de0a8df9044b614e47a2b943aa79803b948c5..0472daaccaf9a525794df24e79c8ae5f923898f0 100755 (executable)
@@ -3,10 +3,15 @@
 import unittest
 
 from ansi import fg, bg, reset
+import bootstrap
 import string_utils as su
 
+import unittest_utils as uu
 
+
[email protected]_all_methods_for_perf_regressions()
 class TestStringUtils(unittest.TestCase):
+
     def test_is_none_or_empty(self):
         self.assertTrue(su.is_none_or_empty(None))
         self.assertTrue(su.is_none_or_empty(""))
@@ -136,5 +141,45 @@ class TestStringUtils(unittest.TestCase):
         self.assertTrue(su.is_url("http://user:[email protected]:81/uri/uri#shard?param=value+s"))
         self.assertTrue(su.is_url("ftp://127.0.0.1/uri/uri"))
 
+    def test_is_email(self):
+        self.assertTrue(su.is_email('[email protected]'))
+        self.assertTrue(su.is_email('[email protected]'))
+        self.assertFalse(su.is_email('@yahoo.com'))
+        self.assertFalse(su.is_email('indubidibly'))
+        self.assertFalse(su.is_email('[email protected]'))
+
+    def test_suffix_string_to_number(self):
+        self.assertEqual(1024, su.suffix_string_to_number('1Kb'))
+        self.assertEqual(1024 * 1024, su.suffix_string_to_number('1Mb'))
+        self.assertEqual(1024, su.suffix_string_to_number('1k'))
+        self.assertEqual(1024, su.suffix_string_to_number('1kb'))
+        self.assertEqual(None, su.suffix_string_to_number('1Jl'))
+        self.assertEqual(None, su.suffix_string_to_number('undeniable'))
+
+    def test_number_to_suffix_string(self):
+        self.assertEqual('1.0Kb', su.number_to_suffix_string(1024))
+        self.assertEqual('1.0Mb', su.number_to_suffix_string(1024 * 1024))
+        self.assertEqual('123', su.number_to_suffix_string(123))
+
+    def test_is_credit_card(self):
+        self.assertTrue(su.is_credit_card('4242424242424242'))
+        self.assertTrue(su.is_credit_card('5555555555554444'))
+        self.assertTrue(su.is_credit_card('378282246310005'))
+        self.assertTrue(su.is_credit_card('6011111111111117'))
+        self.assertTrue(su.is_credit_card('4000000360000006'))
+        self.assertFalse(su.is_credit_card('8000000360110099'))
+        self.assertFalse(su.is_credit_card(''))
+
+    def test_is_camel_case(self):
+        self.assertFalse(su.is_camel_case('thisisatest'))
+        self.assertTrue(su.is_camel_case('thisIsATest'))
+        self.assertFalse(su.is_camel_case('this_is_a_test'))
+
+    def test_is_snake_case(self):
+        self.assertFalse(su.is_snake_case('thisisatest'))
+        self.assertFalse(su.is_snake_case('thisIsATest'))
+        self.assertTrue(su.is_snake_case('this_is_a_test'))
+
+
 if __name__ == '__main__':
-    unittest.main()
+    bootstrap.initialize(unittest.main)()
diff --git a/unittest_utils.py b/unittest_utils.py
new file mode 100644 (file)
index 0000000..99ac81d
--- /dev/null
@@ -0,0 +1,144 @@
+#!/usr/bin/env python3
+
+"""Helpers for unittests.  Note that when you import this we
+automatically wrap unittest.main() with a call to bootstrap.initialize
+so that we getLogger config, commandline args, logging control,
+etc... this works fine but it's a little hacky so caveat emptor.
+"""
+
+import functools
+import inspect
+import logging
+import pickle
+import random
+import statistics
+import time
+from typing import Callable
+import unittest
+
+import bootstrap
+import config
+
+
+logger = logging.getLogger(__name__)
+cfg = config.add_commandline_args(
+    f'Logging ({__file__})',
+    'Args related to function decorators')
+cfg.add_argument(
+    '--unittests_ignore_perf',
+    action='store_true',
+    default=False,
+    help='Ignore unittest perf regression in @check_method_for_perf_regressions',
+)
+cfg.add_argument(
+    '--unittests_num_perf_samples',
+    type=int,
+    default=20,
+    help='The count of perf timing samples we need to see before blocking slow runs on perf grounds'
+)
+cfg.add_argument(
+    '--unittests_drop_perf_traces',
+    type=str,
+    nargs=1,
+    default=None,
+    help='The identifier (i.e. file!test_fixture) for which we should drop all perf data'
+)
+
+
+# >>> This is the hacky business, FYI. <<<
+unittest.main = bootstrap.initialize(unittest.main)
+
+
+_db = '/home/scott/.python_unittest_performance_db'
+
+
+def check_method_for_perf_regressions(func: Callable) -> Callable:
+    """This is meant to be used on a method in a class that subclasses
+    unittest.TestCase.  When thus decorated it will time the execution
+    of the code in the method, compare it with a database of
+    historical perfmance, and fail the test with a perf-related
+    message if it has become too slow.
+    """
+
+    def load_known_test_performance_characteristics():
+        with open(_db, 'rb') as f:
+            return pickle.load(f)
+
+    def save_known_test_performance_characteristics(perfdb):
+        with open(_db, 'wb') as f:
+            pickle.dump(perfdb, f, pickle.HIGHEST_PROTOCOL)
+
+    @functools.wraps(func)
+    def wrapper_perf_monitor(*args, **kwargs):
+        try:
+            perfdb = load_known_test_performance_characteristics()
+        except Exception as e:
+            logger.exception(e)
+            logger.warning(f'Unable to load perfdb from {_db}')
+            perfdb = {}
+
+        # This is a unique identifier for a test: filepath!function
+        logger.debug(f'Watching {func.__name__}\'s performance...')
+        func_id = f'{func.__globals__["__file__"]}!{func.__name__}'
+        logger.debug(f'Canonical function identifier = {func_id}')
+
+        # cmdline arg to forget perf traces for function
+        drop_id = config.config['unittests_drop_perf_traces']
+        if drop_id is not None:
+            if drop_id in perfdb:
+                perfdb[drop_id] = []
+
+        # Run the wrapped test paying attention to latency.
+        start_time = time.perf_counter()
+        value = func(*args, **kwargs)
+        end_time = time.perf_counter()
+        run_time = end_time - start_time
+        logger.debug(f'{func.__name__} executed in {run_time:f}s.')
+
+        # Check the db; see if it was unexpectedly slow.
+        hist = perfdb.get(func_id, [])
+        if len(hist) < config.config['unittests_num_perf_samples']:
+            hist.append(run_time)
+            logger.debug(
+                f'Still establishing a perf baseline for {func.__name__}'
+            )
+        else:
+            stdev = statistics.stdev(hist)
+            limit = hist[-1] + stdev * 3
+            logger.debug(
+                f'Max acceptable performace for {func.__name__} is {limit:f}s'
+            )
+            if (
+                run_time > limit and
+                not config.config['unittests_ignore_perf']
+            ):
+                msg = f'''{func_id} performance has regressed unacceptably.
+{hist[-1]:f}s is the slowest record in {len(hist)} db perf samples.
+It just ran in {run_time:f}s which is >3 stdevs slower than the slowest sample.
+Here is the current, full db perf timing distribution:
+
+{hist}'''
+                slf = args[0]
+                logger.error(msg)
+                slf.fail(msg)
+            else:
+                hist.append(run_time)
+
+        n = min(config.config['unittests_num_perf_samples'], len(hist))
+        hist = random.sample(hist, n)
+        hist.sort()
+        perfdb[func_id] = hist
+        save_known_test_performance_characteristics(perfdb)
+        return value
+    return wrapper_perf_monitor
+
+
+def check_all_methods_for_perf_regressions(prefix='test_'):
+    def decorate_the_testcase(cls):
+        if issubclass(cls, unittest.TestCase):
+            for name, m in inspect.getmembers(cls, inspect.isfunction):
+                if name.startswith(prefix):
+                    setattr(cls, name, check_method_for_perf_regressions(m))
+                    logger.debug(f'Wrapping {cls.__name__}:{name}.')
+        return cls
+    return decorate_the_testcase