From 69566c003b4f1c3a4905f37d3735d7921502d14a Mon Sep 17 00:00:00 2001 From: Scott Gasch Date: Wed, 12 Oct 2022 10:17:06 -0700 Subject: [PATCH] Migration from old pyutilz package name (which, in turn, came from my old python_utilities package). This one is ok to upload to PyPI per: https://github.com/pypa/pypi-support/issues/2201 --- .flake8 | 5 + .gitignore | 3 + LICENSE | 15 + NOTICE | 75 + README.md | 30 + cut_version.sh | 75 + docs/.gitignore | 1 + docs/Makefile | 20 + docs/README | 5 + docs/conf.py | 62 + docs/index.rst | 22 + docs/make.bat | 35 + docs/modules.rst | 7 + docs/new_file_added.sh | 4 + pyproject.template | 52 + pyproject.toml | 52 + release_notes.md | 0 setup.cfg | 5 + src/pyutils/__init__.py | 0 src/pyutils/ansi.py | 2079 ++++++++++++++++ src/pyutils/argparse_utils.py | 282 +++ src/pyutils/bootstrap.py | 414 ++++ src/pyutils/collectionz/__init__.py | 0 src/pyutils/collectionz/bidict.py | 53 + src/pyutils/collectionz/bst.py | 643 +++++ src/pyutils/collectionz/shared_dict.py | 240 ++ src/pyutils/collectionz/trie.py | 346 +++ src/pyutils/compress/__init__.py | 0 src/pyutils/compress/letter_compress.py | 110 + src/pyutils/config.py | 749 ++++++ src/pyutils/datetimez/.gitignore | 7 + src/pyutils/datetimez/__init__.py | 0 src/pyutils/datetimez/constants.py | 22 + src/pyutils/datetimez/dateparse_utils.g4 | 683 ++++++ src/pyutils/datetimez/dateparse_utils.py | 1055 ++++++++ src/pyutils/datetimez/datetime_utils.py | 956 ++++++++ src/pyutils/decorator_utils.py | 839 +++++++ src/pyutils/dict_utils.py | 253 ++ src/pyutils/exec_utils.py | 233 ++ src/pyutils/files/__init__.py | 0 src/pyutils/files/directory_filter.py | 196 ++ src/pyutils/files/file_utils.py | 832 +++++++ src/pyutils/files/lockfile.py | 243 ++ src/pyutils/function_utils.py | 32 + src/pyutils/id_generator.py | 45 + src/pyutils/iter_utils.py | 183 ++ src/pyutils/list_utils.py | 332 +++ src/pyutils/logging_utils.py | 913 +++++++ src/pyutils/math_utils.py | 243 ++ src/pyutils/misc_utils.py | 30 + src/pyutils/parallelize/__init__.py | 0 src/pyutils/parallelize/deferred_operand.py | 163 ++ src/pyutils/parallelize/executors.py | 1540 ++++++++++++ src/pyutils/parallelize/parallelize.py | 107 + src/pyutils/parallelize/smart_future.py | 147 ++ src/pyutils/parallelize/thread_utils.py | 212 ++ src/pyutils/persistent.py | 349 +++ src/pyutils/remote_worker.py | 144 ++ src/pyutils/search/__init__.py | 0 src/pyutils/search/logical_search.py | 470 ++++ src/pyutils/security/__init__.py | 0 src/pyutils/security/acl.py | 309 +++ src/pyutils/state_tracker.py | 276 +++ src/pyutils/stopwatch.py | 40 + src/pyutils/string_utils.py | 2392 +++++++++++++++++++ src/pyutils/text_utils.py | 707 ++++++ src/pyutils/typez/__init__.py | 0 src/pyutils/typez/centcount.py | 220 ++ src/pyutils/typez/histogram.py | 219 ++ src/pyutils/typez/money.py | 216 ++ src/pyutils/typez/rate.py | 92 + src/pyutils/typez/type_utils.py | 46 + src/pyutils/unittest_utils.py | 359 +++ src/pyutils/unscrambler.py | 315 +++ src/pyutils/zookeeper.py | 424 ++++ tests/.coveragerc | 2 + tests/.gitignore | 2 + tests/README | 29 + tests/collectionz/shared_dict_test.py | 79 + tests/color_vars.sh | 104 + tests/compress/letter_compress_test.py | 34 + tests/datetimez/dateparse_utils_test.py | 181 ++ tests/decorator_utils_test.py | 25 + tests/dict_utils_test.py | 56 + tests/exec_utils_test.py | 42 + tests/logging_utils_test.py | 58 + tests/parallelize/parallelize_itest.py | 90 + tests/parallelize/thread_utils_test.py | 67 + tests/run_tests.py | 544 +++++ tests/run_tests_serially.sh | 198 ++ tests/security/acl_test.py | 92 + tests/string_utils_test.py | 195 ++ tests/typez/centcount_test.py | 105 + tests/typez/money_test.py | 106 + tests/typez/rate_test.py | 71 + tests/zookeeper_test.py | 79 + 96 files changed, 23387 insertions(+) create mode 100644 .flake8 create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 NOTICE create mode 100644 README.md create mode 100755 cut_version.sh create mode 100644 docs/.gitignore create mode 100644 docs/Makefile create mode 100644 docs/README create mode 100644 docs/conf.py create mode 100644 docs/index.rst create mode 100644 docs/make.bat create mode 100644 docs/modules.rst create mode 100755 docs/new_file_added.sh create mode 100644 pyproject.template create mode 100644 pyproject.toml create mode 100644 release_notes.md create mode 100644 setup.cfg create mode 100644 src/pyutils/__init__.py create mode 100755 src/pyutils/ansi.py create mode 100644 src/pyutils/argparse_utils.py create mode 100644 src/pyutils/bootstrap.py create mode 100644 src/pyutils/collectionz/__init__.py create mode 100644 src/pyutils/collectionz/bidict.py create mode 100644 src/pyutils/collectionz/bst.py create mode 100644 src/pyutils/collectionz/shared_dict.py create mode 100644 src/pyutils/collectionz/trie.py create mode 100644 src/pyutils/compress/__init__.py create mode 100644 src/pyutils/compress/letter_compress.py create mode 100644 src/pyutils/config.py create mode 100644 src/pyutils/datetimez/.gitignore create mode 100644 src/pyutils/datetimez/__init__.py create mode 100644 src/pyutils/datetimez/constants.py create mode 100644 src/pyutils/datetimez/dateparse_utils.g4 create mode 100755 src/pyutils/datetimez/dateparse_utils.py create mode 100644 src/pyutils/datetimez/datetime_utils.py create mode 100644 src/pyutils/decorator_utils.py create mode 100644 src/pyutils/dict_utils.py create mode 100644 src/pyutils/exec_utils.py create mode 100644 src/pyutils/files/__init__.py create mode 100644 src/pyutils/files/directory_filter.py create mode 100644 src/pyutils/files/file_utils.py create mode 100644 src/pyutils/files/lockfile.py create mode 100644 src/pyutils/function_utils.py create mode 100644 src/pyutils/id_generator.py create mode 100644 src/pyutils/iter_utils.py create mode 100644 src/pyutils/list_utils.py create mode 100644 src/pyutils/logging_utils.py create mode 100644 src/pyutils/math_utils.py create mode 100644 src/pyutils/misc_utils.py create mode 100644 src/pyutils/parallelize/__init__.py create mode 100644 src/pyutils/parallelize/deferred_operand.py create mode 100644 src/pyutils/parallelize/executors.py create mode 100644 src/pyutils/parallelize/parallelize.py create mode 100644 src/pyutils/parallelize/smart_future.py create mode 100644 src/pyutils/parallelize/thread_utils.py create mode 100644 src/pyutils/persistent.py create mode 100755 src/pyutils/remote_worker.py create mode 100644 src/pyutils/search/__init__.py create mode 100644 src/pyutils/search/logical_search.py create mode 100644 src/pyutils/security/__init__.py create mode 100644 src/pyutils/security/acl.py create mode 100644 src/pyutils/state_tracker.py create mode 100644 src/pyutils/stopwatch.py create mode 100644 src/pyutils/string_utils.py create mode 100644 src/pyutils/text_utils.py create mode 100644 src/pyutils/typez/__init__.py create mode 100644 src/pyutils/typez/centcount.py create mode 100644 src/pyutils/typez/histogram.py create mode 100644 src/pyutils/typez/money.py create mode 100644 src/pyutils/typez/rate.py create mode 100644 src/pyutils/typez/type_utils.py create mode 100644 src/pyutils/unittest_utils.py create mode 100644 src/pyutils/unscrambler.py create mode 100644 src/pyutils/zookeeper.py create mode 100644 tests/.coveragerc create mode 100644 tests/.gitignore create mode 100644 tests/README create mode 100755 tests/collectionz/shared_dict_test.py create mode 100644 tests/color_vars.sh create mode 100755 tests/compress/letter_compress_test.py create mode 100755 tests/datetimez/dateparse_utils_test.py create mode 100755 tests/decorator_utils_test.py create mode 100755 tests/dict_utils_test.py create mode 100755 tests/exec_utils_test.py create mode 100755 tests/logging_utils_test.py create mode 100755 tests/parallelize/parallelize_itest.py create mode 100755 tests/parallelize/thread_utils_test.py create mode 100755 tests/run_tests.py create mode 100755 tests/run_tests_serially.sh create mode 100755 tests/security/acl_test.py create mode 100755 tests/string_utils_test.py create mode 100755 tests/typez/centcount_test.py create mode 100755 tests/typez/money_test.py create mode 100755 tests/typez/rate_test.py create mode 100755 tests/zookeeper_test.py diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..47f3c62 --- /dev/null +++ b/.flake8 @@ -0,0 +1,5 @@ +[flake8] +ignore = E203, E251, E266, E501, W503, F403, E262, W504, W503, W501 +max-line-length = 100 +max-complexity = 26 +select = B,C,E,F,W,T4,B9 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9d23a28 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +**/__pycache__/ +**/pyutils.egg-info +dist/* diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..754f2f4 --- /dev/null +++ b/LICENSE @@ -0,0 +1,15 @@ + +Except where otherwise noted in the source code and described in the +NOTICE file, all code is © Copyright 2021-2022 Scott Gasch. + +Licensed under the Apache License, Version 2.0 (the "License"); you +may not use this project except in compliance with the License. You +may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/NOTICE b/NOTICE new file mode 100644 index 0000000..153f09c --- /dev/null +++ b/NOTICE @@ -0,0 +1,75 @@ +This is Scott's pyutils module. See README.md for details. + +Some code in this library came from other sources. As required by +clause 4 of the Apache 2.0 License and clause 3 of the PSF License, +this NOTICE file describes changes Scott Gasch made to any preexisting +code regardless its original License. All such original code was used +in a manner compliant with its original License and is enumerated in +this file. This file also contains URLs pointing at the orginal +source of all forked code. + + 1. As noted in string_utils.py, that file is a fork of work by + Davide Zanotti. Davide's original code is here: + + https://github.com/daveoncode/python-string-utils/tree/master/string_utils + + David's code was released under the MIT license and the original license + text is preserved in string_utils.py. + + Scott's modifications include: + + Added these routines: strip_escape_sequences, + suffix_string_to_number, number_to_suffix_string, extract_ip_v4, + extract_ip_v6, extract_mac_address, extract_ip, to_bool, + to_date, valid_date, to_datetime, valid_datetime, squeeze, + indent, dedent, sprintf, strip_ansi_sequences, SprintfStdout, + capitalize_first_letter, it_they, is_are, pluralize, + make_contractions, thify, ngrams, ngrams_presplit, bigrams, + trigrams, shuffle_columns_into_list, shuffle_columns_into_dict, + interpolate_using_dict, to_ascii, to_base64, is_base64, from_base64, + chunk, to_bitstring, is_bitstring, from_bitstring, ip_v4_sort_key, + path_ancestors_before_descendants_sort_key, replace_all, and + replace_nth. + + Added type annotations everywhere, + + Wrote doctests everywhere, + + Added sphinx style pydocs, + + Wrote a supplimental unittest (tests/string_utils_test.py), + + Added logging. + + 2. As noted in shared_dict.py, that file is a fork of work by + LuizaLabs and available here: + + https://github.com/luizalabs/shared-memory-dict/blob/main/shared_memory_dict/dict.py + + The original work was released under the MIT license and the original license + text is preserved in shared_dict.py. + + Scott's modifications include: + + Adding a unittest (tests/shared_dict_test.py), + + Added type hints, + + Changes to locking scope, + + Minor cleanup and style tweaks, + + Added sphinx style pydocs. + + 3. The timeout decortator in decorator_utils.py is based on original + work published in ActiveState code recipes and covered by the PSF + license. It is from here: + + https://code.activestate.com/recipes/307871-timing-out-function/ + + Scott's modifications include: + + Adding docs + comments including a doctest unittest, + + Minor cleanup and style tweaks, + + Added type hints. + +Thank you to everyone who makes their code available for reuse by +others and contributes to the open source ecosystem. Scott is +especially grateful to the authors of the projects above. Thank you. + +Any code not mentioned in this NOTICE file is work by Scott Gasch, is +copyrighted by him, and is released under the Apache 2.0 license +described in the LICENSE file. + +If you make modifications to such code, please comply with the Apache +2.0 license by retaining the LICENSE and copyright in your work and by +adding your own NOTICEs about the changes you make. See +https://www.apache.org/licenses/LICENSE-2.0 for details. diff --git a/README.md b/README.md new file mode 100644 index 0000000..9f7cac9 --- /dev/null +++ b/README.md @@ -0,0 +1,30 @@ +# pyutils + +This is a collection of Python utilities that I wrote and find useful. +From collections that try to emulate Pythonic patterns +(pyutils.collectionz) to a "smart" natural language date parser +(pyutils.datetimez.dateparse_utils), to filesystem helpers +(pyutils.files.file_utils) to a "simple" parallelization framework +(pyutils.parallelize.parallelize). I hope you find them useful. + +Code is under src/*. Most code includes doctests. + +Tests are under tests/*. To run all tests: + + cd tests/ + ./run_tests.py --all [--coverage] + +See the README under tests/ for more options / information. + +This package generates Sphinx docs which are available at: + + https://wannabe.guru.org/pydocs/pyutils/pyutils.html + +For a long time this was just a local library on my machine that +my tools imported but I've decided to release it on PyPi. I hope +you find it useful. LICENSE and NOTICE describe reusing it and +where everything came from. Drop me a line if you are using this, +find a bug, or have a question. + + --Scott Gasch (scott.gasch@gmail.com) + diff --git a/cut_version.sh b/cut_version.sh new file mode 100755 index 0000000..0f477da --- /dev/null +++ b/cut_version.sh @@ -0,0 +1,75 @@ +#!/usr/bin/env bash + +set -e + +# Ask a yes or no question +function ask_y_n() { + local prompt default reply + if [ "${2:-}" = "Y" ]; then + prompt="Y/n" + default=Y + elif [ "${2:-}" = "N" ]; then + prompt="y/N" + default=N + else + prompt="y/n" + default= + fi + while true; do + echo -ne "$1 [$prompt] " + read -n 1 -t 5 reply " + exit 1 +fi + +VERSION=$1 +if ! ask_y_n "About to cut and upload $VERSION, ok?" "N"; then + echo "Ok, exiting instead." + exit 0 +fi +echo + +echo "Ok, here's the commit message of changes since the last version..." +LAST_TAG=$(git tag | tail -1) +git log $LAST_TAG..HEAD +pause +echo +echo + +echo "Ok, running scottutilz tests including test_some_dependencies.py..." +cd ../scottutilz/tests +./run_tests.py --all --coverage +cd ../../pyutilz +pause +echo +echo + +git tag -a "${VERSION}" -m "cut_version.sh ${VERSION}" +CHANGES=$(git log --pretty="- %s" $VERSION...$LAST_TAG) +printf "# 🎁 Release notes (\`$VERSION\`)\n\n## Changes\n$CHANGES\n\n## Metadata\n\`\`\`\nThis version -------- $VERSION\nPrevious version ---- $PREVIOUS_VERSION\nTotal commits ------- $(echo "$CHANGES" | wc -l)\n\`\`\`\n" >> release_notes.md + +cat ./pyproject.template | sed s/##VERSION##/$VERSION/g > ./pyproject.toml +git commit -a -m "Cut version ${VERSION}" -m "${CHANGES}" +git push + +python -m build +echo "To upload, run: \"twine upload --verbose --repository testpypi dist/*\"" diff --git a/docs/.gitignore b/docs/.gitignore new file mode 100644 index 0000000..2340036 --- /dev/null +++ b/docs/.gitignore @@ -0,0 +1 @@ +_build/** diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..d4bb2cb --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/README b/docs/README new file mode 100644 index 0000000..85560df --- /dev/null +++ b/docs/README @@ -0,0 +1,5 @@ + +This is a directory containing the instructions to Sphinx (pip install sphinx) +for generating HTML-based source code documentation for the pyutils library. + +I regenerate the docs as a .git/hook/post-push diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..7b5287f --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,62 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import sys + +sys.path.insert(0, os.path.abspath('/home/scott/lib/release/pyutils')) +sys.path.insert( + 0, os.path.abspath('/home/scott/py39-venv/lib/python3.9/site-packages/') +) +sys.path.insert(0, os.path.abspath('/usr/local/lib/python3.9/site-packages/')) + +# -- Project information ----------------------------------------------------- + +project = "pyutils" +copyright = '2021-2022, Scott Gasch' +author = 'Scott Gasch' + + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.doctest', + 'sphinx.ext.napoleon', + 'sphinx.ext.viewcode', +] + +autodoc_typehints = "both" + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] + + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = 'sphinx_rtd_theme' + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000..d866481 --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,22 @@ +# sphinx-apidoc -o . .. ../*secret* ../type/people* -f + +.. Scott's Python Utils documentation master file, created by + sphinx-quickstart on Tue May 24 19:36:45 2022. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Welcome to Scott's Python Utils's documentation! +================================================ + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + modules + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000..153be5e --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=_build + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/modules.rst b/docs/modules.rst new file mode 100644 index 0000000..ee6dcc8 --- /dev/null +++ b/docs/modules.rst @@ -0,0 +1,7 @@ +pyutils +======= + +.. toctree:: + :maxdepth: 4 + + pyutils diff --git a/docs/new_file_added.sh b/docs/new_file_added.sh new file mode 100755 index 0000000..8dc12cf --- /dev/null +++ b/docs/new_file_added.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +/bin/rm -f ./python_modules.* modules.rst +sphinx-apidoc -o . ../src/pyutils diff --git a/pyproject.template b/pyproject.template new file mode 100644 index 0000000..b155476 --- /dev/null +++ b/pyproject.template @@ -0,0 +1,52 @@ +[project] +name = "pyutils" +version = "##VERSION##" +authors = [ + { name="Scott Gasch", email="scott.gasch@gmail.com" }, +] +description = "Python Utilities" +readme = "README.md" +license = { file="LICENSE" } +requires-python = ">=3.7" +classifiers = [ + "Development Status :: 3 - Alpha", + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", +] +dependencies = [ + "bitstring", + "cloudpickle", + "holidays", + "numpy", + "overrides", + "python-dateutil", + "pytz", +] + +[project.urls] +"Homepage" = "https://wannabe.guru.org/pydocs/pyutils/pyutils.html" +## "Bug Tracker" = "https://github.com/pypa/sampleproject/issues" + +[project.optional-dependencies] +dev = [ + "black", + "coverage", + "flake8", + "pylint", + "sphinx" +] + +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +packages = [ + "pyutils", "pyutils.collectionz", "pyutils.compress", + "pyutils.datetimez", "pyutils.files", "pyutils.parallelize", + "pyutils.search", "pyutils.security", "pyutils.typez" +] + +[tool.setuptools.package-dir] +pyutils = "src/pyutils" diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..1f771f0 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,52 @@ +[project] +name = "pyutils" +version = "0.0.1a1" +authors = [ + { name="Scott Gasch", email="scott.gasch@gmail.com" }, +] +description = "Python Utilities" +readme = "README.md" +license = { file="LICENSE" } +requires-python = ">=3.7" +classifiers = [ + "Development Status :: 3 - Alpha", + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", +] +dependencies = [ + "bitstring", + "cloudpickle", + "holidays", + "numpy", + "overrides", + "python-dateutil", + "pytz", +] + +[project.urls] +"Homepage" = "https://wannabe.guru.org/pydocs/pyutils/pyutils.html" +## "Bug Tracker" = "https://github.com/pypa/sampleproject/issues" + +[project.optional-dependencies] +dev = [ + "black", + "coverage", + "flake8", + "pylint", + "sphinx" +] + +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +packages = [ + "pyutils", "pyutils.collectionz", "pyutils.compress", + "pyutils.datetimez", "pyutils.files", "pyutils.parallelize", + "pyutils.search", "pyutils.security", "pyutils.typez" +] + +[tool.setuptools.package-dir] +pyutils = "src/pyutils" diff --git a/release_notes.md b/release_notes.md new file mode 100644 index 0000000..e69de29 diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..a974967 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,5 @@ +[options] +packages = + pyutils +package_dir = + = src diff --git a/src/pyutils/__init__.py b/src/pyutils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/pyutils/ansi.py b/src/pyutils/ansi.py new file mode 100755 index 0000000..1d45d3b --- /dev/null +++ b/src/pyutils/ansi.py @@ -0,0 +1,2079 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""A bunch of color names mapped into RGB tuples and some methods for +setting the text color, background, etc... using ANSI escape +sequences. +""" + +import contextlib +import difflib +import io +import logging +import re +import sys +from abc import abstractmethod +from typing import Any, Callable, Dict, Iterable, Literal, Optional, Tuple + +from overrides import overrides + +from pyutils import logging_utils, string_utils + +logger = logging.getLogger(__name__) + +# https://en.wikipedia.org/wiki/ANSI_escape_code + + +COLOR_NAMES_TO_RGB: Dict[str, Tuple[int, int, int]] = { + "abbey": (0x4C, 0x4F, 0x56), + "acadia": (0x1B, 0x14, 0x04), + "acapulco": (0x7C, 0xB0, 0xA1), + "aero blue": (0xC9, 0xFF, 0xE5), + "affair": (0x71, 0x46, 0x93), + "akaroa": (0xD4, 0xC4, 0xA8), + "alabaster": (0xFA, 0xFA, 0xFA), + "albescent white": (0xF5, 0xE9, 0xD3), + "algae green": (0x93, 0xDF, 0xB8), + "alice blue": (0xF0, 0xF8, 0xFF), + "alizarin crimson": (0xE3, 0x26, 0x36), + "allports": (0x00, 0x76, 0xA3), + "almond frost": (0x90, 0x7B, 0x71), + "almond": (0xEE, 0xD9, 0xC4), + "alpine": (0xAF, 0x8F, 0x2C), + "alto": (0xDB, 0xDB, 0xDB), + "aluminium": (0xA9, 0xAC, 0xB6), + "amaranth": (0xE5, 0x2B, 0x50), + "amazon": (0x3B, 0x7A, 0x57), + "amber": (0xFF, 0xBF, 0x00), + "americano": (0x87, 0x75, 0x6E), + "amethyst smoke": (0xA3, 0x97, 0xB4), + "amethyst": (0x99, 0x66, 0xCC), + "amour": (0xF9, 0xEA, 0xF3), + "amulet": (0x7B, 0x9F, 0x80), + "anakiwa": (0x9D, 0xE5, 0xFF), + "antique brass": (0xC8, 0x8A, 0x65), + "antique bronze": (0x70, 0x4A, 0x07), + "antique white": (0xFA, 0xEB, 0xD7), + "anzac": (0xE0, 0xB6, 0x46), + "apache": (0xDF, 0xBE, 0x6F), + "apple blossom": (0xAF, 0x4D, 0x43), + "apple green": (0xE2, 0xF3, 0xEC), + "apple": (0x4F, 0xA8, 0x3D), + "apricot peach": (0xFB, 0xCE, 0xB1), + "apricot white": (0xFF, 0xFE, 0xEC), + "apricot": (0xEB, 0x93, 0x73), + "aqua deep": (0x01, 0x4B, 0x43), + "aqua forest": (0x5F, 0xA7, 0x77), + "aqua haze": (0xED, 0xF5, 0xF5), + "aqua island": (0xA1, 0xDA, 0xD7), + "aqua spring": (0xEA, 0xF9, 0xF5), + "aqua squeeze": (0xE8, 0xF5, 0xF2), + "aqua": (0x00, 0xFF, 0xFF), + "aquamarine blue": (0x71, 0xD9, 0xE2), + "aquamarine": (0x7F, 0xFF, 0xD4), + "arapawa": (0x11, 0x0C, 0x6C), + "armadillo": (0x43, 0x3E, 0x37), + "arrowtown": (0x94, 0x87, 0x71), + "ash": (0xC6, 0xC3, 0xB5), + "asparagus": (0x7B, 0xA0, 0x5B), + "asphalt": (0x13, 0x0A, 0x06), + "astra": (0xFA, 0xEA, 0xB9), + "astral": (0x32, 0x7D, 0xA0), + "astronaut blue": (0x01, 0x3E, 0x62), + "astronaut": (0x28, 0x3A, 0x77), + "athens gray": (0xEE, 0xF0, 0xF3), + "aths special": (0xEC, 0xEB, 0xCE), + "atlantis": (0x97, 0xCD, 0x2D), + "atoll": (0x0A, 0x6F, 0x75), + "atomic tangerine": (0xFF, 0x99, 0x66), + "au chico": (0x97, 0x60, 0x5D), + "aubergine": (0x3B, 0x09, 0x10), + "australian mint": (0xF5, 0xFF, 0xBE), + "avocado": (0x88, 0x8D, 0x65), + "axolotl": (0x4E, 0x66, 0x49), + "azalea": (0xF7, 0xC8, 0xDA), + "aztec": (0x0D, 0x1C, 0x19), + "azure radiance": (0x00, 0x7F, 0xFF), + "azure": (0xF0, 0xFF, 0xFF), + "baby blue": (0xE0, 0xFF, 0xFF), + "backup.house": (175, 95, 0), + "bahama blue": (0x02, 0x63, 0x95), + "bahia": (0xA5, 0xCB, 0x0C), + "baja white": (0xFF, 0xF8, 0xD1), + "bali hai": (0x85, 0x9F, 0xAF), + "baltic sea": (0x2A, 0x26, 0x30), + "bamboo": (0xDA, 0x63, 0x04), + "banana mania": (0xFB, 0xE7, 0xB2), + "bandicoot": (0x85, 0x84, 0x70), + "barberry": (0xDE, 0xD7, 0x17), + "barley corn": (0xA6, 0x8B, 0x5B), + "barley white": (0xFF, 0xF4, 0xCE), + "barossa": (0x44, 0x01, 0x2D), + "bastille": (0x29, 0x21, 0x30), + "battleship gray": (0x82, 0x8F, 0x72), + "bay leaf": (0x7D, 0xA9, 0x8D), + "bay of many": (0x27, 0x3A, 0x81), + "bazaar": (0x98, 0x77, 0x7B), + "bean ": (0x3D, 0x0C, 0x02), + "beauty bush": (0xEE, 0xC1, 0xBE), + "beaver": (0x92, 0x6F, 0x5B), + "beeswax": (0xFE, 0xF2, 0xC7), + "beige": (0xF5, 0xF5, 0xDC), + "bermuda gray": (0x6B, 0x8B, 0xA2), + "bermuda": (0x7D, 0xD8, 0xC6), + "beryl green": (0xDE, 0xE5, 0xC0), + "bianca": (0xFC, 0xFB, 0xF3), + "big stone": (0x16, 0x2A, 0x40), + "bilbao": (0x32, 0x7C, 0x14), + "biloba flower": (0xB2, 0xA1, 0xEA), + "birch": (0x37, 0x30, 0x21), + "bird flower": (0xD4, 0xCD, 0x16), + "biscay": (0x1B, 0x31, 0x62), + "bismark": (0x49, 0x71, 0x83), + "bison hide": (0xC1, 0xB7, 0xA4), + "bisque": (0xFF, 0xE4, 0xC4), + "bistre": (0x3D, 0x2B, 0x1F), + "bitter lemon": (0xCA, 0xE0, 0x0D), + "bitter": (0x86, 0x89, 0x74), + "bittersweet": (0xFE, 0x6F, 0x5E), + "bizarre": (0xEE, 0xDE, 0xDA), + "black bean": (0x08, 0x19, 0x10), + "black forest": (0x0B, 0x13, 0x04), + "black haze": (0xF6, 0xF7, 0xF7), + "black marlin": (0x3E, 0x2C, 0x1C), + "black olive": (0x24, 0x2E, 0x16), + "black pearl": (0x04, 0x13, 0x22), + "black rock": (0x0D, 0x03, 0x32), + "black rose": (0x67, 0x03, 0x2D), + "black russian": (0x0A, 0x00, 0x1C), + "black squeeze": (0xF2, 0xFA, 0xFA), + "black white": (0xFF, 0xFE, 0xF6), + "black": (0x00, 0x00, 0x00), + "blackberry": (0x4D, 0x01, 0x35), + "blackcurrant": (0x32, 0x29, 0x3A), + "blanched almond": (0xFF, 0xEB, 0xCD), + "blaze orange": (0xFF, 0x66, 0x00), + "bleach white": (0xFE, 0xF3, 0xD8), + "bleached cedar": (0x2C, 0x21, 0x33), + "blizzard blue": (0xA3, 0xE3, 0xED), + "blossom": (0xDC, 0xB4, 0xBC), + "blue bayoux": (0x49, 0x66, 0x79), + "blue bell": (0x99, 0x99, 0xCC), + "blue chalk": (0xF1, 0xE9, 0xFF), + "blue charcoal": (0x01, 0x0D, 0x1A), + "blue chill": (0x0C, 0x89, 0x90), + "blue diamond": (0x38, 0x04, 0x74), + "blue dianne": (0x20, 0x48, 0x52), + "blue gem": (0x2C, 0x0E, 0x8C), + "blue haze": (0xBF, 0xBE, 0xD8), + "blue lagoon": (0x01, 0x79, 0x87), + "blue marguerite": (0x76, 0x66, 0xC6), + "blue ribbon": (0x00, 0x66, 0xFF), + "blue romance": (0xD2, 0xF6, 0xDE), + "blue smoke": (0x74, 0x88, 0x81), + "blue stone": (0x01, 0x61, 0x62), + "blue violet": (0x8A, 0x2B, 0xE2), + "blue whale": (0x04, 0x2E, 0x4C), + "blue zodiac": (0x13, 0x26, 0x4D), + "blue": (0x00, 0x00, 0xFF), + "blumine": (0x18, 0x58, 0x7A), + "blush pink": (0xFF, 0x6F, 0xFF), + "blush": (0xB4, 0x46, 0x68), + "bombay": (0xAF, 0xB1, 0xB8), + "bon jour": (0xE5, 0xE0, 0xE1), + "bondi blue": (0x00, 0x95, 0xB6), + "bone": (0xE4, 0xD1, 0xC0), + "bordeaux": (0x5C, 0x01, 0x20), + "bossanova": (0x4E, 0x2A, 0x5A), + "boston blue": (0x3B, 0x91, 0xB4), + "botticelli": (0xC7, 0xDD, 0xE5), + "bottle green": (0x09, 0x36, 0x24), + "boulder": (0x7A, 0x7A, 0x7A), + "bouquet": (0xAE, 0x80, 0x9E), + "bourbon": (0xBA, 0x6F, 0x1E), + "bracken": (0x4A, 0x2A, 0x04), + "brandy punch": (0xCD, 0x84, 0x29), + "brandy rose": (0xBB, 0x89, 0x83), + "brandy": (0xDE, 0xC1, 0x96), + "breaker bay": (0x5D, 0xA1, 0x9F), + "brick red": (0xC6, 0x2D, 0x42), + "bridal heath": (0xFF, 0xFA, 0xF4), + "bridesmaid": (0xFE, 0xF0, 0xEC), + "bright gray": (0x3C, 0x41, 0x51), + "bright green": (0x66, 0xFF, 0x00), + "bright red": (0xB1, 0x00, 0x00), + "bright sun": (0xFE, 0xD3, 0x3C), + "bright turquoise": (0x08, 0xE8, 0xDE), + "brilliant rose": (0xF6, 0x53, 0xA6), + "brink pink": (0xFB, 0x60, 0x7F), + "bronco": (0xAB, 0xA1, 0x96), + "bronze olive": (0x4E, 0x42, 0x0C), + "bronze": (0x3F, 0x21, 0x09), + "bronzetone": (0x4D, 0x40, 0x0F), + "broom": (0xFF, 0xEC, 0x13), + "brown bramble": (0x59, 0x28, 0x04), + "brown derby": (0x49, 0x26, 0x15), + "brown pod": (0x40, 0x18, 0x01), + "brown rust": (0xAF, 0x59, 0x3E), + "brown tumbleweed": (0x37, 0x29, 0x0E), + "brown": (0x96, 0x4B, 0x00), + "bubbles": (0xE7, 0xFE, 0xFF), + "buccaneer": (0x62, 0x2F, 0x30), + "bud": (0xA8, 0xAE, 0x9C), + "buddha gold": (0xC1, 0xA0, 0x04), + "buff": (0xF0, 0xDC, 0x82), + "bulgarian rose": (0x48, 0x06, 0x07), + "bull shot": (0x86, 0x4D, 0x1E), + "bunker": (0x0D, 0x11, 0x17), + "bunting": (0x15, 0x1F, 0x4C), + "burgundy": (0x90, 0x00, 0x20), + "burlywood": (0xDE, 0xB8, 0x87), + "burnham": (0x00, 0x2E, 0x20), + "burning orange": (0xFF, 0x70, 0x34), + "burning sand": (0xD9, 0x93, 0x76), + "burnt maroon": (0x42, 0x03, 0x03), + "burnt orange": (0xCC, 0x55, 0x00), + "burnt sienna": (0xE9, 0x74, 0x51), + "burnt umber": (0x8A, 0x33, 0x24), + "bush": (0x0D, 0x2E, 0x1C), + "buttercup": (0xF3, 0xAD, 0x16), + "buttered rum": (0xA1, 0x75, 0x0D), + "butterfly bush": (0x62, 0x4E, 0x9A), + "buttermilk": (0xFF, 0xF1, 0xB5), + "buttery white": (0xFF, 0xFC, 0xEA), + "cab sav": (0x4D, 0x0A, 0x18), + "cabaret": (0xD9, 0x49, 0x72), + "cabbage pont": (0x3F, 0x4C, 0x3A), + "cactus": (0x58, 0x71, 0x56), + "cadet blue": (0x5F, 0x9E, 0xA0), + "cadillac": (0xB0, 0x4C, 0x6A), + "cafe royale": (0x6F, 0x44, 0x0C), + "calico": (0xE0, 0xC0, 0x95), + "california": (0xFE, 0x9D, 0x04), + "calypso": (0x31, 0x72, 0x8D), + "camarone": (0x00, 0x58, 0x1A), + "camelot": (0x89, 0x34, 0x56), + "cameo": (0xD9, 0xB9, 0x9B), + "camouflage green": (0x78, 0x86, 0x6B), + "camouflage": (0x3C, 0x39, 0x10), + "can can": (0xD5, 0x91, 0xA4), + "canary": (0xF3, 0xFB, 0x62), + "candlelight": (0xFC, 0xD9, 0x17), + "candy corn": (0xFB, 0xEC, 0x5D), + "cannon black": (0x25, 0x17, 0x06), + "cannon pink": (0x89, 0x43, 0x67), + "cape cod": (0x3C, 0x44, 0x43), + "cape honey": (0xFE, 0xE5, 0xAC), + "cape palliser": (0xA2, 0x66, 0x45), + "caper": (0xDC, 0xED, 0xB4), + "caramel": (0xFF, 0xDD, 0xAF), + "cararra": (0xEE, 0xEE, 0xE8), + "cardin green": (0x01, 0x36, 0x1C), + "cardinal pink": (0x8C, 0x05, 0x5E), + "cardinal": (0xC4, 0x1E, 0x3A), + "careys pink": (0xD2, 0x9E, 0xAA), + "caribbean green": (0x00, 0xCC, 0x99), + "carissma": (0xEA, 0x88, 0xA8), + "carla": (0xF3, 0xFF, 0xD8), + "carmine": (0x96, 0x00, 0x18), + "carnaby tan": (0x5C, 0x2E, 0x01), + "carnation pink": (0xFF, 0xA6, 0xC9), + "carnation": (0xF9, 0x5A, 0x61), + "carousel pink": (0xF9, 0xE0, 0xED), + "carrot orange": (0xED, 0x91, 0x21), + "casablanca": (0xF8, 0xB8, 0x53), + "casal": (0x2F, 0x61, 0x68), + "cascade": (0x8B, 0xA9, 0xA5), + "cashmere": (0xE6, 0xBE, 0xA5), + "casper": (0xAD, 0xBE, 0xD1), + "castro": (0x52, 0x00, 0x1F), + "catalina blue": (0x06, 0x2A, 0x78), + "catskill white": (0xEE, 0xF6, 0xF7), + "cavern pink": (0xE3, 0xBE, 0xBE), + "cedar wood finish": (0x71, 0x1A, 0x00), + "cedar": (0x3E, 0x1C, 0x14), + "celadon": (0xAC, 0xE1, 0xAF), + "celery": (0xB8, 0xC2, 0x5D), + "celeste": (0xD1, 0xD2, 0xCA), + "cello": (0x1E, 0x38, 0x5B), + "celtic": (0x16, 0x32, 0x22), + "cement": (0x8D, 0x76, 0x62), + "ceramic": (0xFC, 0xFF, 0xF9), + "cerise red": (0xDE, 0x31, 0x63), + "cerise": (0xDA, 0x32, 0x87), + "cerulean blue": (0x2A, 0x52, 0xBE), + "cerulean": (0x02, 0xA4, 0xD3), + "chablis": (0xFF, 0xF4, 0xF3), + "chalet green": (0x51, 0x6E, 0x3D), + "chalky": (0xEE, 0xD7, 0x94), + "chambray": (0x35, 0x4E, 0x8C), + "chamois": (0xED, 0xDC, 0xB1), + "champagne": (0xFA, 0xEC, 0xCC), + "chantilly": (0xF8, 0xC3, 0xDF), + "charade": (0x29, 0x29, 0x37), + "chardon": (0xFF, 0xF3, 0xF1), + "chardonnay": (0xFF, 0xCD, 0x8C), + "charlotte": (0xBA, 0xEE, 0xF9), + "charm": (0xD4, 0x74, 0x94), + "chartreuse yellow": (0xDF, 0xFF, 0x00), + "chartreuse": (0x7F, 0xFF, 0x00), + "chateau green": (0x40, 0xA8, 0x60), + "chatelle": (0xBD, 0xB3, 0xC7), + "chathams blue": (0x17, 0x55, 0x79), + "cheetah.house": (95, 0x00, 0x00), + "chelsea cucumber": (0x83, 0xAA, 0x5D), + "chelsea gem": (0x9E, 0x53, 0x02), + "chenin": (0xDF, 0xCD, 0x6F), + "cherokee": (0xFC, 0xDA, 0x98), + "cherry pie": (0x2A, 0x03, 0x59), + "cherrywood": (0x65, 0x1A, 0x14), + "cherub": (0xF8, 0xD9, 0xE9), + "chestnut rose": (0xCD, 0x5C, 0x5C), + "chestnut": (0xB9, 0x4E, 0x48), + "chetwode blue": (0x85, 0x81, 0xD9), + "chicago": (0x5D, 0x5C, 0x58), + "chiffon": (0xF1, 0xFF, 0xC8), + "chilean fire": (0xF7, 0x77, 0x03), + "chilean heath": (0xFF, 0xFD, 0xE6), + "china ivory": (0xFC, 0xFF, 0xE7), + "chino": (0xCE, 0xC7, 0xA7), + "chinook": (0xA8, 0xE3, 0xBD), + "chocolate": (0x37, 0x02, 0x02), + "christalle": (0x33, 0x03, 0x6B), + "christi": (0x67, 0xA7, 0x12), + "christine": (0xE7, 0x73, 0x0A), + "chrome white": (0xE8, 0xF1, 0xD4), + "cinder": (0x0E, 0x0E, 0x18), + "cinderella": (0xFD, 0xE1, 0xDC), + "cinnabar": (0xE3, 0x42, 0x34), + "cinnamon": (0x7B, 0x3F, 0x00), + "cioccolato": (0x55, 0x28, 0x0C), + "citrine white": (0xFA, 0xF7, 0xD6), + "citron": (0x9E, 0xA9, 0x1F), + "citrus": (0xA1, 0xC5, 0x0A), + "clairvoyant": (0x48, 0x06, 0x56), + "clam shell": (0xD4, 0xB6, 0xAF), + "claret": (0x7F, 0x17, 0x34), + "classic rose": (0xFB, 0xCC, 0xE7), + "clay ash": (0xBD, 0xC8, 0xB3), + "clay creek": (0x8A, 0x83, 0x60), + "clear day": (0xE9, 0xFF, 0xFD), + "clementine": (0xE9, 0x6E, 0x00), + "clinker": (0x37, 0x1D, 0x09), + "cloud burst": (0x20, 0x2E, 0x54), + "cloud": (0xC7, 0xC4, 0xBF), + "cloudy": (0xAC, 0xA5, 0x9F), + "clover": (0x38, 0x49, 0x10), + "cobalt": (0x00, 0x47, 0xAB), + "cocoa bean": (0x48, 0x1C, 0x1C), + "cocoa brown": (0x30, 0x1F, 0x1E), + "coconut cream": (0xF8, 0xF7, 0xDC), + "cod gray": (0x0B, 0x0B, 0x0B), + "coffee bean": (0x2A, 0x14, 0x0E), + "coffee": (0x70, 0x65, 0x55), + "cognac": (0x9F, 0x38, 0x1D), + "cola": (0x3F, 0x25, 0x00), + "cold purple": (0xAB, 0xA0, 0xD9), + "cold turkey": (0xCE, 0xBA, 0xBA), + "colonial white": (0xFF, 0xED, 0xBC), + "comet": (0x5C, 0x5D, 0x75), + "como": (0x51, 0x7C, 0x66), + "conch": (0xC9, 0xD9, 0xD2), + "concord": (0x7C, 0x7B, 0x7A), + "concrete": (0xF2, 0xF2, 0xF2), + "confetti": (0xE9, 0xD7, 0x5A), + "congo brown": (0x59, 0x37, 0x37), + "congress blue": (0x02, 0x47, 0x8E), + "conifer": (0xAC, 0xDD, 0x4D), + "contessa": (0xC6, 0x72, 0x6B), + "copper canyon": (0x7E, 0x3A, 0x15), + "copper rose": (0x99, 0x66, 0x66), + "copper rust": (0x94, 0x47, 0x47), + "copper": (0xB8, 0x73, 0x33), + "copperfield": (0xDA, 0x8A, 0x67), + "coral red": (0xFF, 0x40, 0x40), + "coral reef": (0xC7, 0xBC, 0xA2), + "coral tree": (0xA8, 0x6B, 0x6B), + "coral": (0xFF, 0x7F, 0x50), + "corduroy": (0x60, 0x6E, 0x68), + "coriander": (0xC4, 0xD0, 0xB0), + "cork": (0x40, 0x29, 0x1D), + "corn field": (0xF8, 0xFA, 0xCD), + "corn harvest": (0x8B, 0x6B, 0x0B), + "corn silk": (0xFF, 0xF8, 0xDC), + "corn": (0xE7, 0xBF, 0x05), + "cornflower blue": (0x64, 0x95, 0xED), + "cornflower lilac": (0xFF, 0xB0, 0xAC), + "cornflower": (0x93, 0xCC, 0xEA), + "corvette": (0xFA, 0xD3, 0xA2), + "cosmic": (0x76, 0x39, 0x5D), + "cosmos": (0xFF, 0xD8, 0xD9), + "costa del sol": (0x61, 0x5D, 0x30), + "cotton candy": (0xFF, 0xB7, 0xD5), + "cotton seed": (0xC2, 0xBD, 0xB6), + "county green": (0x01, 0x37, 0x1A), + "cowboy": (0x4D, 0x28, 0x2D), + "crail": (0xB9, 0x51, 0x40), + "cranberry": (0xDB, 0x50, 0x79), + "crater brown": (0x46, 0x24, 0x25), + "cream brulee": (0xFF, 0xE5, 0xA0), + "cream can": (0xF5, 0xC8, 0x5C), + "cream": (0xFF, 0xFD, 0xD0), + "creole": (0x1E, 0x0F, 0x04), + "crete": (0x73, 0x78, 0x29), + "crimson": (0xDC, 0x14, 0x3C), + "crocodile": (0x73, 0x6D, 0x58), + "crown of thorns": (0x77, 0x1F, 0x1F), + "crowshead": (0x1C, 0x12, 0x08), + "cruise": (0xB5, 0xEC, 0xDF), + "crusoe": (0x00, 0x48, 0x16), + "crusta": (0xFD, 0x7B, 0x33), + "cumin": (0x92, 0x43, 0x21), + "cumulus": (0xFD, 0xFF, 0xD5), + "cupid": (0xFB, 0xBE, 0xDA), + "curious blue": (0x25, 0x96, 0xD1), + "cutty sark": (0x50, 0x76, 0x72), + "cyan": (0x00, 0xFF, 0xFF), + "cyprus": (0x00, 0x3E, 0x40), + "daintree": (0x01, 0x27, 0x31), + "dairy cream": (0xF9, 0xE4, 0xBC), + "daisy bush": (0x4F, 0x23, 0x98), + "dallas": (0x6E, 0x4B, 0x26), + "dandelion": (0xFE, 0xD8, 0x5D), + "danube": (0x60, 0x93, 0xD1), + "dark blue": (0x00, 0x00, 0x8B), + "dark burgundy": (0x77, 0x0F, 0x05), + "dark cyan": (0x00, 0x8B, 0x8B), + "dark ebony": (0x3C, 0x20, 0x05), + "dark fern": (0x0A, 0x48, 0x0D), + "dark goldenrod": (0xB8, 0x86, 0x0B), + "dark gray": (0xA9, 0xA9, 0xA9), + "dark green": (0x18, 0x2D, 0x09), + "dark magenta": (0xAF, 0x00, 0xAF), + "dark olive green": (0x55, 0x6B, 0x2F), + "dark orange": (0xFF, 0x8C, 0x00), + "dark orchid": (0x99, 0x32, 0xCC), + "dark purple": (0x36, 0x00, 0x79), + "dark red": (0x64, 0x00, 0x00), + "dark salmon": (0xE9, 0x96, 0x7A), + "dark sea green": (0x8F, 0xBC, 0x8F), + "dark slate gray": (0x2F, 0x4F, 0x4F), + "dark tan": (0x66, 0x10, 0x10), + "dark turquoise": (0x00, 0xCE, 0xD1), + "dark violet": (0x94, 0x00, 0xD3), + "dawn pink": (0xF3, 0xE9, 0xE5), + "dawn": (0xA6, 0xA2, 0x9A), + "de york": (0x7A, 0xC4, 0x88), + "deco": (0xD2, 0xDA, 0x97), + "deep blue": (0x22, 0x08, 0x78), + "deep blush": (0xE4, 0x76, 0x98), + "deep bronze": (0x4A, 0x30, 0x04), + "deep cerulean": (0x00, 0x7B, 0xA7), + "deep cove": (0x05, 0x10, 0x40), + "deep fir": (0x00, 0x29, 0x00), + "deep forest green": (0x18, 0x2D, 0x09), + "deep koamaru": (0x1B, 0x12, 0x7B), + "deep oak": (0x41, 0x20, 0x10), + "deep pink": (0xFF, 0x14, 0x93), + "deep sapphire": (0x08, 0x25, 0x67), + "deep sea green": (0x09, 0x58, 0x59), + "deep sea": (0x01, 0x82, 0x6B), + "deep sky blue": (0x00, 0xBF, 0xFF), + "deep teal": (0x00, 0x35, 0x32), + "del rio": (0xB0, 0x9A, 0x95), + "dell": (0x39, 0x64, 0x13), + "delta": (0xA4, 0xA4, 0x9D), + "deluge": (0x75, 0x63, 0xA8), + "denim": (0x15, 0x60, 0xBD), + "derby": (0xFF, 0xEE, 0xD8), + "desert sand": (0xED, 0xC9, 0xAF), + "desert storm": (0xF8, 0xF8, 0xF7), + "desert": (0xAE, 0x60, 0x20), + "dew": (0xEA, 0xFF, 0xFE), + "di serria": (0xDB, 0x99, 0x5E), + "diesel": (0x13, 0x00, 0x00), + "dim gray": (0x69, 0x69, 0x69), + "dingley": (0x5D, 0x77, 0x47), + "disco": (0x87, 0x15, 0x50), + "dixie": (0xE2, 0x94, 0x18), + "dodger blue": (0x1E, 0x90, 0xFF), + "dolly": (0xF9, 0xFF, 0x8B), + "dolphin": (0x64, 0x60, 0x77), + "domino": (0x8E, 0x77, 0x5E), + "don juan": (0x5D, 0x4C, 0x51), + "donkey brown": (0xA6, 0x92, 0x79), + "dorado": (0x6B, 0x57, 0x55), + "double colonial white": (0xEE, 0xE3, 0xAD), + "double pearl lusta": (0xFC, 0xF4, 0xD0), + "double spanish white": (0xE6, 0xD7, 0xB9), + "dove gray": (0x6D, 0x6C, 0x6C), + "downriver": (0x09, 0x22, 0x56), + "downy": (0x6F, 0xD0, 0xC5), + "driftwood": (0xAF, 0x87, 0x51), + "drover": (0xFD, 0xF7, 0xAD), + "dull lavender": (0xA8, 0x99, 0xE6), + "dune": (0x38, 0x35, 0x33), + "dust storm": (0xE5, 0xCC, 0xC9), + "dusty gray": (0xA8, 0x98, 0x9B), + "eagle": (0xB6, 0xBA, 0xA4), + "earls green": (0xC9, 0xB9, 0x3B), + "early dawn": (0xFF, 0xF9, 0xE6), + "east bay": (0x41, 0x4C, 0x7D), + "east side": (0xAC, 0x91, 0xCE), + "eastern blue": (0x1E, 0x9A, 0xB0), + "ebb": (0xE9, 0xE3, 0xE3), + "ebony clay": (0x26, 0x28, 0x3B), + "ebony": (0x0C, 0x0B, 0x1D), + "eclipse": (0x31, 0x1C, 0x17), + "ecru white": (0xF5, 0xF3, 0xE5), + "ecstasy": (0xFA, 0x78, 0x14), + "eden": (0x10, 0x58, 0x52), + "edgewater": (0xC8, 0xE3, 0xD7), + "edward": (0xA2, 0xAE, 0xAB), + "egg sour": (0xFF, 0xF4, 0xDD), + "egg white": (0xFF, 0xEF, 0xC1), + "eggplant": (0x61, 0x40, 0x51), + "el paso": (0x1E, 0x17, 0x08), + "el salva": (0x8F, 0x3E, 0x33), + "electric lime": (0xCC, 0xFF, 0x00), + "electric violet": (0x8B, 0x00, 0xFF), + "elephant": (0x12, 0x34, 0x47), + "elf green": (0x08, 0x83, 0x70), + "elm": (0x1C, 0x7C, 0x7D), + "emerald": (0x50, 0xC8, 0x78), + "eminence": (0x6C, 0x30, 0x82), + "emperor": (0x51, 0x46, 0x49), + "empress": (0x81, 0x73, 0x77), + "endeavour": (0x00, 0x56, 0xA7), + "energy yellow": (0xF8, 0xDD, 0x5C), + "english holly": (0x02, 0x2D, 0x15), + "english walnut": (0x3E, 0x2B, 0x23), + "envy": (0x8B, 0xA6, 0x90), + "equator": (0xE1, 0xBC, 0x64), + "espresso": (0x61, 0x27, 0x18), + "eternity": (0x21, 0x1A, 0x0E), + "eucalyptus": (0x27, 0x8A, 0x5B), + "eunry": (0xCF, 0xA3, 0x9D), + "evening sea": (0x02, 0x4E, 0x46), + "everglade": (0x1C, 0x40, 0x2E), + "faded jade": (0x42, 0x79, 0x77), + "fair pink": (0xFF, 0xEF, 0xEC), + "falcon": (0x7F, 0x62, 0x6D), + "fall green": (0xEC, 0xEB, 0xBD), + "falu red": (0x80, 0x18, 0x18), + "fantasy": (0xFA, 0xF3, 0xF0), + "fedora": (0x79, 0x6A, 0x78), + "feijoa": (0x9F, 0xDD, 0x8C), + "fern frond": (0x65, 0x72, 0x20), + "fern green": (0x4F, 0x79, 0x42), + "fern": (0x63, 0xB7, 0x6C), + "ferra": (0x70, 0x4F, 0x50), + "festival": (0xFB, 0xE9, 0x6C), + "feta": (0xF0, 0xFC, 0xEA), + "fiery orange": (0xB3, 0x52, 0x13), + "finch": (0x62, 0x66, 0x49), + "finlandia": (0x55, 0x6D, 0x56), + "finn": (0x69, 0x2D, 0x54), + "fiord": (0x40, 0x51, 0x69), + "fire brick": (0xB2, 0x22, 0x22), + "fire bush": (0xE8, 0x99, 0x28), + "fire": (0xAA, 0x42, 0x03), + "firefly": (0x0E, 0x2A, 0x30), + "flame pea": (0xDA, 0x5B, 0x38), + "flamenco": (0xFF, 0x7D, 0x07), + "flamingo": (0xF2, 0x55, 0x2A), + "flax smoke": (0x7B, 0x82, 0x65), + "flax": (0xEE, 0xDC, 0x82), + "flesh": (0xFF, 0xCB, 0xA4), + "flint": (0x6F, 0x6A, 0x61), + "flirt": (0xA2, 0x00, 0x6D), + "floral white": (0xFF, 0xFA, 0xF0), + "flush mahogany": (0xCA, 0x34, 0x35), + "flush orange": (0xFF, 0x7F, 0x00), + "foam": (0xD8, 0xFC, 0xFA), + "fog": (0xD7, 0xD0, 0xFF), + "foggy gray": (0xCB, 0xCA, 0xB6), + "forest green": (0x22, 0x8B, 0x22), + "forget me not": (0xFF, 0xF1, 0xEE), + "fountain blue": (0x56, 0xB4, 0xBE), + "frangipani": (0xFF, 0xDE, 0xB3), + "french gray": (0xBD, 0xBD, 0xC6), + "french lilac": (0xEC, 0xC7, 0xEE), + "french pass": (0xBD, 0xED, 0xFD), + "french rose": (0xF6, 0x4A, 0x8A), + "fresh eggplant": (0x99, 0x00, 0x66), + "friar gray": (0x80, 0x7E, 0x79), + "fringy flower": (0xB1, 0xE2, 0xC1), + "froly": (0xF5, 0x75, 0x84), + "frost": (0xED, 0xF5, 0xDD), + "frosted mint": (0xDB, 0xFF, 0xF8), + "frostee": (0xE4, 0xF6, 0xE7), + "fruit salad": (0x4F, 0x9D, 0x5D), + "fuchsia blue": (0x7A, 0x58, 0xC1), + "fuchsia pink": (0xC1, 0x54, 0xC1), + "fuchsia": (0xFF, 0x00, 0xFF), + "fuego": (0xBE, 0xDE, 0x0D), + "fuel yellow": (0xEC, 0xA9, 0x27), + "fun blue": (0x19, 0x59, 0xA8), + "fun green": (0x01, 0x6D, 0x39), + "fuscous gray": (0x54, 0x53, 0x4D), + "fuzzy wuzzy brown": (0xC4, 0x56, 0x55), + "gable green": (0x16, 0x35, 0x31), + "gainsboro": (0xDC, 0xDC, 0xDC), + "gallery": (0xEF, 0xEF, 0xEF), + "galliano": (0xDC, 0xB2, 0x0C), + "gamboge": (0xE4, 0x9B, 0x0F), + "geebung": (0xD1, 0x8F, 0x1B), + "genoa": (0x15, 0x73, 0x6B), + "geraldine": (0xFB, 0x89, 0x89), + "geyser": (0xD4, 0xDF, 0xE2), + "ghost white": (0xF8, 0xF8, 0xFF), + "ghost": (0xC7, 0xC9, 0xD5), + "gigas": (0x52, 0x3C, 0x94), + "gimblet": (0xB8, 0xB5, 0x6A), + "gin fizz": (0xFF, 0xF9, 0xE2), + "gin": (0xE8, 0xF2, 0xEB), + "givry": (0xF8, 0xE4, 0xBF), + "glacier": (0x80, 0xB3, 0xC4), + "glade green": (0x61, 0x84, 0x5F), + "go ben": (0x72, 0x6D, 0x4E), + "goblin": (0x3D, 0x7D, 0x52), + "gold drop": (0xF1, 0x82, 0x00), + "gold sand": (0xE6, 0xBE, 0x8A), + "gold tips": (0xDE, 0xBA, 0x13), + "gold": (0xFF, 0xD7, 0x00), + "golden bell": (0xE2, 0x89, 0x13), + "golden dream": (0xF0, 0xD5, 0x2D), + "golden fizz": (0xF5, 0xFB, 0x3D), + "golden glow": (0xFD, 0xE2, 0x95), + "golden grass": (0xDA, 0xA5, 0x20), + "golden sand": (0xF0, 0xDB, 0x7D), + "golden tainoi": (0xFF, 0xCC, 0x5C), + "goldenrod": (0xFC, 0xD6, 0x67), + "gondola": (0x26, 0x14, 0x14), + "gordons green": (0x0B, 0x11, 0x07), + "gorse": (0xFF, 0xF1, 0x4F), + "gossamer": (0x06, 0x9B, 0x81), + "gossip": (0xD2, 0xF8, 0xB0), + "gothic": (0x6D, 0x92, 0xA1), + "governor bay": (0x2F, 0x3C, 0xB3), + "grain brown": (0xE4, 0xD5, 0xB7), + "grandis": (0xFF, 0xD3, 0x8C), + "granite green": (0x8D, 0x89, 0x74), + "granny apple": (0xD5, 0xF6, 0xE3), + "granny smith apple": (0x9D, 0xE0, 0x93), + "granny smith": (0x84, 0xA0, 0xA0), + "grape": (0x38, 0x1A, 0x51), + "graphite": (0x25, 0x16, 0x07), + "gravel": (0x4A, 0x44, 0x4B), + "gray asparagus": (0x46, 0x59, 0x45), + "gray chateau": (0xA2, 0xAA, 0xB3), + "gray nickel": (0xC3, 0xC3, 0xBD), + "gray nurse": (0xE7, 0xEC, 0xE6), + "gray olive": (0xA9, 0xA4, 0x91), + "gray suit": (0xC1, 0xBE, 0xCD), + "gray": (0x80, 0x80, 0x80), + "green haze": (0x01, 0xA3, 0x68), + "green house": (0x24, 0x50, 0x0F), + "green kelp": (0x25, 0x31, 0x1C), + "green leaf": (0x43, 0x6A, 0x0D), + "green mist": (0xCB, 0xD3, 0xB0), + "green pea": (0x1D, 0x61, 0x42), + "green smoke": (0xA4, 0xAF, 0x6E), + "green spring": (0xB8, 0xC1, 0xB1), + "green vogue": (0x03, 0x2B, 0x52), + "green waterloo": (0x10, 0x14, 0x05), + "green white": (0xE8, 0xEB, 0xE0), + "green yellow": (0xAD, 0xFF, 0x2F), + "green": (0x00, 0xFF, 0x00), + "grenadier": (0xD5, 0x46, 0x00), + "guardsman red": (0xBA, 0x01, 0x01), + "gulf blue": (0x05, 0x16, 0x57), + "gulf stream": (0x80, 0xB3, 0xAE), + "gull gray": (0x9D, 0xAC, 0xB7), + "gum leaf": (0xB6, 0xD3, 0xBF), + "gumbo": (0x7C, 0xA1, 0xA6), + "gun powder": (0x41, 0x42, 0x57), + "gunsmoke": (0x82, 0x86, 0x85), + "gurkha": (0x9A, 0x95, 0x77), + "hacienda": (0x98, 0x81, 0x1B), + "hairy heath": (0x6B, 0x2A, 0x14), + "haiti": (0x1B, 0x10, 0x35), + "half and half": (0xFF, 0xFE, 0xE1), + "half baked": (0x85, 0xC4, 0xCC), + "half colonial white": (0xFD, 0xF6, 0xD3), + "half dutch white": (0xFE, 0xF7, 0xDE), + "half spanish white": (0xFE, 0xF4, 0xDB), + "hampton": (0xE5, 0xD8, 0xAF), + "harlequin": (0x3F, 0xFF, 0x00), + "harp": (0xE6, 0xF2, 0xEA), + "harvest gold": (0xE0, 0xB9, 0x74), + "havelock blue": (0x55, 0x90, 0xD9), + "hawaiian tan": (0x9D, 0x56, 0x16), + "hawkes blue": (0xD4, 0xE2, 0xFC), + "heath": (0x54, 0x10, 0x12), + "heather": (0xB7, 0xC3, 0xD0), + "heathered gray": (0xB6, 0xB0, 0x95), + "heavy metal": (0x2B, 0x32, 0x28), + "heliotrope": (0xDF, 0x73, 0xFF), + "hemlock": (0x5E, 0x5D, 0x3B), + "hemp": (0x90, 0x78, 0x74), + "hibiscus": (0xB6, 0x31, 0x6C), + "highland": (0x6F, 0x8E, 0x63), + "hillary": (0xAC, 0xA5, 0x86), + "himalaya": (0x6A, 0x5D, 0x1B), + "hint of green": (0xE6, 0xFF, 0xE9), + "hint of red": (0xFB, 0xF9, 0xF9), + "hint of yellow": (0xFA, 0xFD, 0xE4), + "hippie blue": (0x58, 0x9A, 0xAF), + "hippie green": (0x53, 0x82, 0x4B), + "hippie pink": (0xAE, 0x45, 0x60), + "hit gray": (0xA1, 0xAD, 0xB5), + "hit pink": (0xFF, 0xAB, 0x81), + "hokey pokey": (0xC8, 0xA5, 0x28), + "hoki": (0x65, 0x86, 0x9F), + "holly": (0x01, 0x1D, 0x13), + "hollywood cerise": (0xF4, 0x00, 0xA1), + "honey flower": (0x4F, 0x1C, 0x70), + "honeydew": (0xF0, 0xFF, 0xF0), + "honeysuckle": (0xED, 0xFC, 0x84), + "hopbush": (0xD0, 0x6D, 0xA1), + "horizon": (0x5A, 0x87, 0xA0), + "horses neck": (0x60, 0x49, 0x13), + "hot cinnamon": (0xD2, 0x69, 0x1E), + "hot pink": (0xFF, 0x69, 0xB4), + "hot toddy": (0xB3, 0x80, 0x07), + "humming bird": (0xCF, 0xF9, 0xF3), + "hunter green": (0x16, 0x1D, 0x10), + "hurricane": (0x87, 0x7C, 0x7B), + "husk": (0xB7, 0xA4, 0x58), + "ice cold": (0xB1, 0xF4, 0xE7), + "iceberg": (0xDA, 0xF4, 0xF0), + "illusion": (0xF6, 0xA4, 0xC9), + "inch worm": (0xB0, 0xE3, 0x13), + "indian khaki": (0xC3, 0xB0, 0x91), + "indian red": (0xCD, 0x5C, 0x5C), + "indian tan": (0x4D, 0x1E, 0x01), + "indigo": (0x4F, 0x69, 0xC6), + "indochine": (0xC2, 0x6B, 0x03), + "international orange": (0xFF, 0x4F, 0x00), + "irish coffee": (0x5F, 0x3D, 0x26), + "iroko": (0x43, 0x31, 0x20), + "iron": (0xD4, 0xD7, 0xD9), + "ironside gray": (0x67, 0x66, 0x62), + "ironstone": (0x86, 0x48, 0x3C), + "island spice": (0xFF, 0xFC, 0xEE), + "ivory": (0xFF, 0xFF, 0xF0), + "jacaranda": (0x2E, 0x03, 0x29), + "jacarta": (0x3A, 0x2A, 0x6A), + "jacko bean": (0x2E, 0x19, 0x05), + "jacksons purple": (0x20, 0x20, 0x8D), + "jade": (0x00, 0xA8, 0x6B), + "jaffa": (0xEF, 0x86, 0x3F), + "jagged ice": (0xC2, 0xE8, 0xE5), + "jagger": (0x35, 0x0E, 0x57), + "jaguar": (0x08, 0x01, 0x10), + "jambalaya": (0x5B, 0x30, 0x13), + "janna": (0xF4, 0xEB, 0xD3), + "japanese laurel": (0x0A, 0x69, 0x06), + "japanese maple": (0x78, 0x01, 0x09), + "japonica": (0xD8, 0x7C, 0x63), + "java": (0x1F, 0xC2, 0xC2), + "jazzberry jam": (0xA5, 0x0B, 0x5E), + "jelly bean": (0x29, 0x7B, 0x9A), + "jet stream": (0xB5, 0xD2, 0xCE), + "jewel": (0x12, 0x6B, 0x40), + "jon": (0x3B, 0x1F, 0x1F), + "jonquil": (0xEE, 0xFF, 0x9A), + "jordy blue": (0x8A, 0xB9, 0xF1), + "judge gray": (0x54, 0x43, 0x33), + "jumbo": (0x7C, 0x7B, 0x82), + "jungle green": (0x29, 0xAB, 0x87), + "jungle mist": (0xB4, 0xCF, 0xD3), + "juniper": (0x6D, 0x92, 0x92), + "just right": (0xEC, 0xCD, 0xB9), + "kabul": (0x5E, 0x48, 0x3E), + "kaitoke green": (0x00, 0x46, 0x20), + "kangaroo": (0xC6, 0xC8, 0xBD), + "karaka": (0x1E, 0x16, 0x09), + "karry": (0xFF, 0xEA, 0xD4), + "kashmir blue": (0x50, 0x70, 0x96), + "kelp": (0x45, 0x49, 0x36), + "kenyan copper": (0x7C, 0x1C, 0x05), + "keppel": (0x3A, 0xB0, 0x9E), + "key lime pie": (0xBF, 0xC9, 0x21), + "khaki": (0xF0, 0xE6, 0x8C), + "kidnapper": (0xE1, 0xEA, 0xD4), + "kilamanjaro": (0x24, 0x0C, 0x02), + "killarney": (0x3A, 0x6A, 0x47), + "kimberly": (0x73, 0x6C, 0x9F), + "kingfisher daisy": (0x3E, 0x04, 0x80), + "kiosk.house": (90, 95, 0), + "klein blue": (0x00, 0x2F, 0xA7), + "kobi": (0xE7, 0x9F, 0xC4), + "kokoda": (0x6E, 0x6D, 0x57), + "korma": (0x8F, 0x4B, 0x0E), + "koromiko": (0xFF, 0xBD, 0x5F), + "kournikova": (0xFF, 0xE7, 0x72), + "kumera": (0x88, 0x62, 0x21), + "la palma": (0x36, 0x87, 0x16), + "la rioja": (0xB3, 0xC1, 0x10), + "las palmas": (0xC6, 0xE6, 0x10), + "laser lemon": (0xFF, 0xFF, 0x66), + "laser": (0xC8, 0xB5, 0x68), + "laurel": (0x74, 0x93, 0x78), + "lavender blush": (0xFF, 0xF0, 0xF5), + "lavender gray": (0xBD, 0xBB, 0xD7), + "lavender magenta": (0xEE, 0x82, 0xEE), + "lavender pink": (0xFB, 0xAE, 0xD2), + "lavender purple": (0x96, 0x7B, 0xB6), + "lavender rose": (0xFB, 0xA0, 0xE3), + "lavender": (0xB5, 0x7E, 0xDC), + "lawn green": (0x7C, 0xFC, 0x00), + "leather": (0x96, 0x70, 0x59), + "lemon chiffon": (0xFF, 0xFA, 0xCD), + "lemon ginger": (0xAC, 0x9E, 0x22), + "lemon grass": (0x9B, 0x9E, 0x8F), + "lemon": (0xFD, 0xE9, 0x10), + "light apricot": (0xFD, 0xD5, 0xB1), + "light blue": (0xAD, 0xD8, 0xE6), + "light coral": (0xF0, 0x80, 0x80), + "light cyan": (0xE0, 0xFF, 0xFF), + "light goldenrod": (0xFA, 0xFA, 0xD2), + "light gray": (0x26, 0x23, 0x35), + "light green": (0x90, 0xEE, 0x90), + "light orchid": (0xE2, 0x9C, 0xD2), + "light pink": (0xDD, 0xB6, 0xC1), + "light salmon": (0xDD, 0xA0, 0x7A), + "light sea green": (0x20, 0xB2, 0xAA), + "light slate gray": (0x77, 0x88, 0x99), + "light steel blue": (0xB0, 0xC4, 0xDE), + "light wisteria": (0xC9, 0xA0, 0xDC), + "light yellow": (0xFF, 0xFF, 0xE0), + "lightning yellow": (0xFC, 0xC0, 0x1E), + "lilac bush": (0x98, 0x74, 0xD3), + "lilac": (0xC8, 0xA2, 0xC8), + "lily white": (0xE7, 0xF8, 0xFF), + "lily": (0xC8, 0xAA, 0xBF), + "lima": (0x76, 0xBD, 0x17), + "lime": (0xBF, 0xFF, 0x00), + "limeade": (0x6F, 0x9D, 0x02), + "limed ash": (0x74, 0x7D, 0x63), + "limed oak": (0xAC, 0x8A, 0x56), + "limed spruce": (0x39, 0x48, 0x51), + "linen": (0xFA, 0xF0, 0xE6), + "link water": (0xD9, 0xE4, 0xF5), + "lipstick": (0xAB, 0x05, 0x63), + "lisbon brown": (0x42, 0x39, 0x21), + "livid brown": (0x4D, 0x28, 0x2E), + "loafer": (0xEE, 0xF4, 0xDE), + "loblolly": (0xBD, 0xC9, 0xCE), + "lochinvar": (0x2C, 0x8C, 0x84), + "lochmara": (0x00, 0x7E, 0xC7), + "locust": (0xA8, 0xAF, 0x8E), + "log cabin": (0x24, 0x2A, 0x1D), + "logan": (0xAA, 0xA9, 0xCD), + "lola": (0xDF, 0xCF, 0xDB), + "london hue": (0xBE, 0xA6, 0xC3), + "lonestar": (0x6D, 0x01, 0x01), + "lotus": (0x86, 0x3C, 0x3C), + "loulou": (0x46, 0x0B, 0x41), + "lucky point": (0x1A, 0x1A, 0x68), + "lucky": (0xAF, 0x9F, 0x1C), + "lunar green": (0x3C, 0x49, 0x3A), + "luxor gold": (0xA7, 0x88, 0x2C), + "lynch": (0x69, 0x7E, 0x9A), + "mabel": (0xD9, 0xF7, 0xFF), + "macaroni and cheese": (0xFF, 0xB9, 0x7B), + "madang": (0xB7, 0xF0, 0xBE), + "madison": (0x09, 0x25, 0x5D), + "madras": (0x3F, 0x30, 0x02), + "magenta": (0xFF, 0x00, 0xFF), + "magic mint": (0xAA, 0xF0, 0xD1), + "magnolia": (0xF8, 0xF4, 0xFF), + "mahogany": (0x4E, 0x06, 0x06), + "mai tai": (0xB0, 0x66, 0x08), + "maize": (0xF5, 0xD5, 0xA0), + "makara": (0x89, 0x7D, 0x6D), + "mako": (0x44, 0x49, 0x54), + "malachite": (0x0B, 0xDA, 0x51), + "malibu": (0x7D, 0xC8, 0xF7), + "mallard": (0x23, 0x34, 0x18), + "malta": (0xBD, 0xB2, 0xA1), + "mamba": (0x8E, 0x81, 0x90), + "manatee": (0x8D, 0x90, 0xA1), + "mandalay": (0xAD, 0x78, 0x1B), + "mandy": (0xE2, 0x54, 0x65), + "mandys pink": (0xF2, 0xC3, 0xB2), + "mango tango": (0xE7, 0x72, 0x00), + "manhattan": (0xF5, 0xC9, 0x99), + "mantis": (0x74, 0xC3, 0x65), + "mantle": (0x8B, 0x9C, 0x90), + "manz": (0xEE, 0xEF, 0x78), + "mardi gras": (0x35, 0x00, 0x36), + "marigold yellow": (0xFB, 0xE8, 0x70), + "marigold": (0xB9, 0x8D, 0x28), + "mariner": (0x28, 0x6A, 0xCD), + "maroon flush": (0xC3, 0x21, 0x48), + "maroon oak": (0x52, 0x0C, 0x17), + "maroon": (0x80, 0x00, 0x00), + "marshland": (0x0B, 0x0F, 0x08), + "martini": (0xAF, 0xA0, 0x9E), + "martinique": (0x36, 0x30, 0x50), + "marzipan": (0xF8, 0xDB, 0x9D), + "masala": (0x40, 0x3B, 0x38), + "matisse": (0x1B, 0x65, 0x9D), + "matrix": (0xB0, 0x5D, 0x54), + "matterhorn": (0x4E, 0x3B, 0x41), + "mauve": (0xE0, 0xB0, 0xFF), + "mauvelous": (0xF0, 0x91, 0xA9), + "maverick": (0xD8, 0xC2, 0xD5), + "medium aquamarine": (0x66, 0xCD, 0xAA), + "medium blue": (0x00, 0x00, 0xCD), + "medium carmine": (0xAF, 0x40, 0x35), + "medium orchid": (0xBA, 0x55, 0xD3), + "medium purple": (0x93, 0x70, 0xDB), + "medium red violet": (0xBB, 0x33, 0x85), + "medium sea green": (0x3C, 0xB3, 0x71), + "medium slate blue": (0x7B, 0x68, 0xEE), + "medium spring green": (0x00, 0xFA, 0x9A), + "medium turquoise": (0x48, 0xD1, 0xCC), + "medium violet red": (0xC7, 0x15, 0x85), + "meerkat.cabin": (95, 0x00, 95), + "melanie": (0xE4, 0xC2, 0xD5), + "melanzane": (0x30, 0x05, 0x29), + "melon": (0xFE, 0xBA, 0xAD), + "melrose": (0xC7, 0xC1, 0xFF), + "mercury": (0xE5, 0xE5, 0xE5), + "merino": (0xF6, 0xF0, 0xE6), + "merlin": (0x41, 0x3C, 0x37), + "merlot": (0x83, 0x19, 0x23), + "metallic bronze": (0x49, 0x37, 0x1B), + "metallic copper": (0x71, 0x29, 0x1D), + "meteor": (0xD0, 0x7D, 0x12), + "meteorite": (0x3C, 0x1F, 0x76), + "mexican red": (0xA7, 0x25, 0x25), + "mid gray": (0x5F, 0x5F, 0x6E), + "midnight blue": (0x00, 0x33, 0x66), + "midnight moss": (0x04, 0x10, 0x04), + "midnight": (0x01, 0x16, 0x35), + "mikado": (0x2D, 0x25, 0x10), + "milan": (0xFA, 0xFF, 0xA4), + "milano red": (0xB8, 0x11, 0x04), + "milk punch": (0xFF, 0xF6, 0xD4), + "millbrook": (0x59, 0x44, 0x33), + "mimosa": (0xF8, 0xFD, 0xD3), + "mindaro": (0xE3, 0xF9, 0x88), + "mine shaft": (0x32, 0x32, 0x32), + "mineral green": (0x3F, 0x5D, 0x53), + "ming": (0x36, 0x74, 0x7D), + "minsk": (0x3F, 0x30, 0x7F), + "mint cream": (0xF5, 0xFF, 0xF1), + "mint green": (0x98, 0xFF, 0x98), + "mint julep": (0xF1, 0xEE, 0xC1), + "mint tulip": (0xC4, 0xF4, 0xEB), + "mirage": (0x16, 0x19, 0x28), + "mischka": (0xD1, 0xD2, 0xDD), + "mist gray": (0xC4, 0xC4, 0xBC), + "misty rose": (0xFF, 0xE4, 0xE1), + "mobster": (0x7F, 0x75, 0x89), + "moccaccino": (0x6E, 0x1D, 0x14), + "moccasin": (0xFF, 0xE4, 0xB5), + "mocha": (0x78, 0x2D, 0x19), + "mojo": (0xC0, 0x47, 0x37), + "mona lisa": (0xFF, 0xA1, 0x94), + "monarch": (0x8B, 0x07, 0x23), + "mondo": (0x4A, 0x3C, 0x30), + "mongoose": (0xB5, 0xA2, 0x7F), + "monsoon": (0x8A, 0x83, 0x89), + "monte carlo": (0x83, 0xD0, 0xC6), + "monza": (0xC7, 0x03, 0x1E), + "moody blue": (0x7F, 0x76, 0xD3), + "moon glow": (0xFC, 0xFE, 0xDA), + "moon mist": (0xDC, 0xDD, 0xCC), + "moon raker": (0xD6, 0xCE, 0xF6), + "morning glory": (0x9E, 0xDE, 0xE0), + "morocco brown": (0x44, 0x1D, 0x00), + "mortar": (0x50, 0x43, 0x51), + "mosque": (0x03, 0x6A, 0x6E), + "moss green": (0xAD, 0xDF, 0xAD), + "mountain meadow": (0x1A, 0xB3, 0x85), + "mountain mist": (0x95, 0x93, 0x96), + "mountbatten pink": (0x99, 0x7A, 0x8D), + "muddy waters": (0xB7, 0x8E, 0x5C), + "muesli": (0xAA, 0x8B, 0x5B), + "mulberry wood": (0x5C, 0x05, 0x36), + "mulberry": (0xC5, 0x4B, 0x8C), + "mule fawn": (0x8C, 0x47, 0x2F), + "mulled wine": (0x4E, 0x45, 0x62), + "mustard": (0xFF, 0xDB, 0x58), + "my pink": (0xD6, 0x91, 0x88), + "my sin": (0xFF, 0xB3, 0x1F), + "mystic": (0xE2, 0xEB, 0xED), + "nandor": (0x4B, 0x5D, 0x52), + "napa": (0xAC, 0xA4, 0x94), + "narvik": (0xED, 0xF9, 0xF1), + "natural gray": (0x8B, 0x86, 0x80), + "navajo white": (0xFF, 0xDE, 0xAD), + "navy blue": (0x00, 0x00, 0x80), + "navy": (0x00, 0x00, 0x80), + "nebula": (0xCB, 0xDB, 0xD6), + "negroni": (0xFF, 0xE2, 0xC5), + "neon carrot": (0xFF, 0x99, 0x33), + "nepal": (0x8E, 0xAB, 0xC1), + "neptune": (0x7C, 0xB7, 0xBB), + "nero": (0x14, 0x06, 0x00), + "nevada": (0x64, 0x6E, 0x75), + "new orleans": (0xF3, 0xD6, 0x9D), + "new york pink": (0xD7, 0x83, 0x7F), + "niagara": (0x06, 0xA1, 0x89), + "night rider": (0x1F, 0x12, 0x0F), + "night shadz": (0xAA, 0x37, 0x5A), + "nile blue": (0x19, 0x37, 0x51), + "nobel": (0xB7, 0xB1, 0xB1), + "nomad": (0xBA, 0xB1, 0xA2), + "norway": (0xA8, 0xBD, 0x9F), + "nugget": (0xC5, 0x99, 0x22), + "nutmeg wood finish": (0x68, 0x36, 0x00), + "nutmeg": (0x81, 0x42, 0x2C), + "oasis": (0xFE, 0xEF, 0xCE), + "observatory": (0x02, 0x86, 0x6F), + "ocean green": (0x41, 0xAA, 0x78), + "ochre": (0xCC, 0x77, 0x22), + "off green": (0xE6, 0xF8, 0xF3), + "off yellow": (0xFE, 0xF9, 0xE3), + "oil": (0x28, 0x1E, 0x15), + "old brick": (0x90, 0x1E, 0x1E), + "old copper": (0x72, 0x4A, 0x2F), + "old gold": (0xCF, 0xB5, 0x3B), + "old lace": (0xFD, 0xF5, 0xE6), + "old lavender": (0x79, 0x68, 0x78), + "old rose": (0xC0, 0x80, 0x81), + "olive drab": (0x6B, 0x8E, 0x23), + "olive green": (0xB5, 0xB3, 0x5C), + "olive haze": (0x8B, 0x84, 0x70), + "olive": (0x80, 0x80, 0x00), + "olivetone": (0x71, 0x6E, 0x10), + "olivine": (0x9A, 0xB9, 0x73), + "onahau": (0xCD, 0xF4, 0xFF), + "onion": (0x2F, 0x27, 0x0E), + "opal": (0xA9, 0xC6, 0xC2), + "opium": (0x8E, 0x6F, 0x70), + "oracle": (0x37, 0x74, 0x75), + "orange peel": (0xFF, 0xA0, 0x00), + "orange red": (0xFF, 0x45, 0x00), + "orange roughy": (0xC4, 0x57, 0x19), + "orange white": (0xFE, 0xFC, 0xED), + "orange": (0xFF, 0x68, 0x1F), + "orchid white": (0xFF, 0xFD, 0xF3), + "orchid": (0xDA, 0x70, 0xD6), + "oregon": (0x9B, 0x47, 0x03), + "orient": (0x01, 0x5E, 0x85), + "oriental pink": (0xC6, 0x91, 0x91), + "orinoco": (0xF3, 0xFB, 0xD4), + "oslo gray": (0x87, 0x8D, 0x91), + "ottoman": (0xE9, 0xF8, 0xED), + "outer space": (0x2D, 0x38, 0x3A), + "outrageous orange": (0xFF, 0x60, 0x37), + "oxford blue": (0x38, 0x45, 0x55), + "oxley": (0x77, 0x9E, 0x86), + "oyster bay": (0xDA, 0xFA, 0xFF), + "oyster pink": (0xE9, 0xCE, 0xCD), + "paarl": (0xA6, 0x55, 0x29), + "pablo": (0x77, 0x6F, 0x61), + "pacific blue": (0x00, 0x9D, 0xC4), + "pacifika": (0x77, 0x81, 0x20), + "paco": (0x41, 0x1F, 0x10), + "padua": (0xAD, 0xE6, 0xC4), + "pale canary": (0xFF, 0xFF, 0x99), + "pale goldenrod": (0xEE, 0xE8, 0xAA), + "pale green": (0x98, 0xFB, 0x98), + "pale leaf": (0xC0, 0xD3, 0xB9), + "pale oyster": (0x98, 0x8D, 0x77), + "pale prim": (0xFD, 0xFE, 0xB8), + "pale rose": (0xFF, 0xE1, 0xF2), + "pale sky": (0x6E, 0x77, 0x83), + "pale slate": (0xC3, 0xBF, 0xC1), + "pale turquoise": (0xAF, 0xEE, 0xEE), + "pale violet red": (0xDB, 0x70, 0x93), + "palm green": (0x09, 0x23, 0x0F), + "palm leaf": (0x19, 0x33, 0x0E), + "pampas": (0xF4, 0xF2, 0xEE), + "panache": (0xEA, 0xF6, 0xEE), + "pancho": (0xED, 0xCD, 0xAB), + "papaya whip": (0xFF, 0xEF, 0xD5), + "paprika": (0x8D, 0x02, 0x26), + "paradiso": (0x31, 0x7D, 0x82), + "parchment": (0xF1, 0xE9, 0xD2), + "paris daisy": (0xFF, 0xF4, 0x6E), + "paris m": (0x26, 0x05, 0x6A), + "paris white": (0xCA, 0xDC, 0xD4), + "parsley": (0x13, 0x4F, 0x19), + "pastel green": (0x77, 0xDD, 0x77), + "pastel pink": (0xFF, 0xD1, 0xDC), + "patina": (0x63, 0x9A, 0x8F), + "pattens blue": (0xDE, 0xF5, 0xFF), + "paua": (0x26, 0x03, 0x68), + "pavlova": (0xD7, 0xC4, 0x98), + "peach cream": (0xFF, 0xF0, 0xDB), + "peach orange": (0xFF, 0xCC, 0x99), + "peach puff": (0xFF, 0xDA, 0xB9), + "peach schnapps": (0xFF, 0xDC, 0xD6), + "peach yellow": (0xFA, 0xDF, 0xAD), + "peach": (0xFF, 0xE5, 0xB4), + "peanut": (0x78, 0x2F, 0x16), + "pear": (0xD1, 0xE2, 0x31), + "pearl bush": (0xE8, 0xE0, 0xD5), + "pearl lusta": (0xFC, 0xF4, 0xDC), + "peat": (0x71, 0x6B, 0x56), + "pelorous": (0x3E, 0xAB, 0xBF), + "peppermint": (0xE3, 0xF5, 0xE1), + "perano": (0xA9, 0xBE, 0xF2), + "perfume": (0xD0, 0xBE, 0xF8), + "periglacial blue": (0xE1, 0xE6, 0xD6), + "periwinkle gray": (0xC3, 0xCD, 0xE6), + "periwinkle": (0xCC, 0xCC, 0xFF), + "persian blue": (0x1C, 0x39, 0xBB), + "persian green": (0x00, 0xA6, 0x93), + "persian indigo": (0x32, 0x12, 0x7A), + "persian pink": (0xF7, 0x7F, 0xBE), + "persian plum": (0x70, 0x1C, 0x1C), + "persian red": (0xCC, 0x33, 0x33), + "persian rose": (0xFE, 0x28, 0xA2), + "persimmon": (0xFF, 0x6B, 0x53), + "peru tan": (0x7F, 0x3A, 0x02), + "peru": (0xCD, 0x85, 0x3F), + "pesto": (0x7C, 0x76, 0x31), + "petite orchid": (0xDB, 0x96, 0x90), + "pewter": (0x96, 0xA8, 0xA1), + "pharlap": (0xA3, 0x80, 0x7B), + "picasso": (0xFF, 0xF3, 0x9D), + "pickled bean": (0x6E, 0x48, 0x26), + "pickled bluewood": (0x31, 0x44, 0x59), + "picton blue": (0x45, 0xB1, 0xE8), + "pig pink": (0xFD, 0xD7, 0xE4), + "pigeon post": (0xAF, 0xBD, 0xD9), + "pigment indigo": (0x4B, 0x00, 0x82), + "pine cone": (0x6D, 0x5E, 0x54), + "pine glade": (0xC7, 0xCD, 0x90), + "pine green": (0x01, 0x79, 0x6F), + "pine tree": (0x17, 0x1F, 0x04), + "pink flamingo": (0xFF, 0x66, 0xFF), + "pink flare": (0xE1, 0xC0, 0xC8), + "pink lace": (0xFF, 0xDD, 0xF4), + "pink lady": (0xFF, 0xF1, 0xD8), + "pink salmon": (0xFF, 0x91, 0xA4), + "pink swan": (0xBE, 0xB5, 0xB7), + "pink": (0xFF, 0xC0, 0xCB), + "piper": (0xC9, 0x63, 0x23), + "pipi": (0xFE, 0xF4, 0xCC), + "pippin": (0xFF, 0xE1, 0xDF), + "pirate gold": (0xBA, 0x7F, 0x03), + "pistachio": (0x9D, 0xC2, 0x09), + "pixie green": (0xC0, 0xD8, 0xB6), + "pizazz": (0xFF, 0x90, 0x00), + "pizza": (0xC9, 0x94, 0x15), + "plantation": (0x27, 0x50, 0x4B), + "plum": (0x84, 0x31, 0x79), + "pohutukawa": (0x8F, 0x02, 0x1C), + "polar": (0xE5, 0xF9, 0xF6), + "polo blue": (0x8D, 0xA8, 0xCC), + "pomegranate": (0xF3, 0x47, 0x23), + "pompadour": (0x66, 0x00, 0x45), + "porcelain": (0xEF, 0xF2, 0xF3), + "porsche": (0xEA, 0xAE, 0x69), + "port gore": (0x25, 0x1F, 0x4F), + "portafino": (0xFF, 0xFF, 0xB4), + "portage": (0x8B, 0x9F, 0xEE), + "portica": (0xF9, 0xE6, 0x63), + "pot pourri": (0xF5, 0xE7, 0xE2), + "potters clay": (0x8C, 0x57, 0x38), + "powder ash": (0xBC, 0xC9, 0xC2), + "powder blue": (0xB0, 0xE0, 0xE6), + "prairie sand": (0x9A, 0x38, 0x20), + "prelude": (0xD0, 0xC0, 0xE5), + "prim": (0xF0, 0xE2, 0xEC), + "primrose": (0xED, 0xEA, 0x99), + "provincial pink": (0xFE, 0xF5, 0xF1), + "prussian blue": (0x00, 0x31, 0x53), + "puce": (0xCC, 0x88, 0x99), + "pueblo": (0x7D, 0x2C, 0x14), + "puerto rico": (0x3F, 0xC1, 0xAA), + "pumice": (0xC2, 0xCA, 0xC4), + "pumpkin skin": (0xB1, 0x61, 0x0B), + "pumpkin": (0xFF, 0x75, 0x18), + "punch": (0xDC, 0x43, 0x33), + "punga": (0x4D, 0x3D, 0x14), + "purple heart": (0x65, 0x2D, 0xC1), + "purple mountain's majesty": (0x96, 0x78, 0xB6), + "purple pizzazz": (0xFF, 0x00, 0xCC), + "purple": (0x66, 0x00, 0x99), + "putty": (0xE7, 0xCD, 0x8C), + "quarter pearl lusta": (0xFF, 0xFD, 0xF4), + "quarter spanish white": (0xF7, 0xF2, 0xE1), + "quicksand": (0xBD, 0x97, 0x8E), + "quill gray": (0xD6, 0xD6, 0xD1), + "quincy": (0x62, 0x3F, 0x2D), + "racing green": (0x0C, 0x19, 0x11), + "radical red": (0xFF, 0x35, 0x5E), + "raffia": (0xEA, 0xDA, 0xB8), + "rainee": (0xB9, 0xC8, 0xAC), + "rajah": (0xF7, 0xB6, 0x68), + "rangitoto": (0x2E, 0x32, 0x22), + "rangoon green": (0x1C, 0x1E, 0x13), + "raven": (0x72, 0x7B, 0x89), + "raw sienna": (0xD2, 0x7D, 0x46), + "raw umber": (0x73, 0x4A, 0x12), + "razzle dazzle rose": (0xFF, 0x33, 0xCC), + "razzmatazz": (0xE3, 0x0B, 0x5C), + "rebecca purple": (0x66, 0x33, 0x99), + "rebel": (0x3C, 0x12, 0x06), + "red beech": (0x7B, 0x38, 0x01), + "red berry": (0x8E, 0x00, 0x00), + "red damask": (0xDA, 0x6A, 0x41), + "red devil": (0x86, 0x01, 0x11), + "red orange": (0xFF, 0x3F, 0x34), + "red oxide": (0x6E, 0x09, 0x02), + "red ribbon": (0xED, 0x0A, 0x3F), + "red robin": (0x80, 0x34, 0x1F), + "red stage": (0xD0, 0x5F, 0x04), + "red violet": (0xC7, 0x15, 0x85), + "red": (0xFF, 0x00, 0x00), + "redwood": (0x5D, 0x1E, 0x0F), + "reef gold": (0x9F, 0x82, 0x1C), + "reef": (0xC9, 0xFF, 0xA2), + "regal blue": (0x01, 0x3F, 0x6A), + "regent gray": (0x86, 0x94, 0x9F), + "regent st blue": (0xAA, 0xD6, 0xE6), + "remy": (0xFE, 0xEB, 0xF3), + "reno sand": (0xA8, 0x65, 0x15), + "resolution blue": (0x00, 0x23, 0x87), + "revolver": (0x2C, 0x16, 0x32), + "rhino": (0x2E, 0x3F, 0x62), + "rice cake": (0xFF, 0xFE, 0xF0), + "rice flower": (0xEE, 0xFF, 0xE2), + "rich gold": (0xA8, 0x53, 0x07), + "rio grande": (0xBB, 0xD0, 0x09), + "ripe lemon": (0xF4, 0xD8, 0x1C), + "ripe plum": (0x41, 0x00, 0x56), + "riptide": (0x8B, 0xE6, 0xD8), + "river bed": (0x43, 0x4C, 0x59), + "rob roy": (0xEA, 0xC6, 0x74), + "robin's egg blue": (0x00, 0xCC, 0xCC), + "rock blue": (0x9E, 0xB1, 0xCD), + "rock spray": (0xBA, 0x45, 0x0C), + "rock": (0x4D, 0x38, 0x33), + "rodeo dust": (0xC9, 0xB2, 0x9B), + "rolling stone": (0x74, 0x7D, 0x83), + "roman coffee": (0x79, 0x5D, 0x4C), + "roman": (0xDE, 0x63, 0x60), + "romance": (0xFF, 0xFE, 0xFD), + "romantic": (0xFF, 0xD2, 0xB7), + "ronchi": (0xEC, 0xC5, 0x4E), + "roof terracotta": (0xA6, 0x2F, 0x20), + "rope": (0x8E, 0x4D, 0x1E), + "rose bud cherry": (0x80, 0x0B, 0x47), + "rose bud": (0xFB, 0xB2, 0xA3), + "rose fog": (0xE7, 0xBC, 0xB4), + "rose of sharon": (0xBF, 0x55, 0x00), + "rose white": (0xFF, 0xF6, 0xF5), + "rose": (0xFF, 0x00, 0x7F), + "rosewood": (0x65, 0x00, 0x0B), + "rosy blue": (0xBC, 0x8F, 0x8F), + "roti": (0xC6, 0xA8, 0x4B), + "rouge": (0xA2, 0x3B, 0x6C), + "royal blue": (0x41, 0x69, 0xE1), + "royal heath": (0xAB, 0x34, 0x72), + "royal purple": (0x6B, 0x3F, 0xA0), + "rpi": (208, 95, 0), + "rum swizzle": (0xF9, 0xF8, 0xE4), + "rum": (0x79, 0x69, 0x89), + "russet": (0x80, 0x46, 0x1B), + "russett": (0x75, 0x5A, 0x57), + "rust": (0xB7, 0x41, 0x0E), + "rustic red": (0x48, 0x04, 0x04), + "rusty nail": (0x86, 0x56, 0x0A), + "saddle brown": (0x58, 0x34, 0x01), + "saddle": (0x4C, 0x30, 0x24), + "saffron mango": (0xF9, 0xBF, 0x58), + "saffron": (0xF4, 0xC4, 0x30), + "sage": (0x9E, 0xA5, 0x87), + "sahara sand": (0xF1, 0xE7, 0x88), + "sahara": (0xB7, 0xA2, 0x14), + "sail": (0xB8, 0xE0, 0xF9), + "salem": (0x09, 0x7F, 0x4B), + "salmon": (0xFF, 0x8C, 0x69), + "salomie": (0xFE, 0xDB, 0x8D), + "salt box": (0x68, 0x5E, 0x6E), + "saltpan": (0xF1, 0xF7, 0xF2), + "sambuca": (0x3A, 0x20, 0x10), + "san felix": (0x0B, 0x62, 0x07), + "san juan": (0x30, 0x4B, 0x6A), + "san marino": (0x45, 0x6C, 0xAC), + "sand dune": (0x82, 0x6F, 0x65), + "sandal": (0xAA, 0x8D, 0x6F), + "sandrift": (0xAB, 0x91, 0x7A), + "sandstone": (0x79, 0x6D, 0x62), + "sandwisp": (0xF5, 0xE7, 0xA2), + "sandy beach": (0xFF, 0xEA, 0xC8), + "sandy brown": (0xF4, 0xA4, 0x60), + "sangria": (0x92, 0x00, 0x0A), + "sanguine brown": (0x8D, 0x3D, 0x38), + "santa fe": (0xB1, 0x6D, 0x52), + "santas gray": (0x9F, 0xA0, 0xB1), + "sapling": (0xDE, 0xD4, 0xA4), + "sapphire": (0x2F, 0x51, 0x9E), + "saratoga": (0x55, 0x5B, 0x10), + "satin linen": (0xE6, 0xE4, 0xD4), + "sauvignon": (0xFF, 0xF5, 0xF3), + "sazerac": (0xFF, 0xF4, 0xE0), + "scampi": (0x67, 0x5F, 0xA6), + "scandal": (0xCF, 0xFA, 0xF4), + "scarlet gum": (0x43, 0x15, 0x60), + "scarlet": (0xFF, 0x24, 0x00), + "scarlett": (0x95, 0x00, 0x15), + "scarpa flow": (0x58, 0x55, 0x62), + "schist": (0xA9, 0xB4, 0x97), + "school bus yellow": (0xFF, 0xD8, 0x00), + "schooner": (0x8B, 0x84, 0x7E), + "science blue": (0x00, 0x66, 0xCC), + "scooter": (0x2E, 0xBF, 0xD4), + "scorpion": (0x69, 0x5F, 0x62), + "scotch mist": (0xFF, 0xFB, 0xDC), + "screamin' green": (0x66, 0xFF, 0x66), + "sea buckthorn": (0xFB, 0xA1, 0x29), + "sea green": (0x2E, 0x8B, 0x57), + "sea mist": (0xC5, 0xDB, 0xCA), + "sea nymph": (0x78, 0xA3, 0x9C), + "sea pink": (0xED, 0x98, 0x9E), + "seagull": (0x80, 0xCC, 0xEA), + "seance": (0x73, 0x1E, 0x8F), + "seashell peach": (0xFF, 0xF5, 0xEE), + "seashell": (0xF1, 0xF1, 0xF1), + "seaweed": (0x1B, 0x2F, 0x11), + "selago": (0xF0, 0xEE, 0xFD), + "selective yellow": (0xFF, 0xBA, 0x00), + "sepia black": (0x2B, 0x02, 0x02), + "sepia skin": (0x9E, 0x5B, 0x40), + "sepia": (0x70, 0x42, 0x14), + "serenade": (0xFF, 0xF4, 0xE8), + "shadow green": (0x9A, 0xC2, 0xB8), + "shadow": (0x83, 0x70, 0x50), + "shady lady": (0xAA, 0xA5, 0xA9), + "shakespeare": (0x4E, 0xAB, 0xD1), + "shalimar": (0xFB, 0xFF, 0xBA), + "shamrock": (0x33, 0xCC, 0x99), + "shark": (0x25, 0x27, 0x2C), + "sherpa blue": (0x00, 0x49, 0x50), + "sherwood green": (0x02, 0x40, 0x2C), + "shilo": (0xE8, 0xB9, 0xB3), + "shingle fawn": (0x6B, 0x4E, 0x31), + "ship cove": (0x78, 0x8B, 0xBA), + "ship gray": (0x3E, 0x3A, 0x44), + "shiraz": (0xB2, 0x09, 0x31), + "shocking pink": (0xFC, 0x0F, 0xC0), + "shocking": (0xE2, 0x92, 0xC0), + "shuttle gray": (0x5F, 0x66, 0x72), + "siam": (0x64, 0x6A, 0x54), + "sidecar": (0xF3, 0xE7, 0xBB), + "sienna": (0xA0, 0x52, 0x2D), + "silk": (0xBD, 0xB1, 0xA8), + "silver chalice": (0xAC, 0xAC, 0xAC), + "silver rust": (0xC9, 0xC0, 0xBB), + "silver sand": (0xBF, 0xC1, 0xC2), + "silver tree": (0x66, 0xB5, 0x8F), + "silver": (0xC0, 0xC0, 0xC0), + "sinbad": (0x9F, 0xD7, 0xD3), + "siren": (0x7A, 0x01, 0x3A), + "sirocco": (0x71, 0x80, 0x80), + "sisal": (0xD3, 0xCB, 0xBA), + "skeptic": (0xCA, 0xE6, 0xDA), + "sky blue": (0x76, 0xD7, 0xEA), + "slate blue": (0x6A, 0x5A, 0xCD), + "slate gray": (0x70, 0x80, 0x90), + "smalt blue": (0x51, 0x80, 0x8F), + "smalt": (0x00, 0x33, 0x99), + "smoky": (0x60, 0x5B, 0x73), + "snow drift": (0xF7, 0xFA, 0xF7), + "snow flurry": (0xE4, 0xFF, 0xD1), + "snow": (0xFF, 0xFA, 0xFA), + "snowy mint": (0xD6, 0xFF, 0xDB), + "snuff": (0xE2, 0xD8, 0xED), + "soapstone": (0xFF, 0xFB, 0xF9), + "soft amber": (0xD1, 0xC6, 0xB4), + "soft peach": (0xF5, 0xED, 0xEF), + "solid pink": (0x89, 0x38, 0x43), + "solitaire": (0xFE, 0xF8, 0xE2), + "solitude": (0xEA, 0xF6, 0xFF), + "sorbus": (0xFD, 0x7C, 0x07), + "sorrell brown": (0xCE, 0xB9, 0x8F), + "soya bean": (0x6A, 0x60, 0x51), + "spanish green": (0x81, 0x98, 0x85), + "spectra": (0x2F, 0x5A, 0x57), + "spice": (0x6A, 0x44, 0x2E), + "spicy mix": (0x88, 0x53, 0x42), + "spicy mustard": (0x74, 0x64, 0x0D), + "spicy pink": (0x81, 0x6E, 0x71), + "spindle": (0xB6, 0xD1, 0xEA), + "spray": (0x79, 0xDE, 0xEC), + "spring green": (0x00, 0xFF, 0x7F), + "spring leaves": (0x57, 0x83, 0x63), + "spring rain": (0xAC, 0xCB, 0xB1), + "spring sun": (0xF6, 0xFF, 0xDC), + "spring wood": (0xF8, 0xF6, 0xF1), + "sprout": (0xC1, 0xD7, 0xB0), + "spun pearl": (0xAA, 0xAB, 0xB7), + "squirrel": (0x8F, 0x81, 0x76), + "st tropaz": (0x2D, 0x56, 0x9B), + "stack": (0x8A, 0x8F, 0x8A), + "star dust": (0x9F, 0x9F, 0x9C), + "stark white": (0xE5, 0xD7, 0xBD), + "starship": (0xEC, 0xF2, 0x45), + "steel blue": (0x46, 0x82, 0xB4), + "steel gray": (0x26, 0x23, 0x35), + "stiletto": (0x9C, 0x33, 0x36), + "stonewall": (0x92, 0x85, 0x73), + "storm dust": (0x64, 0x64, 0x63), + "storm gray": (0x71, 0x74, 0x86), + "stratos": (0x00, 0x07, 0x41), + "straw": (0xD4, 0xBF, 0x8D), + "strikemaster": (0x95, 0x63, 0x87), + "stromboli": (0x32, 0x5D, 0x52), + "studio": (0x71, 0x4A, 0xB2), + "submarine": (0xBA, 0xC7, 0xC9), + "sugar cane": (0xF9, 0xFF, 0xF6), + "sulu": (0xC1, 0xF0, 0x7C), + "summer green": (0x96, 0xBB, 0xAB), + "sun": (0xFB, 0xAC, 0x13), + "sundance": (0xC9, 0xB3, 0x5B), + "sundown": (0xFF, 0xB1, 0xB3), + "sunflower": (0xE4, 0xD4, 0x22), + "sunglo": (0xE1, 0x68, 0x65), + "sunglow": (0xFF, 0xCC, 0x33), + "sunset orange": (0xFE, 0x4C, 0x40), + "sunshade": (0xFF, 0x9E, 0x2C), + "supernova": (0xFF, 0xC9, 0x01), + "surf crest": (0xCF, 0xE5, 0xD2), + "surf": (0xBB, 0xD7, 0xC1), + "surfie green": (0x0C, 0x7A, 0x79), + "sushi": (0x87, 0xAB, 0x39), + "suva gray": (0x88, 0x83, 0x87), + "swamp green": (0xAC, 0xB7, 0x8E), + "swamp": (0x00, 0x1B, 0x1C), + "swans down": (0xDC, 0xF0, 0xEA), + "sweet corn": (0xFB, 0xEA, 0x8C), + "sweet pink": (0xFD, 0x9F, 0xA2), + "swirl": (0xD3, 0xCD, 0xC5), + "swiss coffee": (0xDD, 0xD6, 0xD5), + "sycamore": (0x90, 0x8D, 0x39), + "tabasco": (0xA0, 0x27, 0x12), + "tacao": (0xED, 0xB3, 0x81), + "tacha": (0xD6, 0xC5, 0x62), + "tahiti gold": (0xE9, 0x7C, 0x07), + "tahuna sands": (0xEE, 0xF0, 0xC8), + "tall poppy": (0xB3, 0x2D, 0x29), + "tallow": (0xA8, 0xA5, 0x89), + "tamarillo": (0x99, 0x16, 0x13), + "tamarind": (0x34, 0x15, 0x15), + "tan hide": (0xFA, 0x9D, 0x5A), + "tan": (0xD2, 0xB4, 0x8C), + "tana": (0xD9, 0xDC, 0xC1), + "tangaroa": (0x03, 0x16, 0x3C), + "tangerine": (0xF2, 0x85, 0x00), + "tango": (0xED, 0x7A, 0x1C), + "tapa": (0x7B, 0x78, 0x74), + "tapestry": (0xB0, 0x5E, 0x81), + "tara": (0xE1, 0xF6, 0xE8), + "tarawera": (0x07, 0x3A, 0x50), + "tasman": (0xCF, 0xDC, 0xCF), + "taupe gray": (0xB3, 0xAF, 0x95), + "taupe": (0x48, 0x3C, 0x32), + "tawny port": (0x69, 0x25, 0x45), + "te papa green": (0x1E, 0x43, 0x3C), + "tea green": (0xD0, 0xF0, 0xC0), + "tea": (0xC1, 0xBA, 0xB0), + "teak": (0xB1, 0x94, 0x61), + "teal blue": (0x04, 0x42, 0x59), + "teal": (0x00, 0x80, 0x80), + "temptress": (0x3B, 0x00, 0x0B), + "tenn": (0xCD, 0x57, 0x00), + "tequila": (0xFF, 0xE6, 0xC7), + "terracotta": (0xE2, 0x72, 0x5B), + "texas rose": (0xFF, 0xB5, 0x55), + "texas": (0xF8, 0xF9, 0x9C), + "thatch green": (0x40, 0x3D, 0x19), + "thatch": (0xB6, 0x9D, 0x98), + "thistle green": (0xCC, 0xCA, 0xA8), + "thistle": (0xD8, 0xBF, 0xD8), + "thunder": (0x33, 0x29, 0x2F), + "thunderbird": (0xC0, 0x2B, 0x18), + "tia maria": (0xC1, 0x44, 0x0E), + "tiara": (0xC3, 0xD1, 0xD1), + "tiber": (0x06, 0x35, 0x37), + "tickle me pink": (0xFC, 0x80, 0xA5), + "tidal": (0xF1, 0xFF, 0xAD), + "tide": (0xBF, 0xB8, 0xB0), + "timber green": (0x16, 0x32, 0x2C), + "timberwolf": (0xD9, 0xD6, 0xCF), + "titan white": (0xF0, 0xEE, 0xFF), + "toast": (0x9A, 0x6E, 0x61), + "tobacco brown": (0x71, 0x5D, 0x47), + "toledo": (0x3A, 0x00, 0x20), + "tolopea": (0x1B, 0x02, 0x45), + "tom thumb": (0x3F, 0x58, 0x3B), + "tomato": (0xFF, 0x63, 0x47), + "tonys pink": (0xE7, 0x9F, 0x8C), + "topaz": (0x7C, 0x77, 0x8A), + "torch red": (0xFD, 0x0E, 0x35), + "torea bay": (0x0F, 0x2D, 0x9E), + "tory blue": (0x14, 0x50, 0xAA), + "tosca": (0x8D, 0x3F, 0x3F), + "totem pole": (0x99, 0x1B, 0x07), + "tower gray": (0xA9, 0xBD, 0xBF), + "tradewind": (0x5F, 0xB3, 0xAC), + "tranquil": (0xE6, 0xFF, 0xFF), + "travertine": (0xFF, 0xFD, 0xE8), + "tree poppy": (0xFC, 0x9C, 0x1D), + "treehouse": (0x3B, 0x28, 0x20), + "trendy green": (0x7C, 0x88, 0x1A), + "trendy pink": (0x8C, 0x64, 0x95), + "trinidad": (0xE6, 0x4E, 0x03), + "tropical blue": (0xC3, 0xDD, 0xF9), + "tropical rain forest": (0x00, 0x75, 0x5E), + "trout": (0x4A, 0x4E, 0x5A), + "true v": (0x8A, 0x73, 0xD6), + "tuatara": (0x36, 0x35, 0x34), + "tuft bush": (0xFF, 0xDD, 0xCD), + "tulip tree": (0xEA, 0xB3, 0x3B), + "tumbleweed": (0xDE, 0xA6, 0x81), + "tuna": (0x35, 0x35, 0x42), + "tundora": (0x4A, 0x42, 0x44), + "turbo": (0xFA, 0xE6, 0x00), + "turkish rose": (0xB5, 0x72, 0x81), + "turmeric": (0xCA, 0xBB, 0x48), + "turquoise blue": (0x6C, 0xDA, 0xE7), + "turquoise": (0x30, 0xD5, 0xC8), + "turtle green": (0x2A, 0x38, 0x0B), + "tuscany": (0xBD, 0x5E, 0x2E), + "tusk": (0xEE, 0xF3, 0xC3), + "tussock": (0xC5, 0x99, 0x4B), + "tutu": (0xFF, 0xF1, 0xF9), + "twilight blue": (0xEE, 0xFD, 0xFF), + "twilight": (0xE4, 0xCF, 0xDE), + "twine": (0xC2, 0x95, 0x5D), + "tyrian purple": (0x66, 0x02, 0x3C), + "ultramarine": (0x12, 0x0A, 0x8F), + "valencia": (0xD8, 0x44, 0x37), + "valentino": (0x35, 0x0E, 0x42), + "valhalla": (0x2B, 0x19, 0x4F), + "van cleef": (0x49, 0x17, 0x0C), + "vanilla ice": (0xF3, 0xD9, 0xDF), + "vanilla": (0xD1, 0xBE, 0xA8), + "varden": (0xFF, 0xF6, 0xDF), + "venetian red": (0x72, 0x01, 0x0F), + "venice blue": (0x05, 0x59, 0x89), + "venus": (0x92, 0x85, 0x90), + "verdigris": (0x5D, 0x5E, 0x37), + "verdun green": (0x49, 0x54, 0x00), + "vermilion": (0xFF, 0x4D, 0x00), + "vesuvius": (0xB1, 0x4A, 0x0B), + "victoria": (0x53, 0x44, 0x91), + "vida loca": (0x54, 0x90, 0x19), + "viking": (0x64, 0xCC, 0xDB), + "vin rouge": (0x98, 0x3D, 0x61), + "viola": (0xCB, 0x8F, 0xA9), + "violent violet": (0x29, 0x0C, 0x5E), + "violet eggplant": (0x99, 0x11, 0x99), + "violet red": (0xF7, 0x46, 0x8A), + "violet": (0x24, 0x0A, 0x40), + "viridian green": (0x67, 0x89, 0x75), + "viridian": (0x40, 0x82, 0x6D), + "vis vis": (0xFF, 0xEF, 0xA1), + "vista blue": (0x8F, 0xD6, 0xB4), + "vista white": (0xFC, 0xF8, 0xF7), + "vivid tangerine": (0xFF, 0x99, 0x80), + "vivid violet": (0x80, 0x37, 0x90), + "voodoo": (0x53, 0x34, 0x55), + "vulcan": (0x10, 0x12, 0x1D), + "wafer": (0xDE, 0xCB, 0xC6), + "waikawa gray": (0x5A, 0x6E, 0x9C), + "waiouru": (0x36, 0x3C, 0x0D), + "walnut": (0x77, 0x3F, 0x1A), + "wannabe.house": (0x00, 0x00, 95), + "wasabi": (0x78, 0x8A, 0x25), + "water leaf": (0xA1, 0xE9, 0xDE), + "watercourse": (0x05, 0x6F, 0x57), + "waterloo ": (0x7B, 0x7C, 0x94), + "wattle": (0xDC, 0xD7, 0x47), + "watusi": (0xFF, 0xDD, 0xCF), + "wax flower": (0xFF, 0xC0, 0xA8), + "we peep": (0xF7, 0xDB, 0xE6), + "web orange": (0xFF, 0xA5, 0x00), + "wedgewood": (0x4E, 0x7F, 0x9E), + "well read": (0xB4, 0x33, 0x32), + "west coast": (0x62, 0x51, 0x19), + "west side": (0xFF, 0x91, 0x0F), + "westar": (0xDC, 0xD9, 0xD2), + "wewak": (0xF1, 0x9B, 0xAB), + "wheat": (0xF5, 0xDE, 0xB3), + "wheatfield": (0xF3, 0xED, 0xCF), + "whiskey": (0xD5, 0x9A, 0x6F), + "whisper": (0xF7, 0xF5, 0xFA), + "white ice": (0xDD, 0xF9, 0xF1), + "white lilac": (0xF8, 0xF7, 0xFC), + "white linen": (0xF8, 0xF0, 0xE8), + "white pointer": (0xFE, 0xF8, 0xFF), + "white rock": (0xEA, 0xE8, 0xD4), + "white smoke": (0xF5, 0xF5, 0xF5), + "white": (0xFF, 0xFF, 0xFF), + "wild blue yonder": (0x7A, 0x89, 0xB8), + "wild rice": (0xEC, 0xE0, 0x90), + "wild sand": (0xF4, 0xF4, 0xF4), + "wild strawberry": (0xFF, 0x33, 0x99), + "wild watermelon": (0xFD, 0x5B, 0x78), + "wild willow": (0xB9, 0xC4, 0x6A), + "william": (0x3A, 0x68, 0x6C), + "willow brook": (0xDF, 0xEC, 0xDA), + "willow grove": (0x65, 0x74, 0x5D), + "windsor": (0x3C, 0x08, 0x78), + "wine berry": (0x59, 0x1D, 0x35), + "winter hazel": (0xD5, 0xD1, 0x95), + "wisp pink": (0xFE, 0xF4, 0xF8), + "wisteria": (0x97, 0x71, 0xB5), + "wistful": (0xA4, 0xA6, 0xD3), + "witch haze": (0xFF, 0xFC, 0x99), + "wood bark": (0x26, 0x11, 0x05), + "woodland": (0x4D, 0x53, 0x28), + "woodrush": (0x30, 0x2A, 0x0F), + "woodsmoke": (0x0C, 0x0D, 0x0F), + "woody brown": (0x48, 0x31, 0x31), + "xanadu": (0x73, 0x86, 0x78), + "yellow green": (0xC5, 0xE1, 0x7A), + "yellow metal": (0x71, 0x63, 0x38), + "yellow orange": (0xFF, 0xAE, 0x42), + "yellow sea": (0xFE, 0xA9, 0x04), + "yellow": (0xFF, 0xFF, 0x00), + "your pink": (0xFF, 0xC3, 0xC0), + "yukon gold": (0x7B, 0x66, 0x08), + "yuma": (0xCE, 0xC2, 0x91), + "zambezi": (0x68, 0x55, 0x58), + "zanah": (0xDA, 0xEC, 0xD6), + "zest": (0xE5, 0x84, 0x1B), + "zeus": (0x29, 0x23, 0x19), + "ziggurat": (0xBF, 0xDB, 0xE2), + "zinnwaldite": (0xEB, 0xC2, 0xAF), + "zircon": (0xF4, 0xF8, 0xFF), + "zombie": (0xE4, 0xD6, 0x9B), + "zorba": (0xA5, 0x9B, 0x91), + "zuccini": (0x04, 0x40, 0x22), + "zumthor": (0xED, 0xF6, 0xFF), +} + + +def clear() -> str: + """Clear screen ANSI escape sequence""" + return "" + + +def clear_screen() -> str: + """Clear screen ANSI escape sequence""" + return "" + + +def clear_line() -> str: + """Clear the current line ANSI escape sequence""" + return "\r" + + +def reset() -> str: + """Reset text attributes to 'normal'""" + return "" + + +def normal() -> str: + """Reset text attributes to 'normal'""" + return "" + + +def bold() -> str: + """Set text to bold""" + return "" + + +def italic() -> str: + """Set text to italic""" + return "" + + +def italics() -> str: + """Set text to italic""" + return italic() + + +def underline() -> str: + """Set text to underline""" + return "" + + +def strikethrough() -> str: + """Set text to strikethrough""" + return "" + + +def strike_through() -> str: + """Set text to strikethrough""" + return strikethrough() + + +def is_16color(num: int) -> bool: + """Is num a valid 16 color number?""" + return num in (255, 128) + + +def is_216color(num: int) -> bool: + """Is num a valid 256 color number?""" + return num in set([0, 95, 135, 175, 223, 255]) + + +def _simple_color_number(red: int, green: int, blue: int) -> int: + """Construct a simple color number""" + r = red > 0 + g = green > 0 + b = blue > 0 + return b << 2 | g << 1 | r + + +def fg_16color(red: int, green: int, blue: int) -> str: + """Set foreground color using 16 color mode""" + code = _simple_color_number(red, green, blue) + 30 + bright_count = 0 + if red > 128: + bright_count += 1 + if green > 128: + bright_count += 1 + if blue > 128: + bright_count += 1 + if bright_count > 1: + code += 60 + return f"[{code}m" + + +def bg_16color(red: int, green: int, blue: int) -> str: + """Set background using 16 color mode""" + code = _simple_color_number(red, green, blue) + 40 + bright_count = 0 + if red > 128: + bright_count += 1 + if green > 128: + bright_count += 1 + if blue > 128: + bright_count += 1 + if bright_count > 1: + code += 60 + return f"[{code}m" + + +def _pixel_to_216color(n: int) -> int: + if n >= 255: + return 5 + if n >= 233: + return 4 + if n >= 175: + return 3 + if n >= 135: + return 2 + if n >= 95: + return 1 + return 0 + + +def fg_216color(red: int, green: int, blue: int) -> str: + """Set foreground using 216 color mode""" + r = _pixel_to_216color(red) + g = _pixel_to_216color(green) + b = _pixel_to_216color(blue) + code = 16 + r * 36 + g * 6 + b + return f"[38;5;{code}m" + + +def bg_216color(red: int, green: int, blue: int) -> str: + """Set background using 216 color mode""" + r = _pixel_to_216color(red) + g = _pixel_to_216color(green) + b = _pixel_to_216color(blue) + code = 16 + r * 36 + g * 6 + b + return f"[48;5;{code}m" + + +def fg_24bit(red: int, green: int, blue: int) -> str: + """Set foreground using 24bit color mode""" + return f"[38;2;{red};{green};{blue}m" + + +def bg_24bit(red: int, green: int, blue: int) -> str: + """Set background using 24bit color mode""" + return f"[48;2;{red};{green};{blue}m" + + +def _find_color_by_name(name: str) -> Tuple[int, int, int]: + rgb = COLOR_NAMES_TO_RGB.get(name.lower(), None) + if rgb is None: + name = guess_name(name) + rgb = COLOR_NAMES_TO_RGB.get(name.lower(), None) + assert rgb is not None + return rgb + + +@logging_utils.squelch_repeated_log_messages(1) +def fg( + name: Optional[str] = "", + red: Optional[int] = None, + green: Optional[int] = None, + blue: Optional[int] = None, + *, + force_16color: bool = False, + force_216color: bool = False, +) -> str: + """Return the ANSI escape sequence to change the foreground color + being printed. Target colors may be indicated by name or R/G/B. + Result will use the 16 or 216 color scheme if force_16color or + force_216color are passed (respectively). Otherwise the code will + do what it thinks best. + + Args: + name: the name of the color to set + red: the color to set's red component value + green: the color to set's green component value + blue: the color to set's blue component value + force_16color: force fg to use 16 color mode + force_216color: force fg to use 216 color mode + + Returns: + String containing the ANSI escape sequence to set desired foreground + + >>> import string_utils as su + >>> su.to_base64(fg('blue')) + b'G1szODs1OzIxbQ==\\n' + """ + if name is not None and name == 'reset': + return '\033[39m' + + if name is not None and string_utils.is_full_string(name): + rgb = _find_color_by_name(name) + return fg( + None, + rgb[0], + rgb[1], + rgb[2], + force_16color=force_16color, + force_216color=force_216color, + ) + + if red is None: + red = 0 + if green is None: + green = 0 + if blue is None: + blue = 0 + if (is_16color(red) and is_16color(green) and is_16color(blue)) or force_16color: + logger.debug("Using 16-color strategy") + return fg_16color(red, green, blue) + if ( + is_216color(red) and is_216color(green) and is_216color(blue) + ) or force_216color: + logger.debug("Using 216-color strategy") + return fg_216color(red, green, blue) + logger.debug("Using 24-bit color strategy") + return fg_24bit(red, green, blue) + + +def reset_fg(): + return '\033[39m' + + +def _rgb_to_yiq(rgb: Tuple[int, int, int]) -> int: + return (rgb[0] * 299 + rgb[1] * 587 + rgb[2] * 114) // 1000 + + +def _contrast(rgb: Tuple[int, int, int]) -> Tuple[int, int, int]: + if _rgb_to_yiq(rgb) < 128: + return (0xFF, 0xFF, 0xFF) + return (0, 0, 0) + + +def pick_contrasting_color( + name: Optional[str] = "", + red: Optional[int] = None, + green: Optional[int] = None, + blue: Optional[int] = None, +) -> Tuple[int, int, int]: + """This method will return a red, green, blue tuple representing a + contrasting color given the red, green, blue of a background + color or a color name of the background color. + + Args: + name: the name of the color to contrast + red: the color to contrast's red component value + green: the color to contrast's green component value + blue: the color to contrast's blue component value + + Returns: + An RGB tuple containing a contrasting color + + >>> pick_contrasting_color(None, 20, 20, 20) + (255, 255, 255) + + >>> pick_contrasting_color("white") + (0, 0, 0) + + """ + if name is not None and string_utils.is_full_string(name): + rgb = _find_color_by_name(name) + else: + r = red if red is not None else 0 + g = green if green is not None else 0 + b = blue if blue is not None else 0 + rgb = (r, g, b) + assert rgb is not None + return _contrast(rgb) + + +def guess_name(name: str) -> str: + """Try to guess what color the user is talking about""" + best_guess = None + max_ratio = None + for possibility in COLOR_NAMES_TO_RGB: + r = difflib.SequenceMatcher(None, name, possibility).ratio() + if max_ratio is None or r > max_ratio: + max_ratio = r + best_guess = possibility + assert best_guess is not None + logger.debug("Best guess at color name is %s", best_guess) + return best_guess + + +@logging_utils.squelch_repeated_log_messages(1) +def bg( + name: Optional[str] = "", + red: Optional[int] = None, + green: Optional[int] = None, + blue: Optional[int] = None, + *, + force_16color: bool = False, + force_216color: bool = False, +) -> str: + """Returns an ANSI color code for changing the current background + color. + + Args: + name: the name of the color to set + red: the color to set's red component value + green: the color to set's green component value + blue: the color to set's blue component value + force_16color: force bg to use 16 color mode + force_216color: force bg to use 216 color mode + + >>> import string_utils as su + >>> su.to_base64(bg("red")) # b'\x1b[48;5;196m' + b'G1s0ODs1OzE5Nm0=\\n' + """ + if name is not None and name == 'reset': + return '\033[49m' + + if name is not None and string_utils.is_full_string(name): + rgb = _find_color_by_name(name) + return bg( + None, + rgb[0], + rgb[1], + rgb[2], + force_16color=force_16color, + force_216color=force_216color, + ) + if red is None: + red = 0 + if green is None: + green = 0 + if blue is None: + blue = 0 + if (is_16color(red) and is_16color(green) and is_16color(blue)) or force_16color: + logger.debug("Using 16-color strategy") + return bg_16color(red, green, blue) + if ( + is_216color(red) and is_216color(green) and is_216color(blue) + ) or force_216color: + logger.debug("Using 216-color strategy") + return bg_216color(red, green, blue) + logger.debug("Using 24-bit color strategy") + return bg_24bit(red, green, blue) + + +def reset_bg(): + return '\033[49m' + + +class StdoutInterceptor(io.TextIOBase, contextlib.AbstractContextManager): + """An interceptor for data written to stdout. Use as a context.""" + + def __init__(self): + super().__init__() + self.saved_stdout: io.TextIO = None + self.buf = '' + + @abstractmethod + def write(self, s: str): + pass + + def __enter__(self): + self.saved_stdout = sys.stdout + sys.stdout = self + return self + + def __exit__(self, *args) -> Literal[False]: + sys.stdout = self.saved_stdout + print(self.buf) + return False + + +class ProgrammableColorizer(StdoutInterceptor): + """A colorizing interceptor; pass it re.Patterns -> methods that do + something (usually add color to) the match. + + """ + + def __init__( + self, + patterns: Iterable[Tuple[re.Pattern, Callable[[Any, re.Pattern], str]]], + ): + super().__init__() + self.patterns = list(patterns) + + @overrides + def write(self, s: str): + for pattern in self.patterns: + s = pattern[0].sub(pattern[1], s) + self.buf += s + + +if __name__ == '__main__': + + def main() -> None: + import doctest + + doctest.testmod() + + name = " ".join(sys.argv[1:]) + for possibility in COLOR_NAMES_TO_RGB: + if name in possibility: + f = fg(possibility) + b = bg(possibility) + _ = pick_contrasting_color(possibility) + xf = fg(None, _[0], _[1], _[2]) + xb = bg(None, _[0], _[1], _[2]) + print( + f'{f}{xb}{possibility}{reset()}\t\t\t' + f'{b}{xf}{possibility}{reset()}' + ) + + main() diff --git a/src/pyutils/argparse_utils.py b/src/pyutils/argparse_utils.py new file mode 100644 index 0000000..3b466b0 --- /dev/null +++ b/src/pyutils/argparse_utils.py @@ -0,0 +1,282 @@ +#!/usr/bin/python3 + +# © Copyright 2021-2022, Scott Gasch + +"""Helpers for commandline argument parsing.""" + +import argparse +import datetime +import logging +import os +from typing import Any + +from overrides import overrides + +# This module is commonly used by others in here and should avoid +# taking any unnecessary dependencies back on them. + +logger = logging.getLogger(__name__) + + +class ActionNoYes(argparse.Action): + """An argparse Action that allows for commandline arguments like this:: + + cfg.add_argument( + '--enable_the_thing', + action=ActionNoYes, + default=False, + help='Should we enable the thing?' + ) + + This creates the following cmdline arguments:: + + --enable_the_thing + --no_enable_the_thing + + These arguments can be used to indicate the inclusion or exclusion of + binary exclusive behaviors. + """ + + def __init__(self, option_strings, dest, default=None, required=False, help=None): + if default is None: + msg = 'You must provide a default with Yes/No action' + logger.critical(msg) + raise ValueError(msg) + if len(option_strings) != 1: + msg = 'Only single argument is allowed with NoYes action' + logger.critical(msg) + raise ValueError(msg) + opt = option_strings[0] + if not opt.startswith('--'): + msg = 'Yes/No arguments must be prefixed with --' + logger.critical(msg) + raise ValueError(msg) + + opt = opt[2:] + opts = ['--' + opt, '--no_' + opt] + super().__init__( + opts, + dest, + nargs=0, + const=None, + default=default, + required=required, + help=help, + ) + + @overrides + def __call__(self, parser, namespace, values, option_strings=None): + if option_strings.startswith('--no-') or option_strings.startswith('--no_'): + setattr(namespace, self.dest, False) + else: + setattr(namespace, self.dest, True) + + +def valid_bool(v: Any) -> bool: + """ + If the string is a valid bool, return its value. + + >>> valid_bool(True) + True + + >>> valid_bool("true") + True + + >>> valid_bool("yes") + True + + >>> valid_bool("on") + True + + >>> valid_bool("1") + True + + >>> valid_bool(12345) + Traceback (most recent call last): + ... + argparse.ArgumentTypeError: 12345 + + """ + if isinstance(v, bool): + return v + from pyutils.string_utils import to_bool + + try: + return to_bool(v) + except Exception as e: + raise argparse.ArgumentTypeError(v) from e + + +def valid_ip(ip: str) -> str: + """ + If the string is a valid IPv4 address, return it. Otherwise raise + an ArgumentTypeError. + + >>> valid_ip("1.2.3.4") + '1.2.3.4' + + >>> valid_ip("localhost") + Traceback (most recent call last): + ... + argparse.ArgumentTypeError: localhost is an invalid IP address + + """ + from pyutils.string_utils import extract_ip_v4 + + s = extract_ip_v4(ip.strip()) + if s is not None: + return s + msg = f"{ip} is an invalid IP address" + logger.error(msg) + raise argparse.ArgumentTypeError(msg) + + +def valid_mac(mac: str) -> str: + """ + If the string is a valid MAC address, return it. Otherwise raise + an ArgumentTypeError. + + >>> valid_mac('12:23:3A:4F:55:66') + '12:23:3A:4F:55:66' + + >>> valid_mac('12-23-3A-4F-55-66') + '12-23-3A-4F-55-66' + + >>> valid_mac('big') + Traceback (most recent call last): + ... + argparse.ArgumentTypeError: big is an invalid MAC address + + """ + from pyutils.string_utils import extract_mac_address + + s = extract_mac_address(mac) + if s is not None: + return s + msg = f"{mac} is an invalid MAC address" + logger.error(msg) + raise argparse.ArgumentTypeError(msg) + + +def valid_percentage(num: str) -> float: + """ + If the string is a valid percentage, return it. Otherwise raise + an ArgumentTypeError. + + >>> valid_percentage("15%") + 15.0 + + >>> valid_percentage('40') + 40.0 + + >>> valid_percentage('115') + Traceback (most recent call last): + ... + argparse.ArgumentTypeError: 115 is an invalid percentage; expected 0 <= n <= 100.0 + + """ + num = num.strip('%') + n = float(num) + if 0.0 <= n <= 100.0: + return n + msg = f"{num} is an invalid percentage; expected 0 <= n <= 100.0" + logger.error(msg) + raise argparse.ArgumentTypeError(msg) + + +def valid_filename(filename: str) -> str: + """ + If the string is a valid filename, return it. Otherwise raise + an ArgumentTypeError. + + >>> valid_filename('/tmp') + '/tmp' + + >>> valid_filename('wfwefwefwefwefwefwefwefwef') + Traceback (most recent call last): + ... + argparse.ArgumentTypeError: wfwefwefwefwefwefwefwefwef was not found and is therefore invalid. + + """ + s = filename.strip() + if os.path.exists(s): + return s + msg = f"{filename} was not found and is therefore invalid." + logger.error(msg) + raise argparse.ArgumentTypeError(msg) + + +def valid_date(txt: str) -> datetime.date: + """If the string is a valid date, return it. Otherwise raise + an ArgumentTypeError. + + >>> valid_date('6/5/2021') + datetime.date(2021, 6, 5) + + # Note: dates like 'next wednesday' work fine, they are just + # hard to test for without knowing when the testcase will be + # executed... + >>> valid_date('next wednesday') # doctest: +ELLIPSIS + -ANYTHING- + """ + from pyutils.string_utils import to_date + + date = to_date(txt) + if date is not None: + return date + msg = f'Cannot parse argument as a date: {txt}' + logger.error(msg) + raise argparse.ArgumentTypeError(msg) + + +def valid_datetime(txt: str) -> datetime.datetime: + """If the string is a valid datetime, return it. Otherwise raise + an ArgumentTypeError. + + >>> valid_datetime('6/5/2021 3:01:02') + datetime.datetime(2021, 6, 5, 3, 1, 2) + + # Again, these types of expressions work fine but are + # difficult to test with doctests because the answer is + # relative to the time the doctest is executed. + >>> valid_datetime('next christmas at 4:15am') # doctest: +ELLIPSIS + -ANYTHING- + """ + from pyutils.string_utils import to_datetime + + dt = to_datetime(txt) + if dt is not None: + return dt + msg = f'Cannot parse argument as datetime: {txt}' + logger.error(msg) + raise argparse.ArgumentTypeError(msg) + + +def valid_duration(txt: str) -> datetime.timedelta: + """If the string is a valid time duration, return a + datetime.timedelta representing the period of time. Otherwise + maybe raise an ArgumentTypeError or potentially just treat the + time window as zero in length. + + >>> valid_duration('3m') + datetime.timedelta(seconds=180) + + >>> valid_duration('your mom') + datetime.timedelta(0) + + """ + from pyutils.datetimez.datetime_utils import parse_duration + + try: + secs = parse_duration(txt) + return datetime.timedelta(seconds=secs) + except Exception as e: + logger.exception(e) + raise argparse.ArgumentTypeError(e) from e + + +if __name__ == '__main__': + import doctest + + doctest.ELLIPSIS_MARKER = '-ANYTHING-' + doctest.testmod() diff --git a/src/pyutils/bootstrap.py b/src/pyutils/bootstrap.py new file mode 100644 index 0000000..5a59615 --- /dev/null +++ b/src/pyutils/bootstrap.py @@ -0,0 +1,414 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""This is a module for wrapping around python programs and doing some +minor setup and tear down work for them. With it, you will get: + +* The ability to break into pdb on unhandled exceptions, +* automatic support for :file:`config.py` (argument parsing) +* automatic logging support for :file:`logging.py`, +* the ability to enable code profiling, +* the ability to enable module import auditing, +* optional memory profiling for your program, +* ability to set random seed via commandline, +* automatic program timing and reporting, +* more verbose error handling and reporting, + +Most of these are enabled and/or configured via commandline flags +(see below). + +""" + +import functools +import importlib +import importlib.abc +import logging +import os +import sys +from inspect import stack + +from pyutils import config, logging_utils +from pyutils.argparse_utils import ActionNoYes + +# This module is commonly used by others in here and should avoid +# taking any unnecessary dependencies back on them. + + +logger = logging.getLogger(__name__) + +cfg = config.add_commandline_args( + f'Bootstrap ({__file__})', + 'Args related to python program bootstrapper and Swiss army knife', +) +cfg.add_argument( + '--debug_unhandled_exceptions', + action=ActionNoYes, + default=False, + help='Break into pdb on top level unhandled exceptions.', +) +cfg.add_argument( + '--show_random_seed', + action=ActionNoYes, + default=False, + help='Should we display (and log.debug) the global random seed?', +) +cfg.add_argument( + '--set_random_seed', + type=int, + nargs=1, + default=None, + metavar='SEED_INT', + help='Override the global random seed with a particular number.', +) +cfg.add_argument( + '--dump_all_objects', + action=ActionNoYes, + default=False, + help='Should we dump the Python import tree before main?', +) +cfg.add_argument( + '--audit_import_events', + action=ActionNoYes, + default=False, + help='Should we audit all import events?', +) +cfg.add_argument( + '--run_profiler', + action=ActionNoYes, + default=False, + help='Should we run cProfile on this code?', +) +cfg.add_argument( + '--trace_memory', + action=ActionNoYes, + default=False, + help='Should we record/report on memory utilization?', +) + +ORIGINAL_EXCEPTION_HOOK = sys.excepthook + + +def handle_uncaught_exception(exc_type, exc_value, exc_tb): + """ + Top-level exception handler for exceptions that make it past any exception + handlers in the python code being run. Logs the error and stacktrace then + maybe attaches a debugger. + + """ + msg = f'Unhandled top level exception {exc_type}' + logger.exception(msg) + print(msg, file=sys.stderr) + if issubclass(exc_type, KeyboardInterrupt): + sys.__excepthook__(exc_type, exc_value, exc_tb) + return + else: + import io + import traceback + + tb_output = io.StringIO() + traceback.print_tb(exc_tb, None, tb_output) + print(tb_output.getvalue(), file=sys.stderr) + logger.error(tb_output.getvalue()) + tb_output.close() + + # stdin or stderr is redirected, just do the normal thing + if not sys.stderr.isatty() or not sys.stdin.isatty(): + ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb) + + else: # a terminal is attached and stderr isn't redirected, maybe debug. + if config.config['debug_unhandled_exceptions']: + logger.info("Invoking the debugger...") + import pdb + + pdb.pm() + else: + ORIGINAL_EXCEPTION_HOOK(exc_type, exc_value, exc_tb) + + +class ImportInterceptor(importlib.abc.MetaPathFinder): + """An interceptor that always allows module load events but dumps a + record into the log and onto stdout when modules are loaded and + produces an audit of who imported what at the end of the run. It + can't see any load events that happen before it, though, so move + bootstrap up in your __main__'s import list just temporarily to + get a good view. + + """ + + def __init__(self): + from pyutils.collectionz.trie import Trie + + self.module_by_filename_cache = {} + self.repopulate_modules_by_filename() + self.tree = Trie() + self.tree_node_by_module = {} + + def repopulate_modules_by_filename(self): + self.module_by_filename_cache.clear() + for ( + _, + mod, + ) in sys.modules.copy().items(): # copy here because modules is volatile + if hasattr(mod, '__file__'): + fname = getattr(mod, '__file__') + else: + fname = 'unknown' + self.module_by_filename_cache[fname] = mod + + @staticmethod + def should_ignore_filename(filename: str) -> bool: + return 'importlib' in filename or 'six.py' in filename + + def find_module(self, fullname, path): + raise Exception( + "This method has been deprecated since Python 3.4, please upgrade." + ) + + def find_spec(self, loaded_module, path=None, _=None): + s = stack() + for x in range(3, len(s)): + filename = s[x].filename + if ImportInterceptor.should_ignore_filename(filename): + continue + + loading_function = s[x].function + if filename in self.module_by_filename_cache: + loading_module = self.module_by_filename_cache[filename] + else: + self.repopulate_modules_by_filename() + loading_module = self.module_by_filename_cache.get(filename, 'unknown') + + path = self.tree_node_by_module.get(loading_module, []) + path.extend([loaded_module]) + self.tree.insert(path) + self.tree_node_by_module[loading_module] = path + + msg = f'*** Import {loaded_module} from {filename}:{s[x].lineno} in {loading_module}::{loading_function}' + logger.debug(msg) + print(msg) + return + msg = f'*** Import {loaded_module} from ?????' + logger.debug(msg) + print(msg) + + def invalidate_caches(self): + pass + + def find_importer(self, module: str): + if module in self.tree_node_by_module: + node = self.tree_node_by_module[module] + return node + return [] + + +# Audit import events? Note: this runs early in the lifetime of the +# process (assuming that import bootstrap happens early); config has +# (probably) not yet been loaded or parsed the commandline. Also, +# some things have probably already been imported while we weren't +# watching so this information may be incomplete. +# +# Also note: move bootstrap up in the global import list to catch +# more import events and have a more complete record. +IMPORT_INTERCEPTOR = None +for arg in sys.argv: + if arg == '--audit_import_events': + IMPORT_INTERCEPTOR = ImportInterceptor() + sys.meta_path.insert(0, IMPORT_INTERCEPTOR) + + +def dump_all_objects() -> None: + """Helper code to dump all known python objects.""" + + messages = {} + all_modules = sys.modules + for obj in object.__subclasses__(): + if not hasattr(obj, '__name__'): + continue + klass = obj.__name__ + if not hasattr(obj, '__module__'): + continue + class_mod_name = obj.__module__ + if class_mod_name in all_modules: + mod = all_modules[class_mod_name] + if not hasattr(mod, '__name__'): + mod_name = class_mod_name + else: + mod_name = mod.__name__ + if hasattr(mod, '__file__'): + mod_file = mod.__file__ + else: + mod_file = 'unknown' + if IMPORT_INTERCEPTOR is not None: + import_path = IMPORT_INTERCEPTOR.find_importer(mod_name) + else: + import_path = 'unknown' + msg = f'{class_mod_name}::{klass} ({mod_file})' + if import_path != 'unknown' and len(import_path) > 0: + msg += f' imported by {import_path}' + messages[f'{class_mod_name}::{klass}'] = msg + for x in sorted(messages.keys()): + logger.debug(messages[x]) + print(messages[x]) + + +def initialize(entry_point): + """ + Remember to initialize config, initialize logging, set/log a random + seed, etc... before running main. If you use this decorator around + your main, like this:: + + from pyutils import bootstrap + + @bootstrap.initialize + def main(): + whatever + + if __name__ == '__main__': + main() + + You get: + + * The ability to break into pdb on unhandled exceptions, + * automatic support for :file:`config.py` (argument parsing) + * automatic logging support for :file:`logging.py`, + * the ability to enable code profiling, + * the ability to enable module import auditing, + * optional memory profiling for your program, + * ability to set random seed via commandline, + * automatic program timing and reporting, + * more verbose error handling and reporting, + + Most of these are enabled and/or configured via commandline flags + (see below). + """ + + @functools.wraps(entry_point) + def initialize_wrapper(*args, **kwargs): + # Hook top level unhandled exceptions, maybe invoke debugger. + if sys.excepthook == sys.__excepthook__: + sys.excepthook = handle_uncaught_exception + + # Try to figure out the name of the program entry point. Then + # parse configuration (based on cmdline flags, environment vars + # etc...) + entry_filename = None + entry_descr = None + try: + entry_filename = entry_point.__code__.co_filename + entry_descr = entry_point.__code__.__repr__() + except Exception: + if ( + '__globals__' in entry_point.__dict__ + and '__file__' in entry_point.__globals__ + ): + entry_filename = entry_point.__globals__['__file__'] + entry_descr = entry_filename + config.parse(entry_filename) + + if config.config['trace_memory']: + import tracemalloc + + tracemalloc.start() + + # Initialize logging... and log some remembered messages from + # config module. + logging_utils.initialize_logging(logging.getLogger()) + config.late_logging() + + # Maybe log some info about the python interpreter itself. + logger.debug( + 'Platform: %s, maxint=0x%x, byteorder=%s', + sys.platform, + sys.maxsize, + sys.byteorder, + ) + logger.debug('Python interpreter version: %s', sys.version) + logger.debug('Python implementation: %s', sys.implementation) + logger.debug('Python C API version: %s', sys.api_version) + if __debug__: + logger.debug('Python interpreter running in __debug__ mode.') + else: + logger.debug('Python interpreter running in optimized mode.') + logger.debug('Python path: %s', sys.path) + + # Allow programs that don't bother to override the random seed + # to be replayed via the commandline. + import random + + random_seed = config.config['set_random_seed'] + if random_seed is not None: + random_seed = random_seed[0] + else: + random_seed = int.from_bytes(os.urandom(4), 'little') + + if config.config['show_random_seed']: + msg = f'Global random seed is: {random_seed}' + logger.debug(msg) + print(msg) + random.seed(random_seed) + + # Do it, invoke the user's code. Pay attention to how long it takes. + logger.debug('Starting %s (program entry point)', entry_descr) + ret = None + from pyutils import stopwatch + + if config.config['run_profiler']: + import cProfile + from pstats import SortKey + + with stopwatch.Timer() as t: + cProfile.runctx( + "ret = entry_point(*args, **kwargs)", + globals(), + locals(), + None, + SortKey.CUMULATIVE, + ) + else: + with stopwatch.Timer() as t: + ret = entry_point(*args, **kwargs) + + logger.debug('%s (program entry point) returned %s.', entry_descr, ret) + + if config.config['trace_memory']: + snapshot = tracemalloc.take_snapshot() + top_stats = snapshot.statistics('lineno') + print() + print("--trace_memory's top 10 memory using files:") + for stat in top_stats[:10]: + print(stat) + + if config.config['dump_all_objects']: + dump_all_objects() + + if config.config['audit_import_events']: + if IMPORT_INTERCEPTOR is not None: + print(IMPORT_INTERCEPTOR.tree) + + walltime = t() + (utime, stime, cutime, cstime, elapsed_time) = os.times() + logger.debug( + '\n' + 'user: %.4fs\n' + 'system: %.4fs\n' + 'child user: %.4fs\n' + 'child system: %.4fs\n' + 'machine uptime: %.4fs\n' + 'walltime: %.4fs', + utime, + stime, + cutime, + cstime, + elapsed_time, + walltime, + ) + + # If it doesn't return cleanly, call attention to the return value. + if ret is not None and ret != 0: + logger.error('Exit %s', ret) + else: + logger.debug('Exit %s', ret) + sys.exit(ret) + + return initialize_wrapper diff --git a/src/pyutils/collectionz/__init__.py b/src/pyutils/collectionz/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/pyutils/collectionz/bidict.py b/src/pyutils/collectionz/bidict.py new file mode 100644 index 0000000..000fdb3 --- /dev/null +++ b/src/pyutils/collectionz/bidict.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""A bidirectional dictionary.""" + + +class BiDict(dict): + def __init__(self, *args, **kwargs): + """ + A class that stores both a Mapping between keys and values and + also the inverse mapping between values and their keys to + allow for efficient lookups in either direction. Because it + is possible to have several keys with the same value, using + the inverse map returns a sequence of keys. + + >>> d = BiDict() + >>> d['a'] = 1 + >>> d['b'] = 2 + >>> d['c'] = 2 + >>> d['a'] + 1 + >>> d.inverse[1] + ['a'] + >>> d.inverse[2] + ['b', 'c'] + >>> len(d) + 3 + >>> del d['c'] + >>> len(d) + 2 + >>> d.inverse[2] + ['b'] + + """ + super().__init__(*args, **kwargs) + self.inverse = {} + for key, value in self.items(): + self.inverse.setdefault(value, []).append(key) + + def __setitem__(self, key, value): + if key in self: + old_value = self[key] + self.inverse[old_value].remove(key) + super().__setitem__(key, value) + self.inverse.setdefault(value, []).append(key) + + def __delitem__(self, key): + value = self[key] + self.inverse.setdefault(value, []).remove(key) + if value in self.inverse and not self.inverse[value]: + del self.inverse[value] + super().__delitem__(key) diff --git a/src/pyutils/collectionz/bst.py b/src/pyutils/collectionz/bst.py new file mode 100644 index 0000000..52f722c --- /dev/null +++ b/src/pyutils/collectionz/bst.py @@ -0,0 +1,643 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""A binary search tree.""" + +from typing import Any, Generator, List, Optional + + +class Node(object): + def __init__(self, value: Any) -> None: + """ + Note: value can be anything as long as it is comparable. + Check out @functools.total_ordering. + """ + self.left: Optional[Node] = None + self.right: Optional[Node] = None + self.value = value + + +class BinarySearchTree(object): + def __init__(self): + self.root = None + self.count = 0 + self.traverse = None + + def get_root(self) -> Optional[Node]: + return self.root + + def insert(self, value: Any): + """ + Insert something into the tree. + + >>> t = BinarySearchTree() + >>> t.insert(10) + >>> t.insert(20) + >>> t.insert(5) + >>> len(t) + 3 + + >>> t.get_root().value + 10 + + """ + if self.root is None: + self.root = Node(value) + self.count = 1 + else: + self._insert(value, self.root) + + def _insert(self, value: Any, node: Node): + """Insertion helper""" + if value < node.value: + if node.left is not None: + self._insert(value, node.left) + else: + node.left = Node(value) + self.count += 1 + else: + if node.right is not None: + self._insert(value, node.right) + else: + node.right = Node(value) + self.count += 1 + + def __getitem__(self, value: Any) -> Optional[Node]: + """ + Find an item in the tree and return its Node. Returns + None if the item is not in the tree. + + >>> t = BinarySearchTree() + >>> t[99] + + >>> t.insert(10) + >>> t.insert(20) + >>> t.insert(5) + >>> t[10].value + 10 + + >>> t[99] + + """ + if self.root is not None: + return self._find(value, self.root) + return None + + def _find(self, value: Any, node: Node) -> Optional[Node]: + """Find helper""" + if value == node.value: + return node + elif value < node.value and node.left is not None: + return self._find(value, node.left) + elif value > node.value and node.right is not None: + return self._find(value, node.right) + return None + + def _parent_path( + self, current: Optional[Node], target: Node + ) -> List[Optional[Node]]: + if current is None: + return [None] + ret: List[Optional[Node]] = [current] + if target.value == current.value: + return ret + elif target.value < current.value: + ret.extend(self._parent_path(current.left, target)) + return ret + else: + assert target.value > current.value + ret.extend(self._parent_path(current.right, target)) + return ret + + def parent_path(self, node: Node) -> List[Optional[Node]]: + """Return a list of nodes representing the path from + the tree's root to the node argument. If the node does + not exist in the tree for some reason, the last element + on the path will be None but the path will indicate the + ancestor path of that node were it inserted. + + >>> t = BinarySearchTree() + >>> t.insert(50) + >>> t.insert(75) + >>> t.insert(25) + >>> t.insert(12) + >>> t.insert(33) + >>> t.insert(4) + >>> t.insert(88) + >>> t + 50 + ├──25 + │ ├──12 + │ │ └──4 + │ └──33 + └──75 + └──88 + + >>> n = t[4] + >>> for x in t.parent_path(n): + ... print(x.value) + 50 + 25 + 12 + 4 + + >>> del t[4] + >>> for x in t.parent_path(n): + ... if x is not None: + ... print(x.value) + ... else: + ... print(x) + 50 + 25 + 12 + None + + """ + return self._parent_path(self.root, node) + + def __delitem__(self, value: Any) -> bool: + """ + Delete an item from the tree and preserve the BST property. + + >>> t = BinarySearchTree() + >>> t.insert(50) + >>> t.insert(75) + >>> t.insert(25) + >>> t.insert(66) + >>> t.insert(22) + >>> t.insert(13) + >>> t.insert(85) + >>> t + 50 + ├──25 + │ └──22 + │ └──13 + └──75 + ├──66 + └──85 + + >>> for value in t.iterate_inorder(): + ... print(value) + 13 + 22 + 25 + 50 + 66 + 75 + 85 + + >>> del t[22] # Note: bool result is discarded + + >>> for value in t.iterate_inorder(): + ... print(value) + 13 + 25 + 50 + 66 + 75 + 85 + + >>> t.__delitem__(13) + True + >>> for value in t.iterate_inorder(): + ... print(value) + 25 + 50 + 66 + 75 + 85 + + >>> t.__delitem__(75) + True + >>> for value in t.iterate_inorder(): + ... print(value) + 25 + 50 + 66 + 85 + >>> t + 50 + ├──25 + └──85 + └──66 + + >>> t.__delitem__(99) + False + + """ + if self.root is not None: + ret = self._delete(value, None, self.root) + if ret: + self.count -= 1 + if self.count == 0: + self.root = None + return ret + return False + + def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool: + """Delete helper""" + if node.value == value: + # Deleting a leaf node + if node.left is None and node.right is None: + if parent is not None: + if parent.left == node: + parent.left = None + else: + assert parent.right == node + parent.right = None + return True + + # Node only has a right. + elif node.left is None: + assert node.right is not None + if parent is not None: + if parent.left == node: + parent.left = node.right + else: + assert parent.right == node + parent.right = node.right + return True + + # Node only has a left. + elif node.right is None: + assert node.left is not None + if parent is not None: + if parent.left == node: + parent.left = node.left + else: + assert parent.right == node + parent.right = node.left + return True + + # Node has both a left and right. + else: + assert node.left is not None and node.right is not None + descendent = node.right + while descendent.left is not None: + descendent = descendent.left + node.value = descendent.value + return self._delete(node.value, node, node.right) + elif value < node.value and node.left is not None: + return self._delete(value, node, node.left) + elif value > node.value and node.right is not None: + return self._delete(value, node, node.right) + return False + + def __len__(self): + """ + Returns the count of items in the tree. + + >>> t = BinarySearchTree() + >>> len(t) + 0 + >>> t.insert(50) + >>> len(t) + 1 + >>> t.__delitem__(50) + True + >>> len(t) + 0 + >>> t.insert(75) + >>> t.insert(25) + >>> t.insert(66) + >>> t.insert(22) + >>> t.insert(13) + >>> t.insert(85) + >>> len(t) + 6 + + """ + return self.count + + def __contains__(self, value: Any) -> bool: + """ + Returns True if the item is in the tree; False otherwise. + + """ + return self.__getitem__(value) is not None + + def _iterate_preorder(self, node: Node): + yield node.value + if node.left is not None: + yield from self._iterate_preorder(node.left) + if node.right is not None: + yield from self._iterate_preorder(node.right) + + def _iterate_inorder(self, node: Node): + if node.left is not None: + yield from self._iterate_inorder(node.left) + yield node.value + if node.right is not None: + yield from self._iterate_inorder(node.right) + + def _iterate_postorder(self, node: Node): + if node.left is not None: + yield from self._iterate_postorder(node.left) + if node.right is not None: + yield from self._iterate_postorder(node.right) + yield node.value + + def iterate_preorder(self): + """ + Yield the tree's items in a preorder traversal sequence. + + >>> t = BinarySearchTree() + >>> t.insert(50) + >>> t.insert(75) + >>> t.insert(25) + >>> t.insert(66) + >>> t.insert(22) + >>> t.insert(13) + + >>> for value in t.iterate_preorder(): + ... print(value) + 50 + 25 + 22 + 13 + 75 + 66 + + """ + if self.root is not None: + yield from self._iterate_preorder(self.root) + + def iterate_inorder(self): + """ + Yield the tree's items in a preorder traversal sequence. + + >>> t = BinarySearchTree() + >>> t.insert(50) + >>> t.insert(75) + >>> t.insert(25) + >>> t.insert(66) + >>> t.insert(22) + >>> t.insert(13) + >>> t.insert(24) + >>> t + 50 + ├──25 + │ └──22 + │ ├──13 + │ └──24 + └──75 + └──66 + + >>> for value in t.iterate_inorder(): + ... print(value) + 13 + 22 + 24 + 25 + 50 + 66 + 75 + + """ + if self.root is not None: + yield from self._iterate_inorder(self.root) + + def iterate_postorder(self): + """ + Yield the tree's items in a preorder traversal sequence. + + >>> t = BinarySearchTree() + >>> t.insert(50) + >>> t.insert(75) + >>> t.insert(25) + >>> t.insert(66) + >>> t.insert(22) + >>> t.insert(13) + + >>> for value in t.iterate_postorder(): + ... print(value) + 13 + 22 + 25 + 66 + 75 + 50 + + """ + if self.root is not None: + yield from self._iterate_postorder(self.root) + + def _iterate_leaves(self, node: Node): + if node.left is not None: + yield from self._iterate_leaves(node.left) + if node.right is not None: + yield from self._iterate_leaves(node.right) + if node.left is None and node.right is None: + yield node.value + + def iterate_leaves(self): + """ + Iterate only the leaf nodes in the tree. + + >>> t = BinarySearchTree() + >>> t.insert(50) + >>> t.insert(75) + >>> t.insert(25) + >>> t.insert(66) + >>> t.insert(22) + >>> t.insert(13) + + >>> for value in t.iterate_leaves(): + ... print(value) + 13 + 66 + + """ + if self.root is not None: + yield from self._iterate_leaves(self.root) + + def _iterate_by_depth(self, node: Node, depth: int): + if depth == 0: + yield node.value + else: + assert depth > 0 + if node.left is not None: + yield from self._iterate_by_depth(node.left, depth - 1) + if node.right is not None: + yield from self._iterate_by_depth(node.right, depth - 1) + + def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]: + """ + Iterate only the leaf nodes in the tree. + + >>> t = BinarySearchTree() + >>> t.insert(50) + >>> t.insert(75) + >>> t.insert(25) + >>> t.insert(66) + >>> t.insert(22) + >>> t.insert(13) + + >>> for value in t.iterate_nodes_by_depth(2): + ... print(value) + 22 + 66 + + >>> for value in t.iterate_nodes_by_depth(3): + ... print(value) + 13 + + """ + if self.root is not None: + yield from self._iterate_by_depth(self.root, depth) + + def get_next_node(self, node: Node) -> Node: + """ + Given a tree node, get the next greater node in the tree. + + >>> t = BinarySearchTree() + >>> t.insert(50) + >>> t.insert(75) + >>> t.insert(25) + >>> t.insert(66) + >>> t.insert(22) + >>> t.insert(13) + >>> t.insert(23) + >>> t + 50 + ├──25 + │ └──22 + │ ├──13 + │ └──23 + └──75 + └──66 + + >>> n = t[23] + >>> t.get_next_node(n).value + 25 + + >>> n = t[50] + >>> t.get_next_node(n).value + 66 + + """ + if node.right is not None: + x = node.right + while x.left is not None: + x = x.left + return x + + path = self.parent_path(node) + assert path[-1] is not None + assert path[-1] == node + path = path[:-1] + path.reverse() + for ancestor in path: + assert ancestor is not None + if node != ancestor.right: + return ancestor + node = ancestor + raise Exception() + + def _depth(self, node: Node, sofar: int) -> int: + depth_left = sofar + 1 + depth_right = sofar + 1 + if node.left is not None: + depth_left = self._depth(node.left, sofar + 1) + if node.right is not None: + depth_right = self._depth(node.right, sofar + 1) + return max(depth_left, depth_right) + + def depth(self): + """ + Returns the max height (depth) of the tree in plies (edge distance + from root). + + >>> t = BinarySearchTree() + >>> t.depth() + 0 + + >>> t.insert(50) + >>> t.depth() + 1 + + >>> t.insert(65) + >>> t.depth() + 2 + + >>> t.insert(33) + >>> t.depth() + 2 + + >>> t.insert(2) + >>> t.insert(1) + >>> t.depth() + 4 + + """ + if self.root is None: + return 0 + return self._depth(self.root, 0) + + def height(self): + return self.depth() + + def repr_traverse( + self, + padding: str, + pointer: str, + node: Optional[Node], + has_right_sibling: bool, + ) -> str: + if node is not None: + viz = f'\n{padding}{pointer}{node.value}' + if has_right_sibling: + padding += "│ " + else: + padding += ' ' + + pointer_right = "└──" + if node.right is not None: + pointer_left = "├──" + else: + pointer_left = "└──" + + viz += self.repr_traverse( + padding, pointer_left, node.left, node.right is not None + ) + viz += self.repr_traverse(padding, pointer_right, node.right, False) + return viz + return "" + + def __repr__(self): + """ + Draw the tree in ASCII. + + >>> t = BinarySearchTree() + >>> t.insert(50) + >>> t.insert(25) + >>> t.insert(75) + >>> t.insert(12) + >>> t.insert(33) + >>> t.insert(88) + >>> t.insert(55) + >>> t + 50 + ├──25 + │ ├──12 + │ └──33 + └──75 + ├──55 + └──88 + """ + if self.root is None: + return "" + + ret = f'{self.root.value}' + pointer_right = "└──" + if self.root.right is None: + pointer_left = "└──" + else: + pointer_left = "├──" + + ret += self.repr_traverse( + '', pointer_left, self.root.left, self.root.left is not None + ) + ret += self.repr_traverse('', pointer_right, self.root.right, False) + return ret diff --git a/src/pyutils/collectionz/shared_dict.py b/src/pyutils/collectionz/shared_dict.py new file mode 100644 index 0000000..ef74f93 --- /dev/null +++ b/src/pyutils/collectionz/shared_dict.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python3 + +""" +The MIT License (MIT) + +Copyright (c) 2020 LuizaLabs +Additions/Modifications Copyright (c) 2022 Scott Gasch + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +This class is based on https://github.com/luizalabs/shared-memory-dict. +For details about what is preserved from the original and what was changed +by Scott, see NOTICE at the root of this module. +""" + +import pickle +from contextlib import contextmanager +from multiprocessing import RLock, shared_memory +from typing import ( + Any, + Dict, + Hashable, + ItemsView, + Iterator, + KeysView, + Optional, + Tuple, + ValuesView, +) + + +class PickleSerializer: + """A serializer that uses pickling. Used to read/write bytes in the shared + memory region and interpret them as a dict.""" + + def dumps(self, obj: Dict[Hashable, Any]) -> bytes: + try: + return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL) + except pickle.PicklingError as e: + raise Exception from e + + def loads(self, data: bytes) -> Dict[Hashable, Any]: + try: + return pickle.loads(data) + except pickle.UnpicklingError as e: + raise Exception from e + + +# TODOs: profile the serializers and figure out the fastest one. Can +# we use a ChainMap to avoid the constant de/re-serialization of the +# whole thing? + + +class SharedDict(object): + """This class emulates the dict container but uses a + Multiprocessing.SharedMemory region to back the dict such that it + can be read and written by multiple independent processes at the + same time. Because it constantly de/re-serializes the dict, it is + much slower than a normal dict. + + """ + + NULL_BYTE = b'\x00' + LOCK = RLock() + + def __init__( + self, + name: Optional[str] = None, + size_bytes: Optional[int] = None, + ) -> None: + """ + Creates or attaches a shared dictionary back by a SharedMemory buffer. + For create semantics, a unique name (string) and a max dictionary size + (expressed in bytes) must be provided. For attach semantics, these are + ignored. + + The first process that creates the SharedDict is responsible for + (optionally) naming it and deciding the max size (in bytes) that + it may be. It does this via args to the c'tor. + + Subsequent processes may safely omit name and size args. + + """ + assert size_bytes is None or size_bytes > 0 + self._serializer = PickleSerializer() + self.shared_memory = self._get_or_create_memory_block(name, size_bytes) + self._ensure_memory_initialization() + self.name = self.shared_memory.name + + def get_name(self): + """Returns the name of the shared memory buffer backing the dict.""" + return self.name + + def _get_or_create_memory_block( + self, + name: Optional[str] = None, + size_bytes: Optional[int] = None, + ) -> shared_memory.SharedMemory: + try: + return shared_memory.SharedMemory(name=name) + except FileNotFoundError: + assert size_bytes is not None + return shared_memory.SharedMemory(name=name, create=True, size=size_bytes) + + def _ensure_memory_initialization(self): + with SharedDict.LOCK: + memory_is_empty = ( + bytes(self.shared_memory.buf).split(SharedDict.NULL_BYTE, 1)[0] == b'' + ) + if memory_is_empty: + self.clear() + + def _write_memory(self, db: Dict[Hashable, Any]) -> None: + data = self._serializer.dumps(db) + with SharedDict.LOCK: + try: + self.shared_memory.buf[: len(data)] = data + except ValueError as e: + raise ValueError("exceeds available storage") from e + + def _read_memory(self) -> Dict[Hashable, Any]: + with SharedDict.LOCK: + return self._serializer.loads(self.shared_memory.buf.tobytes()) + + @contextmanager + def _modify_dict(self): + with SharedDict.LOCK: + db = self._read_memory() + yield db + self._write_memory(db) + + def close(self) -> None: + """Unmap the shared dict and memory behind it from this + process. Called by automatically __del__""" + if not hasattr(self, 'shared_memory'): + return + self.shared_memory.close() + + def cleanup(self) -> None: + """Unlink the shared dict and memory behind it. Only the last process should + invoke this. Not called automatically.""" + if not hasattr(self, 'shared_memory'): + return + with SharedDict.LOCK: + self.shared_memory.unlink() + + def clear(self) -> None: + """Clear the dict.""" + self._write_memory({}) + + def copy(self) -> Dict[Hashable, Any]: + """Returns a shallow copy of the dict.""" + return self._read_memory() + + def __getitem__(self, key: Hashable) -> Any: + return self._read_memory()[key] + + def __setitem__(self, key: Hashable, value: Any) -> None: + with self._modify_dict() as db: + db[key] = value + + def __len__(self) -> int: + return len(self._read_memory()) + + def __delitem__(self, key: Hashable) -> None: + with self._modify_dict() as db: + del db[key] + + def __iter__(self) -> Iterator[Hashable]: + return iter(self._read_memory()) + + def __reversed__(self) -> Iterator[Hashable]: + return reversed(self._read_memory()) + + def __del__(self) -> None: + self.close() + + def __contains__(self, key: Hashable) -> bool: + return key in self._read_memory() + + def __eq__(self, other: Any) -> bool: + return self._read_memory() == other + + def __ne__(self, other: Any) -> bool: + return self._read_memory() != other + + def __str__(self) -> str: + return str(self._read_memory()) + + def __repr__(self) -> str: + return repr(self._read_memory()) + + def get(self, key: str, default: Optional[Any] = None) -> Any: + """Gets the value associated with key or a default.""" + return self._read_memory().get(key, default) + + def keys(self) -> KeysView[Hashable]: + return self._read_memory().keys() + + def values(self) -> ValuesView[Any]: + return self._read_memory().values() + + def items(self) -> ItemsView[Hashable, Any]: + return self._read_memory().items() + + def popitem(self) -> Tuple[Hashable, Any]: + """Remove and return the last added item.""" + with self._modify_dict() as db: + return db.popitem() + + def pop(self, key: Hashable, default: Optional[Any] = None) -> Any: + """Remove and return the value associated with key or a default""" + with self._modify_dict() as db: + if default is None: + return db.pop(key) + return db.pop(key, default) + + def update(self, other=(), /, **kwds): + with self._modify_dict() as db: + db.update(other, **kwds) + + def setdefault(self, key: Hashable, default: Optional[Any] = None): + with self._modify_dict() as db: + return db.setdefault(key, default) diff --git a/src/pyutils/collectionz/trie.py b/src/pyutils/collectionz/trie.py new file mode 100644 index 0000000..607f531 --- /dev/null +++ b/src/pyutils/collectionz/trie.py @@ -0,0 +1,346 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""This is a Trie class, see: https://en.wikipedia.org/wiki/Trie. + +It attempts to follow Pythonic container patterns. See doctests +for examples. + +""" + +import logging +from typing import Any, Generator, Sequence + +logger = logging.getLogger(__name__) + + +class Trie(object): + """ + This is a Trie class, see: https://en.wikipedia.org/wiki/Trie. + + It attempts to follow Pythonic container patterns. See doctests + for examples. + + """ + + def __init__(self): + self.root = {} + self.end = "~END~" + self.length = 0 + self.viz = '' + self.content_generator: Generator[str] = None + + def insert(self, item: Sequence[Any]): + """ + Insert an item. + + >>> t = Trie() + >>> t.insert('test') + >>> t.__contains__('test') + True + + """ + current = self.root + for child in item: + current = current.setdefault(child, {}) + current[self.end] = self.end + self.length += 1 + + def __contains__(self, item: Sequence[Any]) -> bool: + """ + Check whether an item is in the Trie. + + >>> t = Trie() + >>> t.insert('test') + >>> t.__contains__('test') + True + >>> t.__contains__('testing') + False + >>> 'test' in t + True + + """ + current = self.__traverse__(item) + if current is None: + return False + else: + return self.end in current + + def contains_prefix(self, item: Sequence[Any]): + """ + Check whether a prefix is in the Trie. The prefix may or may not + be a full item. + + >>> t = Trie() + >>> t.insert('testicle') + >>> t.contains_prefix('test') + True + >>> t.contains_prefix('testicle') + True + >>> t.contains_prefix('tessel') + False + + """ + current = self.__traverse__(item) + return current is not None + + def __traverse__(self, item: Sequence[Any]): + current = self.root + for child in item: + if child in current: + current = current[child] + else: + return None + return current + + def __getitem__(self, item: Sequence[Any]): + """Given an item, return its Trie node which contains all + of the successor (child) node pointers. If the item is not + a node in the Trie, raise a KeyError. + + >>> t = Trie() + >>> t.insert('test') + >>> t.insert('testicle') + >>> t.insert('tessera') + >>> t.insert('tesack') + >>> t['tes'] + {'t': {'~END~': '~END~', 'i': {'c': {'l': {'e': {'~END~': '~END~'}}}}}, 's': {'e': {'r': {'a': {'~END~': '~END~'}}}}, 'a': {'c': {'k': {'~END~': '~END~'}}}} + + """ + ret = self.__traverse__(item) + if ret is None: + raise KeyError(f"Node '{item}' is not in the trie") + return ret + + def delete_recursively(self, node, item: Sequence[Any]) -> bool: + if len(item) == 1: + del node[item] + if len(node) == 0 and node is not self.root: + del node + return True + else: + return False + else: + car = item[0] + cdr = item[1:] + lower = node[car] + if self.delete_recursively(lower, cdr): + return self.delete_recursively(node, car) + return False + + def __delitem__(self, item: Sequence[Any]): + """ + Delete an item from the Trie. + + >>> t = Trie() + >>> t.insert('test') + >>> t.insert('tess') + >>> t.insert('tessel') + >>> len(t) + 3 + >>> t.root + {'t': {'e': {'s': {'t': {'~END~': '~END~'}, 's': {'~END~': '~END~', 'e': {'l': {'~END~': '~END~'}}}}}}} + >>> t.__delitem__('test') + >>> len(t) + 2 + >>> t.root + {'t': {'e': {'s': {'s': {'~END~': '~END~', 'e': {'l': {'~END~': '~END~'}}}}}}} + >>> for x in t: + ... print(x) + tess + tessel + >>> t.__delitem__('tessel') + >>> len(t) + 1 + >>> t.root + {'t': {'e': {'s': {'s': {'~END~': '~END~'}}}}} + >>> for x in t: + ... print(x) + tess + >>> t.__delitem__('tess') + >>> len(t) + 0 + >>> t.root + {} + >>> t.insert('testy') + >>> len(t) + 1 + + """ + if item not in self: + raise KeyError(f"Node '{item}' is not in the trie") + self.delete_recursively(self.root, item) + self.length -= 1 + + def __len__(self): + """ + Returns a count of the Trie's item population. + + >>> t = Trie() + >>> len(t) + 0 + >>> t.insert('test') + >>> len(t) + 1 + >>> t.insert('testicle') + >>> len(t) + 2 + + """ + return self.length + + def __iter__(self): + self.content_generator = self.generate_recursively(self.root, '') + return self + + def generate_recursively(self, node, path: Sequence[Any]): + """ + Generate items in the trie one by one. + + >>> t = Trie() + >>> t.insert('test') + >>> t.insert('tickle') + >>> for item in t.generate_recursively(t.root, ''): + ... print(item) + test + tickle + + """ + for child in node: + if child == self.end: + yield path + else: + yield from self.generate_recursively(node[child], path + child) + + def __next__(self): + """ + Iterate through the contents of the trie. + + >>> t = Trie() + >>> t.insert('test') + >>> t.insert('tickle') + >>> for item in t: + ... print(item) + test + tickle + + """ + ret = next(self.content_generator) + if ret is not None: + return ret + raise StopIteration + + def successors(self, item: Sequence[Any]): + """ + Return a list of the successors of an item. + + >>> t = Trie() + >>> t.insert('what') + >>> t.insert('who') + >>> t.insert('when') + >>> t.successors('wh') + ['a', 'o', 'e'] + + >>> u = Trie() + >>> u.insert(['this', 'is', 'a', 'test']) + >>> u.insert(['this', 'is', 'a', 'robbery']) + >>> u.insert(['this', 'is', 'a', 'walrus']) + >>> u.successors(['this', 'is', 'a']) + ['test', 'robbery', 'walrus'] + + """ + node = self.__traverse__(item) + if node is None: + return None + return [x for x in node if x != self.end] + + def repr_fancy( + self, + padding: str, + pointer: str, + node: Any, + has_sibling: bool, + ): + if node is None: + return '' + if node is not self.root: + ret = f'\n{padding}{pointer}' + if has_sibling: + padding += '│ ' + else: + padding += ' ' + else: + ret = f'{pointer}' + + child_count = 0 + for child in node: + if child != self.end: + child_count += 1 + + for child in node: + if child != self.end: + if child_count > 1: + pointer = "├──" + has_sibling = True + else: + pointer = "└──" + has_sibling = False + pointer += f'{child}' + child_count -= 1 + ret += self.repr_fancy(padding, pointer, node[child], has_sibling) + return ret + + def repr_brief(self, node, delimiter): + """ + A friendly string representation of the contents of the Trie. + + >>> t = Trie() + >>> t.insert([10, 0, 0, 1]) + >>> t.insert([10, 0, 0, 2]) + >>> t.insert([10, 10, 10, 1]) + >>> t.insert([10, 10, 10, 2]) + >>> t.repr_brief(t.root, '.') + '10.[0.0.[1,2],10.10.[1,2]]' + + """ + child_count = 0 + my_rep = '' + for child in node: + if child != self.end: + child_count += 1 + child_rep = self.repr_brief(node[child], delimiter) + if len(child_rep) > 0: + my_rep += str(child) + delimiter + child_rep + "," + else: + my_rep += str(child) + "," + if len(my_rep) > 1: + my_rep = my_rep[:-1] + if child_count > 1: + my_rep = f'[{my_rep}]' + return my_rep + + def __repr__(self): + """ + A friendly string representation of the contents of the Trie. Under + the covers uses repr_fancy. + + >>> t = Trie() + >>> t.insert([10, 0, 0, 1]) + >>> t.insert([10, 0, 0, 2]) + >>> t.insert([10, 10, 10, 1]) + >>> t.insert([10, 10, 10, 2]) + >>> print(t) + * + └──10 + ├──0 + │ └──0 + │ ├──1 + │ └──2 + └──10 + └──10 + ├──1 + └──2 + + """ + return self.repr_fancy('', '*', self.root, False) diff --git a/src/pyutils/compress/__init__.py b/src/pyutils/compress/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/pyutils/compress/letter_compress.py b/src/pyutils/compress/letter_compress.py new file mode 100644 index 0000000..c631803 --- /dev/null +++ b/src/pyutils/compress/letter_compress.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""A simple toy compression helper for lowercase ascii text.""" + +import bitstring + +from pyutils.collectionz.bidict import BiDict + +special_characters = BiDict( + { + ' ': 27, + '.': 28, + ',': 29, + "-": 30, + '"': 31, + } +) + + +def compress(uncompressed: str) -> bytes: + """Compress a word sequence into a stream of bytes. The compressed + form will be 5/8th the size of the original. Words can be lower + case letters or special_characters (above). + + >>> import binascii + >>> binascii.hexlify(compress('this is a test')) + b'a2133da67b0ee859d0' + + >>> binascii.hexlify(compress('scot')) + b'98df40' + + >>> binascii.hexlify(compress('scott')) # Note the last byte + b'98df4a00' + + """ + compressed = bitstring.BitArray() + for letter in uncompressed: + if 'a' <= letter <= 'z': + bits = ord(letter) - ord('a') + 1 # 1..26 + else: + if letter not in special_characters: + raise Exception( + f'"{uncompressed}" contains uncompressable char="{letter}"' + ) + bits = special_characters[letter] + compressed.append(f"uint:5={bits}") + while len(compressed) % 8 != 0: + compressed.append("uint:1=0") + return compressed.bytes + + +def decompress(kompressed: bytes) -> str: + """ + Decompress a previously compressed stream of bytes back into + its original form. + + >>> import binascii + >>> decompress(binascii.unhexlify(b'a2133da67b0ee859d0')) + 'this is a test' + + >>> decompress(binascii.unhexlify(b'98df4a00')) + 'scott' + + """ + decompressed = '' + compressed = bitstring.BitArray(kompressed) + + # There are compressed messages that legitimately end with the + # byte 0x00. The message "scott" is an example; compressed it is + # 0x98df4a00. It's 5 characters long which means there are 5 x 5 + # bits of compressed info (25 bits, just over 3 bytes). The last + # (25th) bit in the steam happens to be a zero. The compress code + # padded out the compressed message by adding seven more zeros to + # complete the partial 4th byte. In the 4th byte, however, one + # bit is information and seven are padding. + # + # It's likely that this API's client code may treat a zero byte as + # a termination character and not regard it as a legitimate part + # of the message. This is a bug in that client code, to be clear. + # + # However, it's a bug we can work around: + # + # Here, I'm appending an extra 0x00 byte to the compressed message + # passed in. If the client code dropped the last 0x00 byte (and, + # with it, some of the legitimate message bits) by treating it as + # a termination mark, this 0x00 will replace it (and the missing + # message bits). If the client code didn't drop the last 0x00 (or + # if the compressed message didn't end in 0x00), adding an extra + # 0x00 is a no op because the codepoint 0b00000 is a "stop" message + # so we'll ignore the extras. + compressed.append("uint:8=0") + + for chunk in compressed.cut(5): + chunk = chunk.uint + if chunk == 0: + break + elif 1 <= chunk <= 26: + letter = chr(chunk - 1 + ord('a')) + else: + letter = special_characters.inverse[chunk][0] + decompressed += letter + return decompressed + + +if __name__ == '__main__': + import doctest + + doctest.testmod() diff --git a/src/pyutils/config.py b/src/pyutils/config.py new file mode 100644 index 0000000..e173e71 --- /dev/null +++ b/src/pyutils/config.py @@ -0,0 +1,749 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""Global configuration driven by commandline arguments, environment variables, +saved configuration files, and zookeeper-based dynamic configurations. This +works across several modules. + +Example usage: + + In your file.py:: + + from pyutils import config + + parser = config.add_commandline_args( + "Module", + "Args related to module doing the thing.", + ) + parser.add_argument( + "--module_do_the_thing", + type=bool, + default=True, + help="Should the module do the thing?" + ) + + In your main.py:: + + from pyutils import config + + parser = config.add_commandline_args( + "Main", + "A program that does the thing.", + ) + parser.add_argument( + "--dry_run", + type=bool, + default=False, + help="Should we really do the thing?" + ) + + def main() -> None: + config.parse() # Very important, this must be invoked! + + If you set this up and remember to invoke config.parse(), all commandline + arguments will play nicely together. This is done automatically for you + if you're using the :meth:`bootstrap.initialize` decorator on + your program's entry point. See :meth:`python_modules.bootstrap.initialize` + for more details.:: + + from pyutils import bootstrap + + @bootstrap.initialize + def main(): + whatever + + if __name__ == '__main__': + main() + + Either way, you'll get this behavior from the commandline:: + + % main.py -h + usage: main.py [-h] + [--module_do_the_thing MODULE_DO_THE_THING] + [--dry_run DRY_RUN] + + Module: + Args related to module doing the thing. + + --module_do_the_thing MODULE_DO_THE_THING + Should the module do the thing? + + Main: + A program that does the thing + + --dry_run + Should we really do the thing? + + Arguments themselves should be accessed via + :code:`config.config['arg_name']`. e.g.:: + + if not config.config['dry_run']: + module.do_the_thing() +""" + +import argparse +import logging +import os +import pprint +import re +import sys +from typing import Any, Dict, List, Optional, Tuple + +# This module is commonly used by others in here and should avoid +# taking any unnecessary dependencies back on them. + +# Make a copy of the original program arguments immediately upon module load. +PROGRAM_NAME: str = os.path.basename(sys.argv[0]) +ORIG_ARGV: List[str] = sys.argv.copy() + + +class OptionalRawFormatter(argparse.HelpFormatter): + """This formatter has the same bahavior as the normal argparse text + formatter except when the help text of an argument begins with + "RAW|". In that case, the line breaks are preserved and the text + is not wrapped. + + Use this, for example, when you need the helptext of an argument + to have its spacing preserved exactly, e.g.:: + + args.add_argument( + '--mode', + type=str, + default='PLAY', + choices=['CHEAT', 'AUTOPLAY', 'SELFTEST', 'PRECOMPUTE', 'PLAY'], + metavar='MODE', + help='''RAW|Our mode of operation. One of: + + PLAY = play wordle with me! Pick a random solution or + specify a solution with --template. + + CHEAT = given a --template and, optionally, --letters_in_word + and/or --letters_to_avoid, return the best guess word; + + AUTOPLAY = given a complete word in --template, guess it step + by step showing work; + + SELFTEST = autoplay every possible solution keeping track of + wins/losses and average number of guesses; + + PRECOMPUTE = populate hash table with optimal guesses. + ''', + ) + """ + + def _split_lines(self, text, width): + if text.startswith('RAW|'): + return text[4:].splitlines() + return argparse.HelpFormatter._split_lines(self, text, width) + + +# A global argparser that we will collect arguments in. Each module (including +# us) will add arguments to a separate argument group. +ARGS = argparse.ArgumentParser( + description=None, + formatter_class=OptionalRawFormatter, + fromfile_prefix_chars="@", + epilog=f'{PROGRAM_NAME} uses config.py ({__file__}) for global, cross-module configuration setup and parsing.', + # I don't fully understand why but when loaded by sphinx sometimes + # the same module is loaded many times causing any arguments it + # registers via module-level code to be redefined. Work around + # this iff the program is 'sphinx-build' + conflict_handler='resolve' if PROGRAM_NAME == 'sphinx-build' else 'error', +) + +# Arguments specific to config.py. Other users should get their own group by +# invoking config.add_commandline_args. +GROUP = ARGS.add_argument_group( + f'Global Config ({__file__})', + 'Args that control the global config itself; how meta!', +) +GROUP.add_argument( + '--config_loadfile', + metavar='FILENAME', + default=None, + help='Config file (populated via --config_savefile) from which to read args in lieu or in addition to those passed via the commandline. Note that if the given path begins with "zk:" then it is interpreted as a zookeeper path instead of as a filesystem path. When loading config from zookeeper, any argument with the string "dynamic" in the name (e.g. --module_dynamic_url) may be modified at runtime by changes made to zookeeper (using --config_savefile=zk:path). You should therefore either write your code to handle dynamic argument changes or avoid naming arguments "dynamic" if you use zookeeper configuration paths.', +) +GROUP.add_argument( + '--config_dump', + default=False, + action='store_true', + help='Display the global configuration (possibly derived from multiple sources) on STDERR at program startup time.', +) +GROUP.add_argument( + '--config_savefile', + type=str, + metavar='FILENAME', + default=None, + help='Populate a config file (compatible with --config_loadfile) with the given path for later use. If the given path begins with "zk:" it is interpreted as a zookeeper path instead of a filesystem path. When updating zookeeper-based configs, all running programs that read their configuration from zookeeper (via --config_loadfile=zk:path) at startup time will see their configuration dynamically updated; flags with "dynamic" in their names (e.g. --my_dynamic_flag) may have their values changed. You should therefore either write your code to handle dynamic argument changes or avoid naming arguments "dynamic" if you use zookeeper configuration paths.', +) +GROUP.add_argument( + '--config_rejects_unrecognized_arguments', + default=False, + action='store_true', + help='If present, config will raise an exception if it doesn\'t recognize an argument. The default behavior is to ignore unknown arguments so as to allow interoperability with programs that want to use their own argparse calls to parse their own, separate commandline args.', +) +GROUP.add_argument( + '--config_exit_after_parse', + default=False, + action='store_true', + help='If present, halt the program after parsing config. Useful, for example, to write a --config_savefile and then terminate.', +) + + +class Config: + """ + Everything in the config module used to be module-level functions and + variables but it made the code ugly and harder to maintain. Now, this + class does the heavy lifting. We still rely on some globals, though: + + ARGS and GROUP to interface with argparse + PROGRAM_NAME stores argv[0] close to program invocation + ORIG_ARGV stores the original argv list close to program invocation + CONFIG and config: hold the (singleton) instance of this class. + + """ + + def __init__(self): + # Has our parse() method been invoked yet? + self.config_parse_called = False + + # A configuration dictionary that will contain parsed + # arguments. This is the data that is most interesting to our + # callers as it will hold the configuration result. + self.config: Dict[str, Any] = {} + + # Defer logging messages until later when logging has been + # initialized. + self.saved_messages: List[str] = [] + + # A zookeeper client that is lazily created so as to not incur + # the latency of connecting to zookeeper for programs that are + # not reading or writing their config data into zookeeper. + self.zk: Optional[Any] = None + + # Per known zk file, what is the max version we have seen? + self.max_version: Dict[str, int] = {} + + def __getitem__(self, key: str) -> Optional[Any]: + """If someone uses []'s on us, pass it onto self.config.""" + return self.config.get(key, None) + + def __setitem__(self, key: str, value: Any) -> None: + self.config[key] = value + + def __contains__(self, key: str) -> bool: + return key in self.config + + def get(self, key: str, default: Any = None) -> Optional[Any]: + return self.config.get(key, default) + + @staticmethod + def add_commandline_args( + title: str, description: str = "" + ) -> argparse._ArgumentGroup: + """Create a new context for arguments and return a handle. + + Args: + title: A title for your module's commandline arguments group. + description: A helpful description of your module. + + Returns: + An argparse._ArgumentGroup to be populated by the caller. + """ + return ARGS.add_argument_group(title, description) + + @staticmethod + def overwrite_argparse_epilog(msg: str) -> None: + """Allows your code to override the default epilog created by + argparse. + + Args: + msg: The epilog message to substitute for the default. + """ + ARGS.epilog = msg + + @staticmethod + def is_flag_already_in_argv(var: str) -> bool: + """Returns true if a particular flag is passed on the commandline + and false otherwise. + + Args: + var: The flag to search for. + """ + for _ in sys.argv: + if var in _: + return True + return False + + @staticmethod + def print_usage() -> None: + """Prints the normal help usage message out.""" + ARGS.print_help() + + @staticmethod + def usage() -> str: + """ + Returns: + program usage help text as a string. + """ + return ARGS.format_usage() + + @staticmethod + def _reorder_arg_action_groups_before_help(entry_module: Optional[str]): + """Internal. Used to reorder the arguments before dumping out a + generated help string such that the main program's arguments come + last. + + """ + reordered_action_groups = [] + for grp in ARGS._action_groups: + if entry_module is not None and entry_module in grp.title: # type: ignore + reordered_action_groups.append(grp) + elif PROGRAM_NAME in GROUP.title: # type: ignore + reordered_action_groups.append(grp) + else: + reordered_action_groups.insert(0, grp) + return reordered_action_groups + + @staticmethod + def _parse_arg_into_env(arg: str) -> Optional[Tuple[str, str, List[str]]]: + """Internal helper to parse commandling args into environment vars.""" + arg = arg.strip() + if not arg.startswith('['): + return None + arg = arg.strip('[') + if not arg.endswith(']'): + return None + arg = arg.strip(']') + + chunks = arg.split() + if len(chunks) > 1: + var = chunks[0] + else: + var = arg + + # Environment vars the same as flag names without + # the initial -'s and in UPPERCASE. + env = var.upper() + while env[0] == '-': + env = env[1:] + return var, env, chunks + + @staticmethod + def _to_bool(in_str: str) -> bool: + """ + Args: + in_str: the string to convert to boolean + + Returns: + A boolean equivalent of the original string based on its contents. + All conversion is case insensitive. A positive boolean (True) is + returned if the string value is any of the following: + + * "true" + * "t" + * "1" + * "yes" + * "y" + * "on" + + Otherwise False is returned. + + >>> to_bool('True') + True + + >>> to_bool('1') + True + + >>> to_bool('yes') + True + + >>> to_bool('no') + False + + >>> to_bool('huh?') + False + + >>> to_bool('on') + True + """ + return in_str.lower() in ("true", "1", "yes", "y", "t", "on") + + def _augment_sys_argv_from_environment_variables(self): + """Internal. Look at the system environment for variables that match + commandline arg names. This is done via some munging such that: + + :code:`--argument_to_match` + + ...is matched by: + + :code:`ARGUMENT_TO_MATCH` + + This allows users to set args via shell environment variables + in lieu of passing them on the cmdline. + + """ + usage_message = Config.usage() + optional = False + arg = '' + + # Foreach valid optional commandline option (chunk) generate + # its analogous environment variable. + for chunk in usage_message.split(): + if chunk[0] == '[': + optional = True + if optional: + arg += f'{chunk} ' + if chunk[-1] == ']': + optional = False + _ = Config._parse_arg_into_env(arg) + if _: + var, env, chunks = _ + if env in os.environ: + if not Config.is_flag_already_in_argv(var): + value = os.environ[env] + self.saved_messages.append( + f'Initialized from environment: {var} = {value}' + ) + if len(chunks) == 1 and Config._to_bool(value): + sys.argv.append(var) + elif len(chunks) > 1: + sys.argv.append(var) + sys.argv.append(value) + arg = '' + + def _process_dynamic_args(self, event): + """Invoked as a callback when a zk-based config changed.""" + + if not self.zk: + return + logger = logging.getLogger(__name__) + try: + contents, meta = self.zk.get(event.path, watch=self._process_dynamic_args) + logger.debug('Update for %s at version=%d.', event.path, meta.version) + logger.debug( + 'Max known version for %s is %d.', + event.path, + self.max_version.get(event.path, 0), + ) + except Exception as e: + raise Exception('Error reading data from zookeeper') from e + + # Make sure we process changes in order. + if meta.version > self.max_version.get(event.path, 0): + self.max_version[event.path] = meta.version + contents = contents.decode() + temp_argv = [] + for arg in contents.split(): + + # Our rule is that arguments must contain the word + # 'dynamic' if we are going to allow them to change at + # runtime as a signal that the programmer is expecting + # this. + if 'dynamic' in arg: + temp_argv.append(arg) + logger.info("Updating %s from zookeeper async config change.", arg) + + if len(temp_argv) > 0: + old_argv = sys.argv + sys.argv = temp_argv + known, _ = ARGS.parse_known_args() + sys.argv = old_argv + self.config.update(vars(known)) + + def _read_config_from_zookeeper(self, zkpath: str) -> Optional[str]: + from pyutils import zookeeper + + if not zkpath.startswith('/config/'): + zkpath = '/config/' + zkpath + zkpath = re.sub(r'//+', '/', zkpath) + + try: + if self.zk is None: + self.zk = zookeeper.get_started_zk_client() + if not self.zk.exists(zkpath): + return None + + # Note: we're putting a watch on this config file. Our + # _process_dynamic_args routine will be called to reparse + # args when/if they change. + contents, meta = self.zk.get(zkpath, watch=self._process_dynamic_args) + contents = contents.decode() + self.saved_messages.append( + f'Setting {zkpath}\'s max_version to {meta.version}' + ) + self.max_version[zkpath] = meta.version + self.saved_messages.append(f'Read config from zookeeper {zkpath}.') + return contents + except Exception as e: + self.saved_messages.append( + f'Failed to read {zkpath} from zookeeper: exception {e}' + ) + return None + + def _read_config_from_disk(self, filepath: str) -> Optional[str]: + if not os.path.exists(filepath): + return None + with open(filepath, 'r') as rf: + self.saved_messages.append(f'Read config from disk file {filepath}') + return rf.read() + + def _augment_sys_argv_from_loadfile(self): + """Internal. Augment with arguments persisted in a saved file.""" + + # Check for --config_loadfile in the args manually; argparse isn't + # invoked yet and can't be yet. + loadfile = None + saw_other_args = False + grab_next_arg = False + for arg in sys.argv[1:]: + if 'config_loadfile' in arg: + pieces = arg.split('=') + if len(pieces) > 1: + loadfile = pieces[1] + else: + grab_next_arg = True + elif grab_next_arg: + loadfile = arg + else: + saw_other_args = True + + if not loadfile or len(loadfile) == 0: + return + + # Get contents from wherever. + contents = None + if loadfile[:3] == 'zk:': + contents = self._read_config_from_zookeeper(loadfile[3:]) + else: + contents = self._read_config_from_disk(loadfile) + + if contents: + if saw_other_args: + msg = f'Augmenting commandline arguments with those from {loadfile}.' + else: + msg = f'Reading commandline arguments from {loadfile}.' + print(msg, file=sys.stderr) + self.saved_messages.append(msg) + else: + msg = f'Failed to read/parse contents from {loadfile}' + print(msg, file=sys.stderr) + self.saved_messages.append(msg) + return + + # Augment args with new ones. + newargs = [ + arg.strip('\n') + for arg in contents.split('\n') + if 'config_savefile' not in arg + ] + sys.argv += newargs + + def dump_config(self): + """Print the current config to stdout.""" + print("Global Configuration:", file=sys.stderr) + pprint.pprint(self.config, stream=sys.stderr) + print() + + def _write_config_to_disk(self, data: str, filepath: str) -> None: + with open(filepath, 'w') as wf: + wf.write(data) + + def _write_config_to_zookeeper(self, data: str, zkpath: str) -> None: + if not zkpath.startswith('/config/'): + zkpath = '/config/' + zkpath + zkpath = re.sub(r'//+', '/', zkpath) + try: + if not self.zk: + from pyutils import zookeeper + + self.zk = zookeeper.get_started_zk_client() + encoded_data = data.encode() + if len(encoded_data) > 1024 * 1024: + raise Exception( + f'Saved args are too large ({len(encoded_data)} bytes exceeds zk limit)' + ) + if not self.zk.exists(zkpath): + self.zk.create(zkpath, encoded_data) + self.saved_messages.append( + f'Just created {zkpath}; setting its max_version to 0' + ) + self.max_version[zkpath] = 0 + else: + meta = self.zk.set(zkpath, encoded_data) + self.saved_messages.append( + f'Setting {zkpath}\'s max_version to {meta.version}' + ) + self.max_version[zkpath] = meta.version + except Exception as e: + raise Exception(f'Failed to create zookeeper path {zkpath}') from e + self.saved_messages.append(f'Saved config to zookeeper in {zkpath}') + + def parse(self, entry_module: Optional[str]) -> Dict[str, Any]: + """Main program should call this early in main(). Note that the + :code:`bootstrap.initialize` wrapper takes care of this automatically. + This should only be called once per program invocation. + + """ + if self.config_parse_called: + return self.config + + # If we're about to do the usage message dump, put the main + # module's argument group last in the list (if possible) so that + # when the user passes -h or --help, it will be visible on the + # screen w/o scrolling. + for arg in sys.argv: + if arg in ('--help', '-h'): + if entry_module is not None: + entry_module = os.path.basename(entry_module) + ARGS._action_groups = Config._reorder_arg_action_groups_before_help( + entry_module + ) + + # Examine the environment for variables that match known flags. + # For a flag called --example_flag the corresponding environment + # variable would be called EXAMPLE_FLAG. If found, hackily add + # these into sys.argv to be parsed. + self._augment_sys_argv_from_environment_variables() + + # Look for loadfile and read/parse it if present. This also + # works by jamming these values onto sys.argv. + self._augment_sys_argv_from_loadfile() + + # Parse (possibly augmented, possibly completely overwritten) + # commandline args with argparse normally and populate config. + known, unknown = ARGS.parse_known_args() + self.config.update(vars(known)) + + # Reconstruct the argv with unrecognized flags for the benefit of + # future argument parsers. For example, unittest_main in python + # has some of its own flags. If we didn't recognize it, maybe + # someone else will. + if len(unknown) > 0: + if config['config_rejects_unrecognized_arguments']: + raise Exception( + f'Encountered unrecognized config argument(s) {unknown} with --config_rejects_unrecognized_arguments enabled; halting.' + ) + self.saved_messages.append( + f'Config encountered unrecognized commandline arguments: {unknown}' + ) + sys.argv = sys.argv[:1] + unknown + + # Check for savefile and populate it if requested. + savefile = config['config_savefile'] + if savefile and len(savefile) > 0: + data = '\n'.join(ORIG_ARGV[1:]) + if savefile[:3] == 'zk:': + self._write_config_to_zookeeper(savefile[3:], data) + else: + self._write_config_to_disk(savefile, data) + + # Also dump the config on stderr if requested. + if config['config_dump']: + self.dump_config() + + self.config_parse_called = True + if config['config_exit_after_parse']: + print("Exiting because of --config_exit_after_parse.") + if self.zk: + self.zk.stop() + sys.exit(0) + return self.config + + def has_been_parsed(self) -> bool: + """Returns True iff the global config has already been parsed""" + return self.config_parse_called + + def late_logging(self): + """Log messages saved earlier now that logging has been initialized.""" + logger = logging.getLogger(__name__) + logger.debug('Original commandline was: %s', ORIG_ARGV) + for _ in self.saved_messages: + logger.debug(_) + + +# A global singleton instance of the Config class. +CONFIG = Config() + +# A lot of client code uses config.config['whatever'] to lookup +# configuration so to preserve this we make this, config.config, with +# a __getitem__ method on it. +config = CONFIG + +# Config didn't use to be a class; it was a mess of module-level +# functions and data. The functions below preserve the old interface +# so that existing clients do not need to be changed. As you can see, +# they mostly just thunk into the config class. + + +def add_commandline_args(title: str, description: str = "") -> argparse._ArgumentGroup: + """Create a new context for arguments and return a handle. An alias + for config.config.add_commandline_args. + + Args: + title: A title for your module's commandline arguments group. + description: A helpful description of your module. + + Returns: + An argparse._ArgumentGroup to be populated by the caller. + """ + return CONFIG.add_commandline_args(title, description) + + +def parse(entry_module: Optional[str]) -> Dict[str, Any]: + """Main program should call this early in main(). Note that the + :code:`bootstrap.initialize` wrapper takes care of this automatically. + This should only be called once per program invocation. Subsequent + calls do not reparse the configuration settings but rather just + return the current state. + """ + return CONFIG.parse(entry_module) + + +def has_been_parsed() -> bool: + """Returns True iff the global config has already been parsed""" + return CONFIG.has_been_parsed() + + +def late_logging() -> None: + """Log messages saved earlier now that logging has been initialized.""" + CONFIG.late_logging() + + +def dump_config() -> None: + """Print the current config to stdout.""" + CONFIG.dump_config() + + +def overwrite_argparse_epilog(msg: str) -> None: + """Allows your code to override the default epilog created by + argparse. + + Args: + msg: The epilog message to substitute for the default. + """ + Config.overwrite_argparse_epilog(msg) + + +def is_flag_already_in_argv(var: str) -> bool: + """Returns true if a particular flag is passed on the commandline + and false otherwise. + + Args: + var: The flag to search for. + """ + return Config.is_flag_already_in_argv(var) + + +def print_usage() -> None: + """Prints the normal help usage message out.""" + Config.print_usage() + + +def usage() -> str: + """ + Returns: + program usage help text as a string. + """ + return Config.usage() diff --git a/src/pyutils/datetimez/.gitignore b/src/pyutils/datetimez/.gitignore new file mode 100644 index 0000000..2da2229 --- /dev/null +++ b/src/pyutils/datetimez/.gitignore @@ -0,0 +1,7 @@ +dateparse_utils.interp +dateparse_utils.tokens +dateparse_utilsLexer.interp +dateparse_utilsLexer.py +dateparse_utilsLexer.tokens +dateparse_utilsListener.py +dateparse_utilsParser.py diff --git a/src/pyutils/datetimez/__init__.py b/src/pyutils/datetimez/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/pyutils/datetimez/constants.py b/src/pyutils/datetimez/constants.py new file mode 100644 index 0000000..0b3fed1 --- /dev/null +++ b/src/pyutils/datetimez/constants.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""Universal date/time constants.""" + +from typing import Final + +# This module is commonly used by others in here and should avoid +# taking any unnecessary dependencies back on them. + +# Date/time based constants +SECONDS_PER_MINUTE: Final = 60 +SECONDS_PER_HOUR: Final = 60 * SECONDS_PER_MINUTE +SECONDS_PER_DAY: Final = 24 * SECONDS_PER_HOUR +SECONDS_PER_WEEK: Final = 7 * SECONDS_PER_DAY +MINUTES_PER_HOUR: Final = 60 +MINUTES_PER_DAY: Final = 24 * MINUTES_PER_HOUR +MINUTES_PER_WEEK: Final = 7 * MINUTES_PER_DAY +HOURS_PER_DAY: Final = 24 +HOURS_PER_WEEK: Final = 7 * HOURS_PER_DAY +DAYS_PER_WEEK: Final = 7 diff --git a/src/pyutils/datetimez/dateparse_utils.g4 b/src/pyutils/datetimez/dateparse_utils.g4 new file mode 100644 index 0000000..364aa0f --- /dev/null +++ b/src/pyutils/datetimez/dateparse_utils.g4 @@ -0,0 +1,683 @@ +// © Copyright 2021-2022, Scott Gasch +// +// antlr4 -Dlanguage=Python3 ./dateparse_utils.g4 +// +// Hi, self. In ANTLR grammars, there are two separate types of symbols: those +// for the lexer and those for the parser. The former begin with a CAPITAL +// whereas the latter begin with lowercase. The order of the lexer symbols +// is the order that the lexer will recognize them in. There's a good tutorial +// on this shit at: +// +// https://tomassetti.me/antlr-mega-tutorial/ +// +// There are also a zillion premade grammars at: +// +// https://github.com/antlr/grammars-v4 + +grammar dateparse_utils; + +parse + : SPACE* dateExpr + | SPACE* timeExpr + | SPACE* dateExpr SPACE* dtdiv? SPACE* timeExpr + | SPACE* timeExpr SPACE* tddiv? SPACE+ dateExpr + ; + +dateExpr + : singleDateExpr + | baseAndOffsetDateExpr + ; + +timeExpr + : singleTimeExpr + | baseAndOffsetTimeExpr + ; + +singleTimeExpr + : twentyFourHourTimeExpr + | twelveHourTimeExpr + | specialTimeExpr + ; + +twentyFourHourTimeExpr + : hour ((SPACE|tdiv)+ minute ((SPACE|tdiv)+ second ((SPACE|tdiv)+ micros)? )? )? SPACE* tzExpr? + ; + +twelveHourTimeExpr + : hour ((SPACE|tdiv)+ minute ((SPACE|tdiv)+ second ((SPACE|tdiv)+ micros)? )? )? SPACE* ampm SPACE* tzExpr? + ; + +ampm: ('a'|'am'|'p'|'pm'|'AM'|'PM'|'A'|'P'); + +singleDateExpr + : monthDayMaybeYearExpr + | dayMonthMaybeYearExpr + | yearMonthDayExpr + | specialDateMaybeYearExpr + | nthWeekdayInMonthMaybeYearExpr + | firstLastWeekdayInMonthMaybeYearExpr + | deltaDateExprRelativeToTodayImplied + | dayName (SPACE|ddiv)+ monthDayMaybeYearExpr (SPACE|ddiv)* singleTimeExpr* + | dayName + ; + +monthDayMaybeYearExpr + : monthExpr (SPACE|ddiv)+ dayOfMonth ((SPACE|ddiv)+ year)? + ; + +dayMonthMaybeYearExpr + : dayOfMonth (SPACE|ddiv)+ monthName ((SPACE|ddiv)+ year)? + ; + +yearMonthDayExpr + : year (SPACE|ddiv)+ monthExpr (SPACE|ddiv)+ dayOfMonth + ; + +nthWeekdayInMonthMaybeYearExpr + : nth SPACE+ dayName SPACE+ ('in'|'of') SPACE+ monthName ((ddiv|SPACE)+ year)? + ; + +firstLastWeekdayInMonthMaybeYearExpr + : firstOrLast SPACE+ dayName (SPACE+ ('in'|'of'))? SPACE+ monthName ((ddiv|SPACE)+ year)? + ; + +specialDateMaybeYearExpr + : (thisNextLast SPACE+)? specialDate ((SPACE|ddiv)+ year)? + ; + +thisNextLast: (THIS|NEXT|LAST) ; + +baseAndOffsetDateExpr + : baseDate SPACE+ deltaPlusMinusExpr + | deltaPlusMinusExpr SPACE+ baseDate + ; + +deltaDateExprRelativeToTodayImplied + : nFoosFromTodayAgoExpr + | deltaRelativeToTodayExpr + ; + +deltaRelativeToTodayExpr + : thisNextLast SPACE+ deltaUnit + ; + +nFoosFromTodayAgoExpr + : unsignedInt SPACE+ deltaUnit SPACE+ AGO_FROM_NOW + ; + +baseDate: singleDateExpr ; + +baseAndOffsetTimeExpr + : deltaPlusMinusTimeExpr SPACE+ baseTime + | baseTime SPACE+ deltaPlusMinusTimeExpr + ; + +baseTime: singleTimeExpr ; + +deltaPlusMinusExpr + : nth SPACE+ deltaUnit (SPACE+ deltaBeforeAfter)? + ; + +deltaNextLast + : (NEXT|LAST) ; + +deltaPlusMinusTimeExpr + : countUnitsBeforeAfterTimeExpr + | fractionBeforeAfterTimeExpr + ; + +countUnitsBeforeAfterTimeExpr + : nth (SPACE+ deltaTimeUnit)? SPACE+ deltaTimeBeforeAfter + ; + +fractionBeforeAfterTimeExpr + : deltaTimeFraction SPACE+ deltaTimeBeforeAfter + ; + +deltaUnit: (YEAR|MONTH|WEEK|DAY|WEEKDAY|WORKDAY) ; + +deltaTimeUnit: (SECOND|MINUTE|HOUR|WORKDAY) ; + +deltaBeforeAfter: (BEFORE|AFTER) ; + +deltaTimeBeforeAfter: (BEFORE|AFTER) ; + +monthExpr + : monthName + | monthNumber + ; + +year: DIGIT DIGIT DIGIT DIGIT ; + +specialDate: SPECIAL_DATE ; + +dayOfMonth + : DIGIT DIGIT? ('st'|'ST'|'nd'|'ND'|'rd'|'RD'|'th'|'TH')? + | KALENDS (SPACE+ 'of')? + | IDES (SPACE+ 'of')? + | NONES (SPACE+ 'of')? + ; + +firstOrLast: (FIRST|LAST) ; + +nth: (DASH|PLUS)? DIGIT+ ('st'|'ST'|'nd'|'ND'|'rd'|'RD'|'th'|'TH')? ; + +unsignedInt: DIGIT+ ; + +deltaTimeFraction: DELTA_TIME_FRACTION ; + +specialTimeExpr: specialTime (SPACE+ tzExpr)? ; + +specialTime: SPECIAL_TIME ; + +dayName: WEEKDAY ; + +monthName: MONTH_NAME ; + +monthNumber: DIGIT DIGIT? ; + +hour: DIGIT DIGIT? ; + +minute: DIGIT DIGIT ; + +second: DIGIT DIGIT ; + +micros: DIGIT DIGIT? DIGIT? DIGIT? DIGIT? DIGIT? DIGIT? ; + +ddiv: (SLASH|DASH|',') ; + +tdiv: (COLON|DOT) ; + +dtdiv: ('T'|'t'|'at'|'AT'|','|';') ; + +tddiv: ('on'|'ON'|','|';') ; + +tzExpr + : ntz + | ltz + ; + +ntz: (PLUS|DASH) DIGIT DIGIT COLON? DIGIT DIGIT ; + +ltz: UPPERCASE_STRING ; + +// ---------------------------------- + +SPACE: [ \t\r\n] ; + +COMMENT: '#' ~[\r\n]* -> skip ; + +THE: ('the'|'The') SPACE* -> skip ; + +DASH: '-' ; + +PLUS: '+' ; + +SLASH: '/' ; + +DOT: '.' ; + +COLON: ':' ; + +MONTH_NAME: (JAN|FEB|MAR|APR|MAY|JUN|JUL|AUG|SEP|OCT|NOV|DEC) ; + +JAN : 'jan' + | 'Jan' + | 'JAN' + | 'January' + | 'january' + ; + +FEB : 'feb' + | 'Feb' + | 'FEB' + | 'February' + | 'february' + ; + +MAR : 'mar' + | 'Mar' + | 'MAR' + | 'March' + | 'march' + ; + +APR : 'apr' + | 'Apr' + | 'APR' + | 'April' + | 'april' + ; + +MAY : 'may' + | 'May' + | 'MAY' + ; + +JUN : 'jun' + | 'Jun' + | 'JUN' + | 'June' + | 'june' + ; + +JUL : 'jul' + | 'Jul' + | 'JUL' + | 'July' + | 'july' + ; + +AUG : 'aug' + | 'Aug' + | 'AUG' + | 'August' + | 'august' + ; + +SEP : 'sep' + | 'Sep' + | 'SEP' + | 'sept' + | 'Sept' + | 'SEPT' + | 'September' + | 'september' + ; + +OCT : 'oct' + | 'Oct' + | 'OCT' + | 'October' + | 'october' + ; + +NOV : 'nov' + | 'Nov' + | 'NOV' + | 'November' + | 'november' + ; + +DEC : 'dec' + | 'Dec' + | 'DEC' + | 'December' + | 'december' + ; + +WEEKDAY: (SUN|MON|TUE|WED|THU|FRI|SAT) ; + +SUN : 'sun' + | 'Sun' + | 'SUN' + | 'suns' + | 'Suns' + | 'SUNS' + | 'sunday' + | 'Sunday' + | 'sundays' + | 'Sundays' + ; + +MON : 'mon' + | 'Mon' + | 'MON' + | 'mons' + | 'Mons' + | 'MONS' + | 'monday' + | 'Monday' + | 'mondays' + | 'Mondays' + ; + +TUE : 'tue' + | 'Tue' + | 'TUE' + | 'tues' + | 'Tues' + | 'TUES' + | 'tuesday' + | 'Tuesday' + | 'tuesdays' + | 'Tuesdays' + ; + +WED : 'wed' + | 'Wed' + | 'WED' + | 'weds' + | 'Weds' + | 'WEDS' + | 'wednesday' + | 'Wednesday' + | 'wednesdays' + | 'Wednesdays' + ; + +THU : 'thu' + | 'Thu' + | 'THU' + | 'thur' + | 'Thur' + | 'THUR' + | 'thurs' + | 'Thurs' + | 'THURS' + | 'thursday' + | 'Thursday' + | 'thursdays' + | 'Thursdays' + ; + +FRI : 'fri' + | 'Fri' + | 'FRI' + | 'fris' + | 'Fris' + | 'FRIS' + | 'friday' + | 'Friday' + | 'fridays' + | 'Fridays' + ; + +SAT : 'sat' + | 'Sat' + | 'SAT' + | 'sats' + | 'Sats' + | 'SATS' + | 'saturday' + | 'Saturday' + | 'saturdays' + | 'Saturdays' + ; + +WEEK + : 'week' + | 'Week' + | 'weeks' + | 'Weeks' + | 'wks' + ; + +DAY + : 'day' + | 'Day' + | 'days' + | 'Days' + ; + +HOUR + : 'hour' + | 'Hour' + | 'hours' + | 'Hours' + | 'hrs' + ; + +MINUTE + : 'min' + | 'Min' + | 'MIN' + | 'mins' + | 'Mins' + | 'MINS' + | 'minute' + | 'Minute' + | 'minutes' + | 'Minutes' + ; + +SECOND + : 'sec' + | 'Sec' + | 'SEC' + | 'secs' + | 'Secs' + | 'SECS' + | 'second' + | 'Second' + | 'seconds' + | 'Seconds' + ; + +MONTH + : 'month' + | 'Month' + | 'months' + | 'Months' + ; + +YEAR + : 'year' + | 'Year' + | 'years' + | 'Years' + | 'yrs' + ; + +SPECIAL_DATE + : TODAY + | YESTERDAY + | TOMORROW + | NEW_YEARS_EVE + | NEW_YEARS_DAY + | MARTIN_LUTHER_KING_DAY + | PRESIDENTS_DAY + | EASTER + | MEMORIAL_DAY + | INDEPENDENCE_DAY + | LABOR_DAY + | COLUMBUS_DAY + | VETERANS_DAY + | HALLOWEEN + | THANKSGIVING_DAY + | CHRISTMAS_EVE + | CHRISTMAS + ; + +SPECIAL_TIME + : NOON + | MIDNIGHT + ; + +NOON + : ('noon'|'Noon'|'midday'|'Midday') + ; + +MIDNIGHT + : ('midnight'|'Midnight') + ; + +// today +TODAY + : ('today'|'Today'|'now'|'Now') + ; + +// yeste +YESTERDAY + : ('yesterday'|'Yesterday') + ; + +// tomor +TOMORROW + : ('tomorrow'|'Tomorrow') + ; + +// easte +EASTER + : 'easter' SUN? + | 'Easter' SUN? + ; + +// newye +NEW_YEARS_DAY + : 'new years' + | 'New Years' + | 'new years day' + | 'New Years Day' + | 'new year\'s' + | 'New Year\'s' + | 'new year\'s day' + | 'New year\'s Day' + ; + +// newyeeve +NEW_YEARS_EVE + : 'nye' + | 'NYE' + | 'new years eve' + | 'New Years Eve' + | 'new year\'s eve' + | 'New Year\'s Eve' + ; + +// chris +CHRISTMAS + : 'christmas' + | 'Christmas' + | 'christmas day' + | 'Christmas Day' + | 'xmas' + | 'Xmas' + | 'xmas day' + | 'Xmas Day' + ; + +// chriseve +CHRISTMAS_EVE + : 'christmas eve' + | 'Christmas Eve' + | 'xmas eve' + | 'Xmas Eve' + ; + +// mlk +MARTIN_LUTHER_KING_DAY + : 'martin luther king day' + | 'Martin Luther King Day' + | 'mlk day' + | 'MLK Day' + | 'MLK day' + | 'mlk' + | 'MLK' + ; + +// memor +MEMORIAL_DAY + : 'memorial' + | 'Memorial' + | 'memorial day' + | 'Memorial Day' + ; + +// indep +INDEPENDENCE_DAY + : 'independence day' + | 'Independence day' + | 'Independence Day' + ; + +// labor +LABOR_DAY + : 'labor' + | 'Labor' + | 'labor day' + | 'Labor Day' + ; + +// presi +PRESIDENTS_DAY + : 'presidents\' day' + | 'president\'s day' + | 'presidents day' + | 'presidents' + | 'president\'s' + | 'presidents\'' + | 'Presidents\' Day' + | 'President\'s Day' + | 'Presidents Day' + | 'Presidents' + | 'President\'s' + | 'Presidents\'' + ; + +// colum +COLUMBUS_DAY + : 'columbus' + | 'columbus day' + | 'indiginous peoples day' + | 'indiginous peoples\' day' + | 'Columbus' + | 'Columbus Day' + | 'Indiginous Peoples Day' + | 'Indiginous Peoples\' Day' + ; + +// veter +VETERANS_DAY + : 'veterans' + | 'veterans day' + | 'veterans\' day' + | 'Veterans' + | 'Veterans Day' + | 'Veterans\' Day' + ; + +// hallo +HALLOWEEN + : 'halloween' + | 'Halloween' + ; + +// thank +THANKSGIVING_DAY + : 'thanksgiving' + | 'thanksgiving day' + | 'Thanksgiving' + | 'Thanksgiving Day' + ; + +FIRST: ('first'|'First') ; + +LAST: ('last'|'Last'|'this past') ; + +THIS: ('this'|'This'|'this coming') ; + +NEXT: ('next'|'Next') ; + +AGO_FROM_NOW: (AGO|FROM_NOW) ; + +AGO: ('ago'|'Ago'|'back'|'Back') ; + +FROM_NOW: ('from now'|'From Now') ; + +BEFORE: ('to'|'To'|'before'|'Before'|'til'|'until'|'Until') ; + +AFTER: ('after'|'After'|'from'|'From'|'past'|'Past') ; + +DELTA_TIME_FRACTION: ('quarter'|'Quarter'|'half'|'Half') ; + +DIGIT: ('0'..'9') ; + +IDES: ('ides'|'Ides') ; + +NONES: ('nones'|'Nones') ; + +KALENDS: ('kalends'|'Kalends') ; + +WORKDAY + : 'workday' + | 'workdays' + | 'work days' + | 'working days' + | 'Workday' + | 'Workdays' + | 'Work Days' + | 'Working Days' + ; + +UPPERCASE_STRING: [A-Z]+ ; diff --git a/src/pyutils/datetimez/dateparse_utils.py b/src/pyutils/datetimez/dateparse_utils.py new file mode 100755 index 0000000..3ae2e1f --- /dev/null +++ b/src/pyutils/datetimez/dateparse_utils.py @@ -0,0 +1,1055 @@ +#!/usr/bin/env python3 +# type: ignore +# pylint: disable=W0201 +# pylint: disable=R0904 + +# © Copyright 2021-2022, Scott Gasch + +"""Parse dates in a variety of formats.""" + +import datetime +import functools +import logging +import re +import sys +from typing import Any, Callable, Dict, Optional + +import antlr4 # type: ignore +import dateutil.easter +import dateutil.tz +import holidays # type: ignore +import pytz + +from pyutils import bootstrap, decorator_utils +from pyutils.datetimez.dateparse_utilsLexer import dateparse_utilsLexer # type: ignore +from pyutils.datetimez.dateparse_utilsListener import ( + dateparse_utilsListener, +) # type: ignore +from pyutils.datetimez.dateparse_utilsParser import ( + dateparse_utilsParser, +) # type: ignore +from pyutils.datetimez.datetime_utils import ( + TimeUnit, + date_to_datetime, + datetime_to_date, + n_timeunits_from_base, +) +from pyutils.security import acl + +logger = logging.getLogger(__name__) + + +def debug_parse(enter_or_exit_f: Callable[[Any, Any], None]): + @functools.wraps(enter_or_exit_f) + def debug_parse_wrapper(*args, **kwargs): + # slf = args[0] + ctx = args[1] + depth = ctx.depth() + logger.debug( + ' ' * (depth - 1) + + f'Entering {enter_or_exit_f.__name__} ({ctx.invokingState} / {ctx.exception})' + ) + for c in ctx.getChildren(): + logger.debug(' ' * (depth - 1) + f'{c} {type(c)}') + retval = enter_or_exit_f(*args, **kwargs) + return retval + + return debug_parse_wrapper + + +class ParseException(Exception): + """An exception thrown during parsing because of unrecognized input.""" + + def __init__(self, message: str) -> None: + super().__init__() + self.message = message + + +class RaisingErrorListener(antlr4.DiagnosticErrorListener): + """An error listener that raises ParseExceptions.""" + + def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e): + raise ParseException(msg) + + def reportAmbiguity( + self, recognizer, dfa, startIndex, stopIndex, exact, ambigAlts, configs + ): + pass + + def reportAttemptingFullContext( + self, recognizer, dfa, startIndex, stopIndex, conflictingAlts, configs + ): + pass + + def reportContextSensitivity( + self, recognizer, dfa, startIndex, stopIndex, prediction, configs + ): + pass + + +@decorator_utils.decorate_matching_methods_with( + debug_parse, + acl=acl.StringWildcardBasedACL( + allowed_patterns=[ + 'enter*', + 'exit*', + ], + denied_patterns=['enterEveryRule', 'exitEveryRule'], + order_to_check_allow_deny=acl.Order.DENY_ALLOW, + default_answer=False, + ), +) +class DateParser(dateparse_utilsListener): + """A class to parse dates expressed in human language.""" + + PARSE_TYPE_SINGLE_DATE_EXPR = 1 + PARSE_TYPE_BASE_AND_OFFSET_EXPR = 2 + PARSE_TYPE_SINGLE_TIME_EXPR = 3 + PARSE_TYPE_BASE_AND_OFFSET_TIME_EXPR = 4 + + def __init__(self, *, override_now_for_test_purposes=None) -> None: + """C'tor. Passing a value to override_now_for_test_purposes can be + used to force this instance to use a custom date/time for its + idea of "now" so that the code can be more easily unittested. + Leave as None for real use cases. + """ + self.month_name_to_number = { + 'jan': 1, + 'feb': 2, + 'mar': 3, + 'apr': 4, + 'may': 5, + 'jun': 6, + 'jul': 7, + 'aug': 8, + 'sep': 9, + 'oct': 10, + 'nov': 11, + 'dec': 12, + } + + # Used only for ides/nones. Month length on a non-leap year. + self.typical_days_per_month = { + 1: 31, + 2: 28, + 3: 31, + 4: 30, + 5: 31, + 6: 30, + 7: 31, + 8: 31, + 9: 30, + 10: 31, + 11: 30, + 12: 31, + } + + # N.B. day number is also synched with datetime_utils.TimeUnit values + # which allows expressions like "3 wednesdays from now" to work. + self.day_name_to_number = { + 'mon': 0, + 'tue': 1, + 'wed': 2, + 'thu': 3, + 'fri': 4, + 'sat': 5, + 'sun': 6, + } + + # These TimeUnits are defined in datetime_utils and are used as params + # to datetime_utils.n_timeunits_from_base. + self.time_delta_unit_to_constant = { + 'hou': TimeUnit.HOURS, + 'min': TimeUnit.MINUTES, + 'sec': TimeUnit.SECONDS, + } + self.delta_unit_to_constant = { + 'day': TimeUnit.DAYS, + 'wor': TimeUnit.WORKDAYS, + 'wee': TimeUnit.WEEKS, + 'mon': TimeUnit.MONTHS, + 'yea': TimeUnit.YEARS, + } + self.override_now_for_test_purposes = override_now_for_test_purposes + + # Note: _reset defines several class fields. It is used both here + # in the c'tor but also in between parse operations to restore the + # class' state and allow it to be reused. + # + self._reset() + + def parse(self, date_string: str) -> Optional[datetime.datetime]: + """Parse a date/time expression and return a timezone agnostic + datetime on success. Also sets self.datetime, self.date and + self.time which can each be accessed other methods on the + class: get_datetime(), get_date() and get_time(). Raises a + ParseException with a helpful(?) message on parse error or + confusion. + + To get an idea of what expressions can be parsed, check out + the unittest and the grammar. + + Usage: + + txt = '3 weeks before last tues at 9:15am' + dp = DateParser() + dt1 = dp.parse(txt) + dt2 = dp.get_datetime(tz=pytz.timezone('US/Pacific')) + + # dt1 and dt2 will be identical other than the fact that + # the latter's tzinfo will be set to PST/PDT. + + This is the main entrypoint to this class for caller code. + """ + date_string = date_string.strip() + date_string = re.sub(r'\s+', ' ', date_string) + self._reset() + listener = RaisingErrorListener() + input_stream = antlr4.InputStream(date_string) + lexer = dateparse_utilsLexer(input_stream) + lexer.removeErrorListeners() + lexer.addErrorListener(listener) + stream = antlr4.CommonTokenStream(lexer) + parser = dateparse_utilsParser(stream) + parser.removeErrorListeners() + parser.addErrorListener(listener) + tree = parser.parse() + walker = antlr4.ParseTreeWalker() + walker.walk(self, tree) + return self.datetime + + def get_date(self) -> Optional[datetime.date]: + """Return the date part or None.""" + return self.date + + def get_time(self) -> Optional[datetime.time]: + """Return the time part or None.""" + return self.time + + def get_datetime(self, *, tz=None) -> Optional[datetime.datetime]: + """Return as a datetime. Parsed date expressions without any time + part return hours = minutes = seconds = microseconds = 0 (i.e. at + midnight that day). Parsed time expressions without any date part + default to date = today. + + The optional tz param allows the caller to request the datetime be + timezone aware and sets the tzinfo to the indicated zone. Defaults + to timezone naive (i.e. tzinfo = None). + """ + dt = self.datetime + if dt is not None: + if tz is not None: + dt = dt.replace(tzinfo=None).astimezone(tz=tz) + return dt + + # -- helpers -- + + def _reset(self): + """Reset at init and between parses.""" + if self.override_now_for_test_purposes is None: + self.now_datetime = datetime.datetime.now() + self.today = datetime.date.today() + else: + self.now_datetime = self.override_now_for_test_purposes + self.today = datetime_to_date(self.override_now_for_test_purposes) + self.date: Optional[datetime.date] = None + self.time: Optional[datetime.time] = None + self.datetime: Optional[datetime.datetime] = None + self.context: Dict[str, Any] = {} + self.timedelta = datetime.timedelta(seconds=0) + self.saw_overt_year = False + + @staticmethod + def _normalize_special_day_name(name: str) -> str: + """String normalization / canonicalization for date expressions.""" + name = name.lower() + name = name.replace("'", '') + name = name.replace('xmas', 'christmas') + name = name.replace('mlk', 'martin luther king') + name = name.replace(' ', '') + eve = 'eve' if name[-3:] == 'eve' else '' + name = name[:5] + eve + name = name.replace('washi', 'presi') + return name + + def _figure_out_date_unit(self, orig: str) -> TimeUnit: + """Figure out what unit a date expression piece is talking about.""" + if 'month' in orig: + return TimeUnit.MONTHS + txt = orig.lower()[:3] + if txt in self.day_name_to_number: + return TimeUnit(self.day_name_to_number[txt]) + elif txt in self.delta_unit_to_constant: + return TimeUnit(self.delta_unit_to_constant[txt]) + raise ParseException(f'Invalid date unit: {orig}') + + def _figure_out_time_unit(self, orig: str) -> int: + """Figure out what unit a time expression piece is talking about.""" + txt = orig.lower()[:3] + if txt in self.time_delta_unit_to_constant: + return self.time_delta_unit_to_constant[txt] + raise ParseException(f'Invalid time unit: {orig}') + + def _parse_special_date(self, name: str) -> Optional[datetime.date]: + """Parse what we think is a special date name and return its datetime + (or None if it can't be parsed). + """ + today = self.today + year = self.context.get('year', today.year) + name = DateParser._normalize_special_day_name(self.context['special']) + + # Yesterday, today, tomorrow -- ignore any next/last + if name in ('today', 'now'): + return today + if name == 'yeste': + return today + datetime.timedelta(days=-1) + if name == 'tomor': + return today + datetime.timedelta(days=+1) + + next_last = self.context.get('special_next_last', '') + if next_last == 'next': + year += 1 + self.saw_overt_year = True + elif next_last == 'last': + year -= 1 + self.saw_overt_year = True + + # Holiday names + if name == 'easte': + return dateutil.easter.easter(year=year) + elif name == 'hallo': + return datetime.date(year=year, month=10, day=31) + + for holiday_date, holiday_name in sorted(holidays.US(years=year).items()): + if 'Observed' not in holiday_name: + holiday_name = DateParser._normalize_special_day_name(holiday_name) + if name == holiday_name: + return holiday_date + if name == 'chriseve': + return datetime.date(year=year, month=12, day=24) + elif name == 'newyeeve': + return datetime.date(year=year, month=12, day=31) + return None + + def _resolve_ides_nones(self, day: str, month_number: int) -> int: + """Handle date expressions like "the ides of March" which require + both the "ides" and the month since the definition of the "ides" + changes based on the length of the month. + """ + assert 'ide' in day or 'non' in day + assert month_number in self.typical_days_per_month + typical_days_per_month = self.typical_days_per_month[month_number] + + # "full" month + if typical_days_per_month == 31: + if self.context['day'] == 'ide': + return 15 + else: + return 7 + + # "hollow" month + else: + if self.context['day'] == 'ide': + return 13 + else: + return 5 + + def _parse_normal_date(self) -> datetime.date: + if 'dow' in self.context and 'month' not in self.context: + d = self.today + while d.weekday() != self.context['dow']: + d += datetime.timedelta(days=1) + return d + + if 'month' not in self.context: + raise ParseException('Missing month') + if 'day' not in self.context: + raise ParseException('Missing day') + if 'year' not in self.context: + self.context['year'] = self.today.year + self.saw_overt_year = False + else: + self.saw_overt_year = True + + # Handling "ides" and "nones" requires both the day and month. + if self.context['day'] == 'ide' or self.context['day'] == 'non': + self.context['day'] = self._resolve_ides_nones( + self.context['day'], self.context['month'] + ) + + return datetime.date( + year=self.context['year'], + month=self.context['month'], + day=self.context['day'], + ) + + @staticmethod + def _parse_tz(txt: str) -> Any: + if txt == 'Z': + txt = 'UTC' + + # Try pytz + try: + tz1 = pytz.timezone(txt) + if tz1 is not None: + return tz1 + except Exception: + pass + + # Try dateutil + try: + tz2 = dateutil.tz.gettz(txt) + if tz2 is not None: + return tz2 + except Exception: + pass + + # Try constructing an offset in seconds + try: + txt_sign = txt[0] + if txt_sign in ('-', '+'): + sign = +1 if txt_sign == '+' else -1 + hour = int(txt[1:3]) + minute = int(txt[-2:]) + offset = sign * (hour * 60 * 60) + sign * (minute * 60) + tzoffset = dateutil.tz.tzoffset(txt, offset) + return tzoffset + except Exception: + pass + return None + + @staticmethod + def _get_int(txt: str) -> int: + while not txt[0].isdigit() and txt[0] != '-' and txt[0] != '+': + txt = txt[1:] + while not txt[-1].isdigit(): + txt = txt[:-1] + return int(txt) + + # -- overridden methods invoked by parse walk. Note: not part of the class' + # public API(!!) -- + + def visitErrorNode(self, node: antlr4.ErrorNode) -> None: + pass + + def visitTerminal(self, node: antlr4.TerminalNode) -> None: + pass + + def exitParse(self, ctx: dateparse_utilsParser.ParseContext) -> None: + """Populate self.datetime.""" + if self.date is None: + self.date = self.today + year = self.date.year + month = self.date.month + day = self.date.day + + if self.time is None: + self.time = datetime.time(0, 0, 0) + hour = self.time.hour + minute = self.time.minute + second = self.time.second + micros = self.time.microsecond + + self.datetime = datetime.datetime( + year, + month, + day, + hour, + minute, + second, + micros, + tzinfo=self.time.tzinfo, + ) + + # Apply resudual adjustments to times here when we have a + # datetime. + self.datetime = self.datetime + self.timedelta + assert self.datetime is not None + self.time = datetime.time( + self.datetime.hour, + self.datetime.minute, + self.datetime.second, + self.datetime.microsecond, + self.datetime.tzinfo, + ) + + def enterDateExpr(self, ctx: dateparse_utilsParser.DateExprContext): + self.date = None + if ctx.singleDateExpr() is not None: + self.main_type = DateParser.PARSE_TYPE_SINGLE_DATE_EXPR + elif ctx.baseAndOffsetDateExpr() is not None: + self.main_type = DateParser.PARSE_TYPE_BASE_AND_OFFSET_EXPR + + def enterTimeExpr(self, ctx: dateparse_utilsParser.TimeExprContext): + self.time = None + if ctx.singleTimeExpr() is not None: + self.time_type = DateParser.PARSE_TYPE_SINGLE_TIME_EXPR + elif ctx.baseAndOffsetTimeExpr() is not None: + self.time_type = DateParser.PARSE_TYPE_BASE_AND_OFFSET_TIME_EXPR + + def exitDateExpr(self, ctx: dateparse_utilsParser.DateExprContext) -> None: + """When we leave the date expression, populate self.date.""" + if 'special' in self.context: + self.date = self._parse_special_date(self.context['special']) + else: + self.date = self._parse_normal_date() + assert self.date is not None + + # For a single date, just return the date we pulled out. + if self.main_type == DateParser.PARSE_TYPE_SINGLE_DATE_EXPR: + return + + # Otherwise treat self.date as a base date that we're modifying + # with an offset. + if 'delta_int' not in self.context: + raise ParseException('Missing delta_int?!') + count = self.context['delta_int'] + if count == 0: + return + + # Adjust count's sign based on the presence of 'before' or 'after'. + if 'delta_before_after' in self.context: + before_after = self.context['delta_before_after'].lower() + if before_after in ('before', 'until', 'til', 'to'): + count = -count + + # What are we counting units of? + if 'delta_unit' not in self.context: + raise ParseException('Missing delta_unit?!') + unit = self.context['delta_unit'] + dt = n_timeunits_from_base(count, TimeUnit(unit), date_to_datetime(self.date)) + self.date = datetime_to_date(dt) + + def exitTimeExpr(self, ctx: dateparse_utilsParser.TimeExprContext) -> None: + # Simple time? + self.time = datetime.time( + self.context['hour'], + self.context['minute'], + self.context['seconds'], + self.context['micros'], + tzinfo=self.context.get('tz', None), + ) + if self.time_type == DateParser.PARSE_TYPE_SINGLE_TIME_EXPR: + return + + # If we get here there (should be) a relative adjustment to + # the time. + if 'nth' in self.context: + count = self.context['nth'] + elif 'time_delta_int' in self.context: + count = self.context['time_delta_int'] + else: + raise ParseException('Missing delta in relative time.') + if count == 0: + return + + # Adjust count's sign based on the presence of 'before' or 'after'. + if 'time_delta_before_after' in self.context: + before_after = self.context['time_delta_before_after'].lower() + if before_after in ('before', 'until', 'til', 'to'): + count = -count + + # What are we counting units of... assume minutes. + if 'time_delta_unit' not in self.context: + self.timedelta += datetime.timedelta(minutes=count) + else: + unit = self.context['time_delta_unit'] + if unit == TimeUnit.SECONDS: + self.timedelta += datetime.timedelta(seconds=count) + elif unit == TimeUnit.MINUTES: + self.timedelta = datetime.timedelta(minutes=count) + elif unit == TimeUnit.HOURS: + self.timedelta = datetime.timedelta(hours=count) + else: + raise ParseException(f'Invalid Unit: "{unit}"') + + def exitDeltaPlusMinusExpr( + self, ctx: dateparse_utilsParser.DeltaPlusMinusExprContext + ) -> None: + try: + n = ctx.nth() + if n is None: + raise ParseException(f'Bad N in Delta +/- Expr: {ctx.getText()}') + n = n.getText() + n = DateParser._get_int(n) + unit = self._figure_out_date_unit(ctx.deltaUnit().getText().lower()) + except Exception as e: + raise ParseException(f'Invalid Delta +/-: {ctx.getText()}') from e + else: + self.context['delta_int'] = n + self.context['delta_unit'] = unit + + def exitNextLastUnit(self, ctx: dateparse_utilsParser.DeltaUnitContext) -> None: + try: + unit = self._figure_out_date_unit(ctx.getText().lower()) + except Exception as e: + raise ParseException(f'Bad delta unit: {ctx.getText()}') from e + else: + self.context['delta_unit'] = unit + + def exitDeltaNextLast( + self, ctx: dateparse_utilsParser.DeltaNextLastContext + ) -> None: + try: + txt = ctx.getText().lower() + except Exception as e: + raise ParseException(f'Bad next/last: {ctx.getText()}') from e + if 'month' in self.context or 'day' in self.context or 'year' in self.context: + raise ParseException( + 'Next/last expression expected to be relative to today.' + ) + if txt[:4] == 'next': + self.context['delta_int'] = +1 + self.context['day'] = self.now_datetime.day + self.context['month'] = self.now_datetime.month + self.context['year'] = self.now_datetime.year + self.saw_overt_year = True + elif txt[:4] == 'last': + self.context['delta_int'] = -1 + self.context['day'] = self.now_datetime.day + self.context['month'] = self.now_datetime.month + self.context['year'] = self.now_datetime.year + self.saw_overt_year = True + else: + raise ParseException(f'Bad next/last: {ctx.getText()}') + + def exitCountUnitsBeforeAfterTimeExpr( + self, ctx: dateparse_utilsParser.CountUnitsBeforeAfterTimeExprContext + ) -> None: + if 'nth' not in self.context: + raise ParseException(f'Bad count expression: {ctx.getText()}') + try: + unit = self._figure_out_time_unit(ctx.deltaTimeUnit().getText().lower()) + self.context['time_delta_unit'] = unit + except Exception as e: + raise ParseException(f'Bad delta unit: {ctx.getText()}') from e + if 'time_delta_before_after' not in self.context: + raise ParseException(f'Bad Before/After: {ctx.getText()}') + + def exitDeltaTimeFraction( + self, ctx: dateparse_utilsParser.DeltaTimeFractionContext + ) -> None: + try: + txt = ctx.getText().lower()[:4] + if txt == 'quar': + self.context['time_delta_int'] = 15 + self.context['time_delta_unit'] = TimeUnit.MINUTES + elif txt == 'half': + self.context['time_delta_int'] = 30 + self.context['time_delta_unit'] = TimeUnit.MINUTES + else: + raise ParseException(f'Bad time fraction {ctx.getText()}') + except Exception as e: + raise ParseException(f'Bad time fraction {ctx.getText()}') from e + + def exitDeltaBeforeAfter( + self, ctx: dateparse_utilsParser.DeltaBeforeAfterContext + ) -> None: + try: + txt = ctx.getText().lower() + except Exception as e: + raise ParseException(f'Bad delta before|after: {ctx.getText()}') from e + else: + self.context['delta_before_after'] = txt + + def exitDeltaTimeBeforeAfter( + self, ctx: dateparse_utilsParser.DeltaBeforeAfterContext + ) -> None: + try: + txt = ctx.getText().lower() + except Exception as e: + raise ParseException(f'Bad delta before|after: {ctx.getText()}') from e + else: + self.context['time_delta_before_after'] = txt + + def exitNthWeekdayInMonthMaybeYearExpr( + self, ctx: dateparse_utilsParser.NthWeekdayInMonthMaybeYearExprContext + ) -> None: + """Do a bunch of work to convert expressions like... + + 'the 2nd Friday of June' -and- + 'the last Wednesday in October' + + ...into base + offset expressions instead. + """ + try: + if 'nth' not in self.context: + raise ParseException(f'Missing nth number: {ctx.getText()}') + n = self.context['nth'] + if n < 1 or n > 5: # months never have more than 5 Foodays + if n != -1: + raise ParseException(f'Invalid nth number: {ctx.getText()}') + del self.context['nth'] + self.context['delta_int'] = n + + year = self.context.get('year', self.today.year) + if 'month' not in self.context: + raise ParseException(f'Missing month expression: {ctx.getText()}') + month = self.context['month'] + + dow = self.context['dow'] + del self.context['dow'] + self.context['delta_unit'] = dow + + # For the nth Fooday in Month, start at the last day of + # the previous month count ahead N Foodays. For the last + # Fooday in Month, start at the last of the month and + # count back one Fooday. + if n == -1: + month += 1 + if month == 13: + month = 1 + year += 1 + tmp_date = datetime.date(year=year, month=month, day=1) + tmp_date = tmp_date - datetime.timedelta(days=1) + + # The delta adjustment code can handle the case where + # the last day of the month is the day we're looking + # for already. + else: + tmp_date = datetime.date(year=year, month=month, day=1) + tmp_date = tmp_date - datetime.timedelta(days=1) + + self.context['year'] = tmp_date.year + self.context['month'] = tmp_date.month + self.context['day'] = tmp_date.day + self.main_type = DateParser.PARSE_TYPE_BASE_AND_OFFSET_EXPR + except Exception as e: + raise ParseException( + f'Invalid nthWeekday expression: {ctx.getText()}' + ) from e + + def exitFirstLastWeekdayInMonthMaybeYearExpr( + self, + ctx: dateparse_utilsParser.FirstLastWeekdayInMonthMaybeYearExprContext, + ) -> None: + self.exitNthWeekdayInMonthMaybeYearExpr(ctx) + + def exitNth(self, ctx: dateparse_utilsParser.NthContext) -> None: + try: + i = DateParser._get_int(ctx.getText()) + except Exception as e: + raise ParseException(f'Bad nth expression: {ctx.getText()}') from e + else: + self.context['nth'] = i + + def exitFirstOrLast(self, ctx: dateparse_utilsParser.FirstOrLastContext) -> None: + try: + txt = ctx.getText() + if txt == 'first': + txt = 1 + elif txt == 'last': + txt = -1 + else: + raise ParseException(f'Bad first|last expression: {ctx.getText()}') + except Exception as e: + raise ParseException(f'Bad first|last expression: {ctx.getText()}') from e + else: + self.context['nth'] = txt + + def exitDayName(self, ctx: dateparse_utilsParser.DayNameContext) -> None: + try: + dow = ctx.getText().lower()[:3] + dow = self.day_name_to_number.get(dow, None) + except Exception as e: + raise ParseException('Bad day of week') from e + else: + self.context['dow'] = dow + + def exitDayOfMonth(self, ctx: dateparse_utilsParser.DayOfMonthContext) -> None: + try: + day = ctx.getText().lower() + if day[:3] == 'ide': + self.context['day'] = 'ide' + return + if day[:3] == 'non': + self.context['day'] = 'non' + return + if day[:3] == 'kal': + self.context['day'] = 1 + return + day = DateParser._get_int(day) + if day < 1 or day > 31: + raise ParseException(f'Bad dayOfMonth expression: {ctx.getText()}') + except Exception as e: + raise ParseException(f'Bad dayOfMonth expression: {ctx.getText()}') from e + self.context['day'] = day + + def exitMonthName(self, ctx: dateparse_utilsParser.MonthNameContext) -> None: + try: + month = ctx.getText() + while month[0] == '/' or month[0] == '-': + month = month[1:] + month = month[:3].lower() + month = self.month_name_to_number.get(month, None) + if month is None: + raise ParseException(f'Bad monthName expression: {ctx.getText()}') + except Exception as e: + raise ParseException(f'Bad monthName expression: {ctx.getText()}') from e + else: + self.context['month'] = month + + def exitMonthNumber(self, ctx: dateparse_utilsParser.MonthNumberContext) -> None: + try: + month = DateParser._get_int(ctx.getText()) + if month < 1 or month > 12: + raise ParseException(f'Bad monthNumber expression: {ctx.getText()}') + except Exception as e: + raise ParseException(f'Bad monthNumber expression: {ctx.getText()}') from e + else: + self.context['month'] = month + + def exitYear(self, ctx: dateparse_utilsParser.YearContext) -> None: + try: + year = DateParser._get_int(ctx.getText()) + if year < 1: + raise ParseException(f'Bad year expression: {ctx.getText()}') + except Exception as e: + raise ParseException(f'Bad year expression: {ctx.getText()}') from e + else: + self.saw_overt_year = True + self.context['year'] = year + + def exitSpecialDateMaybeYearExpr( + self, ctx: dateparse_utilsParser.SpecialDateMaybeYearExprContext + ) -> None: + try: + special = ctx.specialDate().getText().lower() + self.context['special'] = special + except Exception as e: + raise ParseException( + f'Bad specialDate expression: {ctx.specialDate().getText()}' + ) from e + try: + mod = ctx.thisNextLast() + if mod is not None: + if mod.THIS() is not None: + self.context['special_next_last'] = 'this' + elif mod.NEXT() is not None: + self.context['special_next_last'] = 'next' + elif mod.LAST() is not None: + self.context['special_next_last'] = 'last' + except Exception as e: + raise ParseException( + f'Bad specialDateNextLast expression: {ctx.getText()}' + ) from e + + def exitNFoosFromTodayAgoExpr( + self, ctx: dateparse_utilsParser.NFoosFromTodayAgoExprContext + ) -> None: + d = self.now_datetime + try: + count = DateParser._get_int(ctx.unsignedInt().getText()) + unit = ctx.deltaUnit().getText().lower() + ago_from_now = ctx.AGO_FROM_NOW().getText() + except Exception as e: + raise ParseException(f'Bad NFoosFromTodayAgoExpr: {ctx.getText()}') from e + + if "ago" in ago_from_now or "back" in ago_from_now: + count = -count + + unit = self._figure_out_date_unit(unit) + d = n_timeunits_from_base(count, TimeUnit(unit), d) + self.context['year'] = d.year + self.context['month'] = d.month + self.context['day'] = d.day + + def exitDeltaRelativeToTodayExpr( + self, ctx: dateparse_utilsParser.DeltaRelativeToTodayExprContext + ) -> None: + # When someone says "next week" they mean a week from now. + # Likewise next month or last year. These expressions are now + # +/- delta. + # + # But when someone says "this Friday" they mean "this coming + # Friday". It would be weird to say "this Friday" if today + # was already Friday but I'm parsing it to mean: the next day + # that is a Friday. So when you say "next Friday" you mean + # the Friday after this coming Friday, or 2 Fridays from now. + # + # This set handles this weirdness. + weekdays = set( + [ + TimeUnit.MONDAYS, + TimeUnit.TUESDAYS, + TimeUnit.WEDNESDAYS, + TimeUnit.THURSDAYS, + TimeUnit.FRIDAYS, + TimeUnit.SATURDAYS, + TimeUnit.SUNDAYS, + ] + ) + d = self.now_datetime + try: + mod = ctx.thisNextLast() + unit = ctx.deltaUnit().getText().lower() + unit = self._figure_out_date_unit(unit) + if mod.LAST(): + count = -1 + elif mod.THIS(): + if unit in weekdays: + count = +1 + else: + count = 0 + elif mod.NEXT(): + if unit in weekdays: + count = +2 + else: + count = +1 + else: + raise ParseException(f'Bad This/Next/Last modifier: {mod}') + except Exception as e: + raise ParseException( + f'Bad DeltaRelativeToTodayExpr: {ctx.getText()}' + ) from e + d = n_timeunits_from_base(count, TimeUnit(unit), d) + self.context['year'] = d.year + self.context['month'] = d.month + self.context['day'] = d.day + + def exitSpecialTimeExpr( + self, ctx: dateparse_utilsParser.SpecialTimeExprContext + ) -> None: + try: + txt = ctx.specialTime().getText().lower() + except Exception as e: + raise ParseException(f'Bad special time expression: {ctx.getText()}') from e + else: + if txt in ('noon', 'midday'): + self.context['hour'] = 12 + self.context['minute'] = 0 + self.context['seconds'] = 0 + self.context['micros'] = 0 + elif txt == 'midnight': + self.context['hour'] = 0 + self.context['minute'] = 0 + self.context['seconds'] = 0 + self.context['micros'] = 0 + else: + raise ParseException(f'Bad special time expression: {txt}') + + try: + tz = ctx.tzExpr().getText() + self.context['tz'] = DateParser._parse_tz(tz) + except Exception: + pass + + def exitTwelveHourTimeExpr( + self, ctx: dateparse_utilsParser.TwelveHourTimeExprContext + ) -> None: + try: + hour = ctx.hour().getText() + while not hour[-1].isdigit(): + hour = hour[:-1] + hour = DateParser._get_int(hour) + except Exception as e: + raise ParseException(f'Bad hour: {ctx.hour().getText()}') from e + if hour <= 0 or hour > 12: + raise ParseException(f'Bad hour (out of range): {hour}') + + try: + minute = DateParser._get_int(ctx.minute().getText()) + except Exception: + minute = 0 + if minute < 0 or minute > 59: + raise ParseException(f'Bad minute (out of range): {minute}') + self.context['minute'] = minute + + try: + seconds = DateParser._get_int(ctx.second().getText()) + except Exception: + seconds = 0 + if seconds < 0 or seconds > 59: + raise ParseException(f'Bad second (out of range): {seconds}') + self.context['seconds'] = seconds + + try: + micros = DateParser._get_int(ctx.micros().getText()) + except Exception: + micros = 0 + if micros < 0 or micros > 1000000: + raise ParseException(f'Bad micros (out of range): {micros}') + self.context['micros'] = micros + + try: + ampm = ctx.ampm().getText() + except Exception as e: + raise ParseException(f'Bad ampm: {ctx.ampm().getText()}') from e + if hour == 12: + hour = 0 + if ampm[0] == 'p': + hour += 12 + self.context['hour'] = hour + + try: + tz = ctx.tzExpr().getText() + self.context['tz'] = DateParser._parse_tz(tz) + except Exception: + pass + + def exitTwentyFourHourTimeExpr( + self, ctx: dateparse_utilsParser.TwentyFourHourTimeExprContext + ) -> None: + try: + hour = ctx.hour().getText() + while not hour[-1].isdigit(): + hour = hour[:-1] + hour = DateParser._get_int(hour) + except Exception as e: + raise ParseException(f'Bad hour: {ctx.hour().getText()}') from e + if hour < 0 or hour > 23: + raise ParseException(f'Bad hour (out of range): {hour}') + self.context['hour'] = hour + + try: + minute = DateParser._get_int(ctx.minute().getText()) + except Exception: + minute = 0 + if minute < 0 or minute > 59: + raise ParseException(f'Bad minute (out of range): {ctx.getText()}') + self.context['minute'] = minute + + try: + seconds = DateParser._get_int(ctx.second().getText()) + except Exception: + seconds = 0 + if seconds < 0 or seconds > 59: + raise ParseException(f'Bad second (out of range): {ctx.getText()}') + self.context['seconds'] = seconds + + try: + micros = DateParser._get_int(ctx.micros().getText()) + except Exception: + micros = 0 + if micros < 0 or micros >= 1000000: + raise ParseException(f'Bad micros (out of range): {ctx.getText()}') + self.context['micros'] = micros + + try: + tz = ctx.tzExpr().getText() + self.context['tz'] = DateParser._parse_tz(tz) + except Exception: + pass + + +@bootstrap.initialize +def main() -> None: + parser = DateParser() + for line in sys.stdin: + line = line.strip() + line = re.sub(r"#.*$", "", line) + if re.match(r"^ *$", line) is not None: + continue + try: + dt = parser.parse(line) + except Exception as e: + logger.exception(e) + print("Unrecognized.") + else: + assert dt is not None + print(dt.strftime('%A %Y/%m/%d %H:%M:%S.%f %Z(%z)')) + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/src/pyutils/datetimez/datetime_utils.py b/src/pyutils/datetimez/datetime_utils.py new file mode 100644 index 0000000..6026d9a --- /dev/null +++ b/src/pyutils/datetimez/datetime_utils.py @@ -0,0 +1,956 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""Utilities related to dates, times, and datetimes.""" + +import datetime +import enum +import logging +import re +from typing import Any, NewType, Optional, Tuple + +import holidays # type: ignore +import pytz + +from pyutils.datetimez import constants + +logger = logging.getLogger(__name__) + + +def is_timezone_aware(dt: datetime.datetime) -> bool: + """Returns true if the datetime argument is timezone aware or + False if not. + + See: https://docs.python.org/3/library/datetime.html + #determining-if-an-object-is-aware-or-naive + + Args: + dt: The datetime object to check + + >>> is_timezone_aware(datetime.datetime.now()) + False + + >>> is_timezone_aware(now_pacific()) + True + + """ + return dt.tzinfo is not None and dt.tzinfo.utcoffset(dt) is not None + + +def is_timezone_naive(dt: datetime.datetime) -> bool: + """Inverse of is_timezone_aware -- returns true if the dt argument + is timezone naive. + + See: https://docs.python.org/3/library/datetime.html + #determining-if-an-object-is-aware-or-naive + + Args: + dt: The datetime object to check + + >>> is_timezone_naive(datetime.datetime.now()) + True + + >>> is_timezone_naive(now_pacific()) + False + + """ + return not is_timezone_aware(dt) + + +def strip_timezone(dt: datetime.datetime) -> datetime.datetime: + """Remove the timezone from a datetime. + + .. warning:: + + This does not change the hours, minutes, seconds, + months, days, years, etc... Thus the instant to which this + timestamp refers will change. Silently ignores datetimes + which are already timezone naive. + + >>> now = now_pacific() + >>> now.tzinfo == None + False + + >>> dt = strip_timezone(now) + >>> dt == now + False + + >>> dt.tzinfo == None + True + + >>> dt.hour == now.hour + True + + """ + if is_timezone_naive(dt): + return dt + return replace_timezone(dt, None) + + +def add_timezone(dt: datetime.datetime, tz: datetime.tzinfo) -> datetime.datetime: + """ + Adds a timezone to a timezone naive datetime. This does not + change the instant to which the timestamp refers. See also: + replace_timezone. + + >>> now = datetime.datetime.now() + >>> is_timezone_aware(now) + False + + >>> now_pacific = add_timezone(now, pytz.timezone('US/Pacific')) + >>> is_timezone_aware(now_pacific) + True + + >>> now.hour == now_pacific.hour + True + >>> now.minute == now_pacific.minute + True + + """ + + # This doesn't work, tz requires a timezone naive dt. Two options + # here: + # 1. Use strip_timezone and try again. + # 2. Replace the timezone on your dt object via replace_timezone. + # Be aware that this changes the instant to which the dt refers + # and, further, can introduce weirdness like UTC offsets that + # are weird (e.g. not an even multiple of an hour, etc...) + if is_timezone_aware(dt): + if dt.tzinfo == tz: + return dt + raise Exception( + f'{dt} is already timezone aware; use replace_timezone or translate_timezone ' + + 'depending on the semantics you want. See the pydocs / code.' + ) + return dt.replace(tzinfo=tz) + + +def replace_timezone( + dt: datetime.datetime, tz: Optional[datetime.tzinfo] +) -> datetime.datetime: + """Replaces the timezone on a timezone aware datetime object directly + (leaving the year, month, day, hour, minute, second, micro, + etc... alone). + + Works with timezone aware and timezone naive dts but for the + latter it is probably better to use add_timezone or just create it + with a tz parameter. Using this can have weird side effects like + UTC offsets that are not an even multiple of an hour, etc... + + .. warning:: + + This changes the instant to which this dt refers. + + >>> from pytz import UTC + >>> d = now_pacific() + >>> d.tzinfo.tzname(d)[0] # Note: could be PST or PDT + 'P' + >>> h = d.hour + >>> o = replace_timezone(d, UTC) + >>> o.tzinfo.tzname(o) + 'UTC' + >>> o.hour == h + True + + """ + if is_timezone_aware(dt): + logger.warning( + '%s already has a timezone; klobbering it anyway.\n Be aware that this operation changed the instant to which the object refers.', + dt, + ) + return datetime.datetime( + dt.year, + dt.month, + dt.day, + dt.hour, + dt.minute, + dt.second, + dt.microsecond, + tzinfo=tz, + ) + else: + if tz: + return add_timezone(dt, tz) + else: + return dt + + +def replace_time_timezone(t: datetime.time, tz: datetime.tzinfo) -> datetime.time: + """Replaces the timezone on a datetime.time directly without performing + any translation. + + .. warning:: + + Note that, as above, this will change the instant to + which the time refers. + + >>> t = datetime.time(8, 15, 12, 0, pytz.UTC) + >>> t.tzname() + 'UTC' + + >>> t = replace_time_timezone(t, pytz.timezone('US/Pacific')) + >>> t.tzname() + 'US/Pacific' + """ + return t.replace(tzinfo=tz) + + +def translate_timezone(dt: datetime.datetime, tz: datetime.tzinfo) -> datetime.datetime: + """ + Translates dt into a different timezone by adjusting the year, month, + day, hour, minute, second, micro, etc... appropriately. The returned + dt is the same instant in another timezone. + + >>> import pytz + >>> d = now_pacific() + >>> d.tzinfo.tzname(d)[0] # Note: could be PST or PDT + 'P' + >>> h = d.hour + >>> o = translate_timezone(d, pytz.timezone('US/Eastern')) + >>> o.tzinfo.tzname(o)[0] # Again, could be EST or EDT + 'E' + >>> o.hour == h + False + >>> expected = h + 3 # Three hours later in E?T than P?T + >>> expected = expected % 24 # Handle edge case + >>> expected == o.hour + True + """ + return dt.replace().astimezone(tz=tz) + + +def now() -> datetime.datetime: + """ + What time is it? Result is a timezone naive datetime. + """ + return datetime.datetime.now() + + +def now_pacific() -> datetime.datetime: + """ + What time is it? Result in US/Pacific time (PST/PDT) + """ + return datetime.datetime.now(pytz.timezone("US/Pacific")) + + +def date_to_datetime(date: datetime.date) -> datetime.datetime: + """ + Given a date, return a datetime with hour/min/sec zero (midnight) + + >>> import datetime + >>> date_to_datetime(datetime.date(2021, 12, 25)) + datetime.datetime(2021, 12, 25, 0, 0) + + """ + return datetime.datetime(date.year, date.month, date.day, 0, 0, 0, 0) + + +def time_to_datetime_today(time: datetime.time) -> datetime.datetime: + """ + Given a time, returns that time as a datetime with a date component + set based on the current date. If the time passed is timezone aware, + the resulting datetime will also be (and will use the same tzinfo). + If the time is timezone naive, the datetime returned will be too. + + >>> t = datetime.time(13, 14, 0) + >>> d = now_pacific().date() + >>> dt = time_to_datetime_today(t) + >>> dt.date() == d + True + + >>> dt.time() == t + True + + >>> dt.tzinfo == t.tzinfo + True + + >>> dt.tzinfo == None + True + + >>> t = datetime.time(8, 15, 12, 0, pytz.UTC) + >>> t.tzinfo == None + False + + >>> dt = time_to_datetime_today(t) + >>> dt.tzinfo == None + False + + """ + tz = time.tzinfo + return datetime.datetime.combine(now_pacific(), time, tz) + + +def date_and_time_to_datetime( + date: datetime.date, time: datetime.time +) -> datetime.datetime: + """ + Given a date and time, merge them and return a datetime. + + >>> import datetime + >>> d = datetime.date(2021, 12, 25) + >>> t = datetime.time(12, 30, 0, 0) + >>> date_and_time_to_datetime(d, t) + datetime.datetime(2021, 12, 25, 12, 30) + + """ + return datetime.datetime( + date.year, + date.month, + date.day, + time.hour, + time.minute, + time.second, + time.microsecond, + ) + + +def datetime_to_date_and_time( + dt: datetime.datetime, +) -> Tuple[datetime.date, datetime.time]: + """Return the component date and time objects of a datetime in a + Tuple given a datetime. + + >>> import datetime + >>> dt = datetime.datetime(2021, 12, 25, 12, 30) + >>> (d, t) = datetime_to_date_and_time(dt) + >>> d + datetime.date(2021, 12, 25) + >>> t + datetime.time(12, 30) + + """ + return (dt.date(), dt.timetz()) + + +def datetime_to_date(dt: datetime.datetime) -> datetime.date: + """Return just the date part of a datetime. + + >>> import datetime + >>> dt = datetime.datetime(2021, 12, 25, 12, 30) + >>> datetime_to_date(dt) + datetime.date(2021, 12, 25) + + """ + return datetime_to_date_and_time(dt)[0] + + +def datetime_to_time(dt: datetime.datetime) -> datetime.time: + """Return just the time part of a datetime. + + >>> import datetime + >>> dt = datetime.datetime(2021, 12, 25, 12, 30) + >>> datetime_to_time(dt) + datetime.time(12, 30) + + """ + return datetime_to_date_and_time(dt)[1] + + +class TimeUnit(enum.IntEnum): + """An enum to represent units with which we can compute deltas.""" + + MONDAYS = 0 + TUESDAYS = 1 + WEDNESDAYS = 2 + THURSDAYS = 3 + FRIDAYS = 4 + SATURDAYS = 5 + SUNDAYS = 6 + SECONDS = 7 + MINUTES = 8 + HOURS = 9 + DAYS = 10 + WORKDAYS = 11 + WEEKS = 12 + MONTHS = 13 + YEARS = 14 + + @classmethod + def is_valid(cls, value: Any): + if isinstance(value, int): + return cls(value) is not None + elif isinstance(value, TimeUnit): + return cls(value.value) is not None + elif isinstance(value, str): + return cls.__members__[value] is not None + else: + print(type(value)) + return False + + +def n_timeunits_from_base( + count: int, unit: TimeUnit, base: datetime.datetime +) -> datetime.datetime: + """Return a datetime that is N units before/after a base datetime. + e.g. 3 Wednesdays from base datetime, 2 weeks from base date, 10 + years before base datetime, 13 minutes after base datetime, etc... + Note: to indicate before/after the base date, use a positive or + negative count. + + >>> base = string_to_datetime("2021/09/10 11:24:51AM-0700")[0] + + The next (1) Monday from the base datetime: + >>> n_timeunits_from_base(+1, TimeUnit.MONDAYS, base) + datetime.datetime(2021, 9, 13, 11, 24, 51, tzinfo=datetime.timezone(datetime.timedelta(days=-1, seconds=61200))) + + Ten (10) years after the base datetime: + >>> n_timeunits_from_base(10, TimeUnit.YEARS, base) + datetime.datetime(2031, 9, 10, 11, 24, 51, tzinfo=datetime.timezone(datetime.timedelta(days=-1, seconds=61200))) + + Fifty (50) working days (M..F, not counting holidays) after base datetime: + >>> n_timeunits_from_base(50, TimeUnit.WORKDAYS, base) + datetime.datetime(2021, 11, 23, 11, 24, 51, tzinfo=datetime.timezone(datetime.timedelta(days=-1, seconds=61200))) + + Fifty (50) days (including weekends and holidays) after base datetime: + >>> n_timeunits_from_base(50, TimeUnit.DAYS, base) + datetime.datetime(2021, 10, 30, 11, 24, 51, tzinfo=datetime.timezone(datetime.timedelta(days=-1, seconds=61200))) + + Fifty (50) months before (note negative count) base datetime: + >>> n_timeunits_from_base(-50, TimeUnit.MONTHS, base) + datetime.datetime(2017, 7, 10, 11, 24, 51, tzinfo=datetime.timezone(datetime.timedelta(days=-1, seconds=61200))) + + Fifty (50) hours after base datetime: + >>> n_timeunits_from_base(50, TimeUnit.HOURS, base) + datetime.datetime(2021, 9, 12, 13, 24, 51, tzinfo=datetime.timezone(datetime.timedelta(days=-1, seconds=61200))) + + Fifty (50) minutes before base datetime: + >>> n_timeunits_from_base(-50, TimeUnit.MINUTES, base) + datetime.datetime(2021, 9, 10, 10, 34, 51, tzinfo=datetime.timezone(datetime.timedelta(days=-1, seconds=61200))) + + Fifty (50) seconds from base datetime: + >>> n_timeunits_from_base(50, TimeUnit.SECONDS, base) + datetime.datetime(2021, 9, 10, 11, 25, 41, tzinfo=datetime.timezone(datetime.timedelta(days=-1, seconds=61200))) + + Next month corner case -- it will try to make Feb 31, 2022 then count + backwards. + >>> base = string_to_datetime("2022/01/31 11:24:51AM-0700")[0] + >>> n_timeunits_from_base(1, TimeUnit.MONTHS, base) + datetime.datetime(2022, 2, 28, 11, 24, 51, tzinfo=datetime.timezone(datetime.timedelta(days=-1, seconds=61200))) + + Last month with the same corner case + >>> base = string_to_datetime("2022/03/31 11:24:51AM-0700")[0] + >>> n_timeunits_from_base(-1, TimeUnit.MONTHS, base) + datetime.datetime(2022, 2, 28, 11, 24, 51, tzinfo=datetime.timezone(datetime.timedelta(days=-1, seconds=61200))) + + """ + assert TimeUnit.is_valid(unit) + if count == 0: + return base + + # N days from base + if unit == TimeUnit.DAYS: + timedelta = datetime.timedelta(days=count) + return base + timedelta + + # N hours from base + elif unit == TimeUnit.HOURS: + timedelta = datetime.timedelta(hours=count) + return base + timedelta + + # N minutes from base + elif unit == TimeUnit.MINUTES: + timedelta = datetime.timedelta(minutes=count) + return base + timedelta + + # N seconds from base + elif unit == TimeUnit.SECONDS: + timedelta = datetime.timedelta(seconds=count) + return base + timedelta + + # N workdays from base + elif unit == TimeUnit.WORKDAYS: + if count < 0: + count = abs(count) + timedelta = datetime.timedelta(days=-1) + else: + timedelta = datetime.timedelta(days=1) + skips = holidays.US(years=base.year).keys() + while count > 0: + old_year = base.year + base += timedelta + if base.year != old_year: + skips = holidays.US(years=base.year).keys() + if ( + base.weekday() < 5 + and datetime.date(base.year, base.month, base.day) not in skips + ): + count -= 1 + return base + + # N weeks from base + elif unit == TimeUnit.WEEKS: + timedelta = datetime.timedelta(weeks=count) + base = base + timedelta + return base + + # N months from base + elif unit == TimeUnit.MONTHS: + month_term = count % 12 + year_term = count // 12 + new_month = base.month + month_term + if new_month > 12: + new_month %= 12 + year_term += 1 + new_year = base.year + year_term + day = base.day + while True: + try: + ret = datetime.datetime( + new_year, + new_month, + day, + base.hour, + base.minute, + base.second, + base.microsecond, + base.tzinfo, + ) + break + except ValueError: + day -= 1 + return ret + + # N years from base + elif unit == TimeUnit.YEARS: + new_year = base.year + count + return datetime.datetime( + new_year, + base.month, + base.day, + base.hour, + base.minute, + base.second, + base.microsecond, + base.tzinfo, + ) + + if unit not in set( + [ + TimeUnit.MONDAYS, + TimeUnit.TUESDAYS, + TimeUnit.WEDNESDAYS, + TimeUnit.THURSDAYS, + TimeUnit.FRIDAYS, + TimeUnit.SATURDAYS, + TimeUnit.SUNDAYS, + ] + ): + raise ValueError(unit) + + # N weekdays from base (e.g. 4 wednesdays from today) + direction = 1 if count > 0 else -1 + count = abs(count) + timedelta = datetime.timedelta(days=direction) + start = base + while True: + dow = base.weekday() + if dow == unit.value and start != base: + count -= 1 + if count == 0: + return base + base = base + timedelta + + +def get_format_string( + *, + date_time_separator=" ", + include_timezone=True, + include_dayname=False, + use_month_abbrevs=False, + include_seconds=True, + include_fractional=False, + twelve_hour=True, +) -> str: + """ + Helper to return a format string without looking up the documentation + for strftime. + + >>> get_format_string() + '%Y/%m/%d %I:%M:%S%p%z' + + >>> get_format_string(date_time_separator='@') + '%Y/%m/%d@%I:%M:%S%p%z' + + >>> get_format_string(include_dayname=True) + '%a/%Y/%m/%d %I:%M:%S%p%z' + + >>> get_format_string(include_dayname=True, twelve_hour=False) + '%a/%Y/%m/%d %H:%M:%S%z' + + """ + fstring = "" + if include_dayname: + fstring += "%a/" + + if use_month_abbrevs: + fstring = f"{fstring}%Y/%b/%d{date_time_separator}" + else: + fstring = f"{fstring}%Y/%m/%d{date_time_separator}" + if twelve_hour: + fstring += "%I:%M" + if include_seconds: + fstring += ":%S" + fstring += "%p" + else: + fstring += "%H:%M" + if include_seconds: + fstring += ":%S" + if include_fractional: + fstring += ".%f" + if include_timezone: + fstring += "%z" + return fstring + + +def datetime_to_string( + dt: datetime.datetime, + *, + date_time_separator=" ", + include_timezone=True, + include_dayname=False, + use_month_abbrevs=False, + include_seconds=True, + include_fractional=False, + twelve_hour=True, +) -> str: + """ + A nice way to convert a datetime into a string; arguably better than + just printing it and relying on it __repr__(). + + >>> d = string_to_datetime( + ... "2021/09/10 11:24:51AM-0700", + ... )[0] + >>> d + datetime.datetime(2021, 9, 10, 11, 24, 51, tzinfo=datetime.timezone(datetime.timedelta(days=-1, seconds=61200))) + >>> datetime_to_string(d) + '2021/09/10 11:24:51AM-0700' + >>> datetime_to_string(d, include_dayname=True, include_seconds=False) + 'Fri/2021/09/10 11:24AM-0700' + + """ + fstring = get_format_string( + date_time_separator=date_time_separator, + include_timezone=include_timezone, + include_dayname=include_dayname, + use_month_abbrevs=use_month_abbrevs, + include_seconds=include_seconds, + include_fractional=include_fractional, + twelve_hour=twelve_hour, + ) + return dt.strftime(fstring).strip() + + +def string_to_datetime( + txt: str, + *, + date_time_separator=" ", + include_timezone=True, + include_dayname=False, + use_month_abbrevs=False, + include_seconds=True, + include_fractional=False, + twelve_hour=True, +) -> Tuple[datetime.datetime, str]: + """A nice way to convert a string into a datetime. Returns both the + datetime and the format string used to parse it. Also consider + dateparse.dateparse_utils for a full parser alternative. + + >>> d = string_to_datetime( + ... "2021/09/10 11:24:51AM-0700", + ... ) + >>> d + (datetime.datetime(2021, 9, 10, 11, 24, 51, tzinfo=datetime.timezone(datetime.timedelta(days=-1, seconds=61200))), '%Y/%m/%d %I:%M:%S%p%z') + + """ + fstring = get_format_string( + date_time_separator=date_time_separator, + include_timezone=include_timezone, + include_dayname=include_dayname, + use_month_abbrevs=use_month_abbrevs, + include_seconds=include_seconds, + include_fractional=include_fractional, + twelve_hour=twelve_hour, + ) + return (datetime.datetime.strptime(txt, fstring), fstring) + + +def timestamp() -> str: + """Return a timestamp for right now in Pacific timezone.""" + ts = datetime.datetime.now(tz=pytz.timezone("US/Pacific")) + return datetime_to_string(ts, include_timezone=True) + + +def time_to_string( + dt: datetime.datetime, + *, + include_seconds=True, + include_fractional=False, + include_timezone=False, + twelve_hour=True, +) -> str: + """A nice way to convert a datetime into a time (only) string. + This ignores the date part of the datetime. + + >>> d = string_to_datetime( + ... "2021/09/10 11:24:51AM-0700", + ... )[0] + >>> d + datetime.datetime(2021, 9, 10, 11, 24, 51, tzinfo=datetime.timezone(datetime.timedelta(days=-1, seconds=61200))) + + >>> time_to_string(d) + '11:24:51AM' + + >>> time_to_string(d, include_seconds=False) + '11:24AM' + + >>> time_to_string(d, include_seconds=False, include_timezone=True) + '11:24AM-0700' + + """ + fstring = "" + if twelve_hour: + fstring += "%l:%M" + if include_seconds: + fstring += ":%S" + fstring += "%p" + else: + fstring += "%H:%M" + if include_seconds: + fstring += ":%S" + if include_fractional: + fstring += ".%f" + if include_timezone: + fstring += "%z" + return dt.strftime(fstring).strip() + + +def seconds_to_timedelta(seconds: int) -> datetime.timedelta: + """Convert a delta in seconds into a timedelta.""" + return datetime.timedelta(seconds=seconds) + + +MinuteOfDay = NewType("MinuteOfDay", int) + + +def minute_number(hour: int, minute: int) -> MinuteOfDay: + """ + Convert hour:minute into minute number from start of day. + + >>> minute_number(0, 0) + 0 + + >>> minute_number(9, 15) + 555 + + >>> minute_number(23, 59) + 1439 + + """ + return MinuteOfDay(hour * 60 + minute) + + +def datetime_to_minute_number(dt: datetime.datetime) -> MinuteOfDay: + """ + Convert a datetime into a minute number (of the day). Note that + this ignores the date part of the datetime and only uses the time + part. + + >>> d = string_to_datetime( + ... "2021/09/10 11:24:51AM-0700", + ... )[0] + + >>> datetime_to_minute_number(d) + 684 + + """ + return minute_number(dt.hour, dt.minute) + + +def time_to_minute_number(t: datetime.time) -> MinuteOfDay: + """ + Convert a datetime.time into a minute number. + + >>> t = datetime.time(5, 15) + >>> time_to_minute_number(t) + 315 + + """ + return minute_number(t.hour, t.minute) + + +def minute_number_to_time_string(minute_num: MinuteOfDay) -> str: + """ + Convert minute number from start of day into hour:minute am/pm + string. + + >>> minute_number_to_time_string(315) + ' 5:15a' + + >>> minute_number_to_time_string(684) + '11:24a' + + """ + hour = minute_num // 60 + minute = minute_num % 60 + ampm = "a" + if hour > 12: + hour -= 12 + ampm = "p" + if hour == 12: + ampm = "p" + if hour == 0: + hour = 12 + return f"{hour:2}:{minute:02}{ampm}" + + +def parse_duration(duration: str) -> int: + """ + Parse a duration in string form into a delta seconds. + + >>> parse_duration('15 days, 2 hours') + 1303200 + + >>> parse_duration('15d 2h') + 1303200 + + >>> parse_duration('100s') + 100 + + >>> parse_duration('3min 2sec') + 182 + + """ + if duration.isdigit(): + return int(duration) + seconds = 0 + m = re.search(r'(\d+) *d[ays]*', duration) + if m is not None: + seconds += int(m.group(1)) * 60 * 60 * 24 + m = re.search(r'(\d+) *h[ours]*', duration) + if m is not None: + seconds += int(m.group(1)) * 60 * 60 + m = re.search(r'(\d+) *m[inutes]*', duration) + if m is not None: + seconds += int(m.group(1)) * 60 + m = re.search(r'(\d+) *s[econds]*', duration) + if m is not None: + seconds += int(m.group(1)) + return seconds + + +def describe_duration(seconds: int, *, include_seconds=False) -> str: + """ + Describe a duration represented as a count of seconds nicely. + + >>> describe_duration(182) + '3 minutes' + + >>> describe_duration(182, include_seconds=True) + '3 minutes, and 2 seconds' + + >>> describe_duration(100, include_seconds=True) + '1 minute, and 40 seconds' + + describe_duration(1303200) + '15 days, 2 hours' + + """ + days = divmod(seconds, constants.SECONDS_PER_DAY) + hours = divmod(days[1], constants.SECONDS_PER_HOUR) + minutes = divmod(hours[1], constants.SECONDS_PER_MINUTE) + + descr = "" + if days[0] > 1: + descr = f"{int(days[0])} days, " + elif days[0] == 1: + descr = "1 day, " + + if hours[0] > 1: + descr = descr + f"{int(hours[0])} hours, " + elif hours[0] == 1: + descr = descr + "1 hour, " + + if not include_seconds and len(descr) > 0: + descr = descr + "and " + + if minutes[0] == 1: + descr = descr + "1 minute" + else: + descr = descr + f"{int(minutes[0])} minutes" + + if include_seconds: + descr = descr + ', ' + if len(descr) > 0: + descr = descr + 'and ' + s = minutes[1] + if s == 1: + descr = descr + '1 second' + else: + descr = descr + f'{s} seconds' + return descr + + +def describe_timedelta(delta: datetime.timedelta) -> str: + """ + Describe a duration represented by a timedelta object. + + >>> d = datetime.timedelta(1, 600) + >>> describe_timedelta(d) + '1 day, and 10 minutes' + + """ + return describe_duration(int(delta.total_seconds())) # Note: drops milliseconds + + +def describe_duration_briefly(seconds: int, *, include_seconds=False) -> str: + """ + Describe a duration briefly. + + >>> describe_duration_briefly(182) + '3m' + + >>> describe_duration_briefly(182, include_seconds=True) + '3m 2s' + + >>> describe_duration_briefly(100, include_seconds=True) + '1m 40s' + + describe_duration_briefly(1303200) + '15d 2h' + + """ + days = divmod(seconds, constants.SECONDS_PER_DAY) + hours = divmod(days[1], constants.SECONDS_PER_HOUR) + minutes = divmod(hours[1], constants.SECONDS_PER_MINUTE) + + descr = '' + if days[0] > 0: + descr = f'{int(days[0])}d ' + if hours[0] > 0: + descr = descr + f'{int(hours[0])}h ' + if minutes[0] > 0 or (len(descr) == 0 and not include_seconds): + descr = descr + f'{int(minutes[0])}m ' + if minutes[1] > 0 and include_seconds: + descr = descr + f'{int(minutes[1])}s' + return descr.strip() + + +def describe_timedelta_briefly(delta: datetime.timedelta) -> str: + """ + Describe a duration represented by a timedelta object. + + >>> d = datetime.timedelta(1, 600) + >>> describe_timedelta_briefly(d) + '1d 10m' + + """ + return describe_duration_briefly( + int(delta.total_seconds()) + ) # Note: drops milliseconds + + +if __name__ == '__main__': + import doctest + + doctest.testmod() diff --git a/src/pyutils/decorator_utils.py b/src/pyutils/decorator_utils.py new file mode 100644 index 0000000..e8d2249 --- /dev/null +++ b/src/pyutils/decorator_utils.py @@ -0,0 +1,839 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch +# Portions (marked) below retain the original author's copyright. + +"""Useful(?) decorators.""" + +import enum +import functools +import inspect +import logging +import math +import multiprocessing +import random +import signal +import sys +import threading +import time +import traceback +import warnings +from typing import Any, Callable, List, Optional + +# This module is commonly used by others in here and should avoid +# taking any unnecessary dependencies back on them. + +logger = logging.getLogger(__name__) + + +def timed(func: Callable) -> Callable: + """Print the runtime of the decorated function. + + >>> @timed + ... def foo(): + ... import time + ... time.sleep(0.01) + + >>> foo() # doctest: +ELLIPSIS + Finished foo in ... + + """ + + @functools.wraps(func) + def wrapper_timer(*args, **kwargs): + start_time = time.perf_counter() + value = func(*args, **kwargs) + end_time = time.perf_counter() + run_time = end_time - start_time + msg = f"Finished {func.__qualname__} in {run_time:.4f}s" + print(msg) + logger.info(msg) + return value + + return wrapper_timer + + +def invocation_logged(func: Callable) -> Callable: + """Log the call of a function on stdout and the info log. + + >>> @invocation_logged + ... def foo(): + ... print('Hello, world.') + + >>> foo() + Entered foo + Hello, world. + Exited foo + + """ + + @functools.wraps(func) + def wrapper_invocation_logged(*args, **kwargs): + msg = f"Entered {func.__qualname__}" + print(msg) + logger.info(msg) + ret = func(*args, **kwargs) + msg = f"Exited {func.__qualname__}" + print(msg) + logger.info(msg) + return ret + + return wrapper_invocation_logged + + +def rate_limited(n_calls: int, *, per_period_in_seconds: float = 1.0) -> Callable: + """Limit invocation of a wrapped function to n calls per time period. + Thread safe. In testing this was relatively fair with multiple + threads using it though that hasn't been measured in detail. + + >>> import time + >>> from pyutils import decorator_utils + >>> from pyutils.parallelize import thread_utils + + >>> calls = 0 + + >>> @decorator_utils.rate_limited(10, per_period_in_seconds=1.0) + ... def limited(x: int): + ... global calls + ... calls += 1 + + >>> @thread_utils.background_thread + ... def a(stop): + ... for _ in range(3): + ... limited(_) + + >>> @thread_utils.background_thread + ... def b(stop): + ... for _ in range(3): + ... limited(_) + + >>> start = time.time() + >>> (t1, e1) = a() + >>> (t2, e2) = b() + >>> t1.join() + >>> t2.join() + >>> end = time.time() + >>> dur = end - start + >>> dur > 0.5 + True + + >>> calls + 6 + + """ + min_interval_seconds = per_period_in_seconds / float(n_calls) + + def wrapper_rate_limited(func: Callable) -> Callable: + cv = threading.Condition() + last_invocation_timestamp = [0.0] + + def may_proceed() -> float: + now = time.time() + last_invocation = last_invocation_timestamp[0] + if last_invocation != 0.0: + elapsed_since_last = now - last_invocation + wait_time = min_interval_seconds - elapsed_since_last + else: + wait_time = 0.0 + logger.debug('@%.4f> wait_time = %.4f', time.time(), wait_time) + return wait_time + + def wrapper_wrapper_rate_limited(*args, **kargs) -> Any: + with cv: + while True: + if cv.wait_for( + lambda: may_proceed() <= 0.0, + timeout=may_proceed(), + ): + break + with cv: + logger.debug('@%.4f> calling it...', time.time()) + ret = func(*args, **kargs) + last_invocation_timestamp[0] = time.time() + logger.debug( + '@%.4f> Last invocation <- %.4f', + time.time(), + last_invocation_timestamp[0], + ) + cv.notify() + return ret + + return wrapper_wrapper_rate_limited + + return wrapper_rate_limited + + +def debug_args(func: Callable) -> Callable: + """Print the function signature and return value at each call. + + >>> @debug_args + ... def foo(a, b, c): + ... print(a) + ... print(b) + ... print(c) + ... return (a + b, c) + + >>> foo(1, 2.0, "test") + Calling foo(1:, 2.0:, 'test':) + 1 + 2.0 + test + foo returned (3.0, 'test'): + (3.0, 'test') + """ + + @functools.wraps(func) + def wrapper_debug_args(*args, **kwargs): + args_repr = [f"{repr(a)}:{type(a)}" for a in args] + kwargs_repr = [f"{k}={v!r}:{type(v)}" for k, v in kwargs.items()] + signature = ", ".join(args_repr + kwargs_repr) + msg = f"Calling {func.__qualname__}({signature})" + print(msg) + logger.info(msg) + value = func(*args, **kwargs) + msg = f"{func.__qualname__} returned {value!r}:{type(value)}" + print(msg) + logger.info(msg) + return value + + return wrapper_debug_args + + +def debug_count_calls(func: Callable) -> Callable: + """Count function invocations and print a message befor every call. + + >>> @debug_count_calls + ... def factoral(x): + ... if x == 1: + ... return 1 + ... return x * factoral(x - 1) + + >>> factoral(5) + Call #1 of 'factoral' + Call #2 of 'factoral' + Call #3 of 'factoral' + Call #4 of 'factoral' + Call #5 of 'factoral' + 120 + + """ + + @functools.wraps(func) + def wrapper_debug_count_calls(*args, **kwargs): + wrapper_debug_count_calls.num_calls += 1 + msg = f"Call #{wrapper_debug_count_calls.num_calls} of {func.__name__!r}" + print(msg) + logger.info(msg) + return func(*args, **kwargs) + + wrapper_debug_count_calls.num_calls = 0 # type: ignore + return wrapper_debug_count_calls + + +class DelayWhen(enum.IntEnum): + """When should we delay: before or after calling the function (or + both)? + + """ + + BEFORE_CALL = 1 + AFTER_CALL = 2 + BEFORE_AND_AFTER = 3 + + +def delay( + _func: Callable = None, + *, + seconds: float = 1.0, + when: DelayWhen = DelayWhen.BEFORE_CALL, +) -> Callable: + """Slow down a function by inserting a delay before and/or after its + invocation. + + >>> import time + + >>> @delay(seconds=1.0) + ... def foo(): + ... pass + + >>> start = time.time() + >>> foo() + >>> dur = time.time() - start + >>> dur >= 1.0 + True + + """ + + def decorator_delay(func: Callable) -> Callable: + @functools.wraps(func) + def wrapper_delay(*args, **kwargs): + if when & DelayWhen.BEFORE_CALL: + logger.debug("@delay for %fs BEFORE_CALL to %s", seconds, func.__name__) + time.sleep(seconds) + retval = func(*args, **kwargs) + if when & DelayWhen.AFTER_CALL: + logger.debug("@delay for %fs AFTER_CALL to %s", seconds, func.__name__) + time.sleep(seconds) + return retval + + return wrapper_delay + + if _func is None: + return decorator_delay + else: + return decorator_delay(_func) + + +class _SingletonWrapper: + """ + A singleton wrapper class. Its instances would be created + for each decorated class. + + """ + + def __init__(self, cls): + self.__wrapped__ = cls + self._instance = None + + def __call__(self, *args, **kwargs): + """Returns a single instance of decorated class""" + logger.debug( + '@singleton returning global instance of %s', self.__wrapped__.__name__ + ) + if self._instance is None: + self._instance = self.__wrapped__(*args, **kwargs) + return self._instance + + +def singleton(cls): + """ + A singleton decorator. Returns a wrapper objects. A call on that object + returns a single instance object of decorated class. Use the __wrapped__ + attribute to access decorated class directly in unit tests + + >>> @singleton + ... class foo(object): + ... pass + + >>> a = foo() + >>> b = foo() + >>> a is b + True + + >>> id(a) == id(b) + True + + """ + return _SingletonWrapper(cls) + + +def memoized(func: Callable) -> Callable: + """Keep a cache of previous function call results. + + The cache here is a dict with a key based on the arguments to the + call. Consider also: functools.cache for a more advanced + implementation. See: + https://docs.python.org/3/library/functools.html#functools.cache + + >>> import time + + >>> @memoized + ... def expensive(arg) -> int: + ... # Simulate something slow to compute or lookup + ... time.sleep(1.0) + ... return arg * arg + + >>> start = time.time() + >>> expensive(5) # Takes about 1 sec + 25 + + >>> expensive(3) # Also takes about 1 sec + 9 + + >>> expensive(5) # Pulls from cache, fast + 25 + + >>> expensive(3) # Pulls from cache again, fast + 9 + + >>> dur = time.time() - start + >>> dur < 3.0 + True + + """ + + @functools.wraps(func) + def wrapper_memoized(*args, **kwargs): + cache_key = args + tuple(kwargs.items()) + if cache_key not in wrapper_memoized.cache: + value = func(*args, **kwargs) + logger.debug('Memoizing %s => %s for %s', cache_key, value, func.__name__) + wrapper_memoized.cache[cache_key] = value + else: + logger.debug('Returning memoized value for %s', {func.__name__}) + return wrapper_memoized.cache[cache_key] + + wrapper_memoized.cache = {} # type: ignore + return wrapper_memoized + + +def retry_predicate( + tries: int, + *, + predicate: Callable[..., bool], + delay_sec: float = 3.0, + backoff: float = 2.0, +): + """Retries a function or method up to a certain number of times with a + prescribed initial delay period and backoff rate (multiplier). + + Args: + tries: the maximum number of attempts to run the function + delay_sec: sets the initial delay period in seconds + backoff: a multiplier (must be >=1.0) used to modify the + delay at each subsequent invocation + predicate: a Callable that will be passed the retval of + the decorated function and must return True to indicate + that we should stop calling or False to indicate a retry + is necessary + """ + + if backoff < 1.0: + msg = f"backoff must be greater than or equal to 1, got {backoff}" + logger.critical(msg) + raise ValueError(msg) + + tries = math.floor(tries) + if tries < 0: + msg = f"tries must be 0 or greater, got {tries}" + logger.critical(msg) + raise ValueError(msg) + + if delay_sec <= 0: + msg = f"delay_sec must be greater than 0, got {delay_sec}" + logger.critical(msg) + raise ValueError(msg) + + def deco_retry(f): + @functools.wraps(f) + def f_retry(*args, **kwargs): + mtries, mdelay = tries, delay_sec # make mutable + logger.debug('deco_retry: will make up to %d attempts...', mtries) + retval = f(*args, **kwargs) + while mtries > 0: + if predicate(retval) is True: + logger.debug('Predicate succeeded, deco_retry is done.') + return retval + logger.debug("Predicate failed, sleeping and retrying.") + mtries -= 1 + time.sleep(mdelay) + mdelay *= backoff + retval = f(*args, **kwargs) + return retval + + return f_retry + + return deco_retry + + +def retry_if_false(tries: int, *, delay_sec=3.0, backoff=2.0): + """A helper for @retry_predicate that retries a decorated + function as long as it keeps returning False. + + >>> import time + + >>> counter = 0 + >>> @retry_if_false(5, delay_sec=1.0, backoff=1.1) + ... def foo(): + ... global counter + ... counter += 1 + ... return counter >= 3 + + >>> start = time.time() + >>> foo() # fail, delay 1.0, fail, delay 1.1, succeed + True + + >>> dur = time.time() - start + >>> counter + 3 + >>> dur > 2.0 + True + >>> dur < 2.3 + True + + """ + return retry_predicate( + tries, + predicate=lambda x: x is True, + delay_sec=delay_sec, + backoff=backoff, + ) + + +def retry_if_none(tries: int, *, delay_sec=3.0, backoff=2.0): + """Another helper for @retry_predicate above. Retries up to N + times so long as the wrapped function returns None with a delay + between each retry and a backoff that can increase the delay. + """ + + return retry_predicate( + tries, + predicate=lambda x: x is not None, + delay_sec=delay_sec, + backoff=backoff, + ) + + +def deprecated(func): + """This is a decorator which can be used to mark functions + as deprecated. It will result in a warning being emitted + when the function is used. + """ + + @functools.wraps(func) + def wrapper_deprecated(*args, **kwargs): + msg = f"Call to deprecated function {func.__qualname__}" + logger.warning(msg) + warnings.warn(msg, category=DeprecationWarning, stacklevel=2) + print(msg, file=sys.stderr) + return func(*args, **kwargs) + + return wrapper_deprecated + + +def thunkify(func): + """ + Make a function immediately return a function of no args which, + when called, waits for the result, which will start being + processed in another thread. + """ + + @functools.wraps(func) + def lazy_thunked(*args, **kwargs): + wait_event = threading.Event() + + result = [None] + exc: List[Any] = [False, None] + + def worker_func(): + try: + func_result = func(*args, **kwargs) + result[0] = func_result + except Exception: + exc[0] = True + exc[1] = sys.exc_info() # (type, value, traceback) + msg = f"Thunkify has thrown an exception (will be raised on thunk()):\n{traceback.format_exc()}" + logger.warning(msg) + finally: + wait_event.set() + + def thunk(): + wait_event.wait() + if exc[0]: + assert exc[1] + raise exc[1][0](exc[1][1]) + return result[0] + + threading.Thread(target=worker_func).start() + return thunk + + return lazy_thunked + + +############################################################ +# Timeout +############################################################ + +# http://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/ +# Used work of Stephen "Zero" Chappell +# in https://code.google.com/p/verse-quiz/source/browse/trunk/timeout.py + +# Original work is covered by PSF-2.0: + +# 1. This LICENSE AGREEMENT is between the Python Software Foundation +# ("PSF"), and the Individual or Organization ("Licensee") accessing +# and otherwise using this software ("Python") in source or binary +# form and its associated documentation. +# +# 2. Subject to the terms and conditions of this License Agreement, +# PSF hereby grants Licensee a nonexclusive, royalty-free, world-wide +# license to reproduce, analyze, test, perform and/or display +# publicly, prepare derivative works, distribute, and otherwise use +# Python alone or in any derivative version, provided, however, that +# PSF's License Agreement and PSF's notice of copyright, i.e., +# "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006 Python Software +# Foundation; All Rights Reserved" are retained in Python alone or in +# any derivative version prepared by Licensee. + +# 3. In the event Licensee prepares a derivative work that is based on +# or incorporates Python or any part thereof, and wants to make the +# derivative work available to others as provided herein, then +# Licensee hereby agrees to include in any such work a brief summary +# of the changes made to Python. + +# (N.B. See NOTICE file in the root of this module for a list of +# changes) + +# 4. PSF is making Python available to Licensee on an "AS IS" +# basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR +# FITNESS FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL +# NOT INFRINGE ANY THIRD PARTY RIGHTS. + +# 5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +# FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS A +# RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, OR ANY +# DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. + +# 6. This License Agreement will automatically terminate upon a +# material breach of its terms and conditions. + +# 7. Nothing in this License Agreement shall be deemed to create any +# relationship of agency, partnership, or joint venture between PSF +# and Licensee. This License Agreement does not grant permission to +# use PSF trademarks or trade name in a trademark sense to endorse or +# promote products or services of Licensee, or any third party. + +# 8. By copying, installing or otherwise using Python, Licensee agrees +# to be bound by the terms and conditions of this License Agreement. + + +def _raise_exception(exception, error_message: Optional[str]): + if error_message is None: + raise Exception(exception) + else: + raise Exception(error_message) + + +def _target(queue, function, *args, **kwargs): + """Run a function with arguments and return output via a queue. + + This is a helper function for the Process created in _Timeout. It runs + the function with positional arguments and keyword arguments and then + returns the function's output by way of a queue. If an exception gets + raised, it is returned to _Timeout to be raised by the value property. + """ + try: + queue.put((True, function(*args, **kwargs))) + except Exception: + queue.put((False, sys.exc_info()[1])) + + +class _Timeout(object): + """Wrap a function and add a timeout to it. + + Instances of this class are automatically generated by the add_timeout + function defined below. Do not use directly. + """ + + def __init__( + self, + function: Callable, + timeout_exception: Exception, + error_message: str, + seconds: float, + ): + self.__limit = seconds + self.__function = function + self.__timeout_exception = timeout_exception + self.__error_message = error_message + self.__name__ = function.__name__ + self.__doc__ = function.__doc__ + self.__timeout = time.time() + self.__process = multiprocessing.Process() + self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue() + + def __call__(self, *args, **kwargs): + """Execute the embedded function object asynchronously. + + The function given to the constructor is transparently called and + requires that "ready" be intermittently polled. If and when it is + True, the "value" property may then be checked for returned data. + """ + self.__limit = kwargs.pop("timeout", self.__limit) + self.__queue = multiprocessing.Queue(1) + args = (self.__queue, self.__function) + args + self.__process = multiprocessing.Process( + target=_target, args=args, kwargs=kwargs + ) + self.__process.daemon = True + self.__process.start() + if self.__limit is not None: + self.__timeout = self.__limit + time.time() + while not self.ready: + time.sleep(0.1) + return self.value + + def cancel(self): + """Terminate any possible execution of the embedded function.""" + if self.__process.is_alive(): + self.__process.terminate() + _raise_exception(self.__timeout_exception, self.__error_message) + + @property + def ready(self): + """Read-only property indicating status of "value" property.""" + if self.__limit and self.__timeout < time.time(): + self.cancel() + return self.__queue.full() and not self.__queue.empty() + + @property + def value(self): + """Read-only property containing data returned from function.""" + if self.ready is True: + flag, load = self.__queue.get() + if flag: + return load + raise load + return None + + +def timeout( + seconds: float = 1.0, + use_signals: Optional[bool] = None, + timeout_exception=TimeoutError, + error_message="Function call timed out", +): + """Add a timeout parameter to a function and return the function. + + Note: the use_signals parameter is included in order to support + multiprocessing scenarios (signal can only be used from the process' + main thread). When not using signals, timeout granularity will be + rounded to the nearest 0.1s. + + Beware that an @timeout on a function inside a module will be + evaluated at module load time and not when the wrapped function is + invoked. This can lead to problems when relying on the automatic + main thread detection code (use_signals=None, the default) since + the import probably happens on the main thread and the invocation + can happen on a different thread (which can't use signals). + + Raises an exception when/if the timeout is reached. + + It is illegal to pass anything other than a function as the first + parameter. The function is wrapped and returned to the caller. + + >>> @timeout(0.2) + ... def foo(delay: float): + ... time.sleep(delay) + ... return "ok" + + >>> foo(0) + 'ok' + + >>> foo(1.0) + Traceback (most recent call last): + ... + Exception: Function call timed out + + """ + if use_signals is None: + import pyutils.parallelize.thread_utils as tu + + use_signals = tu.is_current_thread_main_thread() + + def decorate(function): + if use_signals: + + def handler(unused_signum, unused_frame): + _raise_exception(timeout_exception, error_message) + + @functools.wraps(function) + def new_function(*args, **kwargs): + new_seconds = kwargs.pop("timeout", seconds) + if new_seconds: + old = signal.signal(signal.SIGALRM, handler) + signal.setitimer(signal.ITIMER_REAL, new_seconds) + + if not seconds: + return function(*args, **kwargs) + + try: + return function(*args, **kwargs) + finally: + if new_seconds: + signal.setitimer(signal.ITIMER_REAL, 0) + signal.signal(signal.SIGALRM, old) + + return new_function + else: + + @functools.wraps(function) + def new_function(*args, **kwargs): + timeout_wrapper = _Timeout( + function, timeout_exception, error_message, seconds + ) + return timeout_wrapper(*args, **kwargs) + + return new_function + + return decorate + + +def synchronized(lock): + """Emulates java's synchronized keyword: given a lock, require that + threads take that lock (or wait) before invoking the wrapped + function and automatically releases the lock afterwards. + """ + + def wrap(f): + @functools.wraps(f) + def _gatekeeper(*args, **kw): + lock.acquire() + try: + return f(*args, **kw) + finally: + lock.release() + + return _gatekeeper + + return wrap + + +def call_with_sample_rate(sample_rate: float) -> Callable: + """Calls the wrapped function probabilistically given a rate between + 0.0 and 1.0 inclusive (0% probability and 100% probability). + """ + + if not 0.0 <= sample_rate <= 1.0: + msg = f"sample_rate must be between [0, 1]. Got {sample_rate}." + logger.critical(msg) + raise ValueError(msg) + + def decorator(f): + @functools.wraps(f) + def _call_with_sample_rate(*args, **kwargs): + if random.uniform(0, 1) < sample_rate: + return f(*args, **kwargs) + else: + logger.debug("@call_with_sample_rate skipping a call to %s", f.__name__) + return None + + return _call_with_sample_rate + + return decorator + + +def decorate_matching_methods_with(decorator, acl=None): + """Apply the given decorator to all methods in a class whose names + begin with prefix. If prefix is None (default), decorate all + methods in the class. + """ + + def decorate_the_class(cls): + for name, m in inspect.getmembers(cls, inspect.isfunction): + if acl is None: + setattr(cls, name, decorator(m)) + else: + if acl(name): + setattr(cls, name, decorator(m)) + return cls + + return decorate_the_class + + +if __name__ == '__main__': + import doctest + + doctest.testmod() diff --git a/src/pyutils/dict_utils.py b/src/pyutils/dict_utils.py new file mode 100644 index 0000000..a5f6290 --- /dev/null +++ b/src/pyutils/dict_utils.py @@ -0,0 +1,253 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""Helper functions for dealing with dictionaries.""" + +from itertools import islice +from typing import Any, Callable, Dict, Iterator, List, Tuple + + +def init_or_inc( + d: Dict[Any, Any], + key: Any, + *, + init_value: Any = 1, + inc_function: Callable[..., Any] = lambda x: x + 1, +) -> bool: + """ + Initialize a dict value (if it doesn't exist) or increments it (using the + inc_function, which is customizable) if it already does exist. Returns + True if the key already existed or False otherwise. + + >>> d = {} + >>> init_or_inc(d, "test") + False + >>> init_or_inc(d, "test") + True + >>> init_or_inc(d, 'ing') + False + >>> d + {'test': 2, 'ing': 1} + + """ + if key in d.keys(): + d[key] = inc_function(d[key]) + return True + d[key] = init_value + return False + + +def shard(d: Dict[Any, Any], size: int) -> Iterator[Dict[Any, Any]]: + """ + Shards a dict into N subdicts which, together, contain all keys/values + from the original unsharded dict. + """ + items = d.items() + for x in range(0, len(d), size): + yield dict(islice(items, x, x + size)) + + +def coalesce_by_creating_list(_, new_value, old_value): + """Helper for use with :meth:`coalesce` that creates a list on + collision.""" + from pyutils.list_utils import flatten + + return flatten([new_value, old_value]) + + +def coalesce_by_creating_set(key, new_value, old_value): + """Helper for use with :meth:`coalesce` that creates a set on + collision.""" + return set(coalesce_by_creating_list(key, new_value, old_value)) + + +def coalesce_last_write_wins(_, new_value, discarded_old_value): + """Helper for use with :meth:`coalsce` that klobbers the old + with the new one on collision.""" + return new_value + + +def coalesce_first_write_wins(_, discarded_new_value, old_value): + """Helper for use with :meth:`coalsce` that preserves the old + value and discards the new one on collision.""" + return old_value + + +def raise_on_duplicated_keys(key, new_value, old_value): + """Helper for use with :meth:`coalesce` that raises an exception + when a collision is detected. + """ + raise Exception(f'Key {key} is duplicated in more than one input dict.') + + +def coalesce( + inputs: Iterator[Dict[Any, Any]], + *, + aggregation_function: Callable[[Any, Any, Any], Any] = coalesce_by_creating_list, +) -> Dict[Any, Any]: + """Merge N dicts into one dict containing the union of all keys / + values in the input dicts. When keys collide, apply the + aggregation_function which, by default, creates a list of values. + See also several other alternative functions for coalescing values: + + * :meth:`coalesce_by_creating_set` + * :meth:`coalesce_first_write_wins` + * :meth:`coalesce_last_write_wins` + * :meth:`raise_on_duplicated_keys` + * or provive your own collision resolution code. + + >>> a = {'a': 1, 'b': 2} + >>> b = {'b': 1, 'c': 2, 'd': 3} + >>> c = {'c': 1, 'd': 2} + >>> coalesce([a, b, c]) + {'a': 1, 'b': [1, 2], 'c': [1, 2], 'd': [2, 3]} + + >>> coalesce([a, b, c], aggregation_function=coalesce_last_write_wins) + {'a': 1, 'b': 1, 'c': 1, 'd': 2} + + >>> coalesce([a, b, c], aggregation_function=raise_on_duplicated_keys) + Traceback (most recent call last): + ... + Exception: Key b is duplicated in more than one input dict. + + """ + out: Dict[Any, Any] = {} + for d in inputs: + for key in d: + if key in out: + value = aggregation_function(key, d[key], out[key]) + else: + value = d[key] + out[key] = value + return out + + +def item_with_max_value(d: Dict[Any, Any]) -> Tuple[Any, Any]: + """Returns the key and value of the item with the max value in a dict. + + >>> d = {'a': 1, 'b': 2, 'c': 3} + >>> item_with_max_value(d) + ('c', 3) + >>> item_with_max_value({}) + Traceback (most recent call last): + ... + ValueError: max() arg is an empty sequence + + """ + return max(d.items(), key=lambda _: _[1]) + + +def item_with_min_value(d: Dict[Any, Any]) -> Tuple[Any, Any]: + """Returns the key and value of the item with the min value in a dict. + + >>> d = {'a': 1, 'b': 2, 'c': 3} + >>> item_with_min_value(d) + ('a', 1) + + """ + return min(d.items(), key=lambda _: _[1]) + + +def key_with_max_value(d: Dict[Any, Any]) -> Any: + """Returns the key with the max value in the dict. + + >>> d = {'a': 1, 'b': 2, 'c': 3} + >>> key_with_max_value(d) + 'c' + + """ + return item_with_max_value(d)[0] + + +def key_with_min_value(d: Dict[Any, Any]) -> Any: + """Returns the key with the min value in the dict. + + >>> d = {'a': 1, 'b': 2, 'c': 3} + >>> key_with_min_value(d) + 'a' + + """ + return item_with_min_value(d)[0] + + +def max_value(d: Dict[Any, Any]) -> Any: + """Returns the maximum value in the dict. + + >>> d = {'a': 1, 'b': 2, 'c': 3} + >>> max_value(d) + 3 + + """ + return item_with_max_value(d)[1] + + +def min_value(d: Dict[Any, Any]) -> Any: + """Returns the minimum value in the dict. + + >>> d = {'a': 1, 'b': 2, 'c': 3} + >>> min_value(d) + 1 + + """ + return item_with_min_value(d)[1] + + +def max_key(d: Dict[Any, Any]) -> Any: + """Returns the maximum key in dict (ignoring values totally) + + >>> d = {'a': 3, 'b': 2, 'c': 1} + >>> max_key(d) + 'c' + + """ + return max(d.keys()) + + +def min_key(d: Dict[Any, Any]) -> Any: + """Returns the minimum key in dict (ignoring values totally) + + >>> d = {'a': 3, 'b': 2, 'c': 1} + >>> min_key(d) + 'a' + + """ + return min(d.keys()) + + +def parallel_lists_to_dict(keys: List[Any], values: List[Any]) -> Dict[Any, Any]: + """Given two parallel lists (keys and values), create and return + a dict. + + >>> k = ['name', 'phone', 'address', 'zip'] + >>> v = ['scott', '555-1212', '123 main st.', '12345'] + >>> parallel_lists_to_dict(k, v) + {'name': 'scott', 'phone': '555-1212', 'address': '123 main st.', 'zip': '12345'} + + """ + if len(keys) != len(values): + raise Exception("Parallel keys and values lists must have the same length") + return dict(zip(keys, values)) + + +def dict_to_key_value_lists(d: Dict[Any, Any]) -> Tuple[List[Any], List[Any]]: + """ + >>> d = {'name': 'scott', 'phone': '555-1212', 'address': '123 main st.', 'zip': '12345'} + >>> (k, v) = dict_to_key_value_lists(d) + >>> k + ['name', 'phone', 'address', 'zip'] + >>> v + ['scott', '555-1212', '123 main st.', '12345'] + + """ + r: Tuple[List[Any], List[Any]] = ([], []) + for (k, v) in d.items(): + r[0].append(k) + r[1].append(v) + return r + + +if __name__ == '__main__': + import doctest + + doctest.testmod() diff --git a/src/pyutils/exec_utils.py b/src/pyutils/exec_utils.py new file mode 100644 index 0000000..49484c6 --- /dev/null +++ b/src/pyutils/exec_utils.py @@ -0,0 +1,233 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""Helper methods concerned with executing subprocesses.""" + +import atexit +import logging +import os +import selectors +import shlex +import subprocess +import sys +from typing import List, Optional + +logger = logging.getLogger(__file__) + + +def cmd_showing_output( + command: str, + *, + timeout_seconds: Optional[float] = None, +) -> int: + """Kick off a child process. Capture and emit all output that it + produces on stdout and stderr in a character by character manner + so that we don't have to wait on newlines. This was done to + capture the output of a subprocess that created dots to show + incremental progress on a task and render it correctly. + + Args: + command: the command to execute + timeout_seconds: terminate the subprocess if it takes longer + than N seconds; None means to wait as long as it takes. + + Returns: + the exit status of the subprocess once the subprocess has + exited. Raises TimeoutExpired after killing the subprocess + if the timeout expires. + + Side effects: + prints all output of the child process (stdout or stderr) + """ + + def timer_expired(p): + p.kill() + raise subprocess.TimeoutExpired(command, timeout_seconds) + + line_enders = set([b'\n', b'\r']) + sel = selectors.DefaultSelector() + with subprocess.Popen( + command, + shell=True, + bufsize=0, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=False, + ) as p: + timer = None + if timeout_seconds: + import threading + + timer = threading.Timer(timeout_seconds, timer_expired(p)) + timer.start() + try: + sel.register(p.stdout, selectors.EVENT_READ) # type: ignore + sel.register(p.stderr, selectors.EVENT_READ) # type: ignore + done = False + while not done: + for key, _ in sel.select(): + char = key.fileobj.read(1) # type: ignore + if not char: + sel.unregister(key.fileobj) + if len(sel.get_map()) == 0: + sys.stdout.flush() + sys.stderr.flush() + sel.close() + done = True + if key.fileobj is p.stdout: + os.write(sys.stdout.fileno(), char) + if char in line_enders: + sys.stdout.flush() + else: + os.write(sys.stderr.fileno(), char) + if char in line_enders: + sys.stderr.flush() + p.wait() + finally: + if timer: + timer.cancel() + return p.returncode + + +def cmd_exitcode(command: str, timeout_seconds: Optional[float] = None) -> int: + """Run a command silently and return its exit code once it has + finished. If timeout_seconds is provided and the command runs too + long it will raise a TimeoutExpired exception. + + Args: + command: the command to run + timeout_seconds: the max number of seconds to allow the subprocess + to execute or None to indicate no timeout + + Returns: + the exit status of the subprocess once the subprocess has + exited + + >>> cmd_exitcode('/bin/echo foo', 10.0) + 0 + + >>> cmd_exitcode('/bin/sleep 2', 0.01) + Traceback (most recent call last): + ... + subprocess.TimeoutExpired: Command '['/bin/bash', '-c', '/bin/sleep 2']' timed out after 0.01 seconds + + """ + return subprocess.check_call(["/bin/bash", "-c", command], timeout=timeout_seconds) + + +def cmd(command: str, timeout_seconds: Optional[float] = None) -> str: + """Run a command and capture its output to stdout and stderr into a + string buffer. Return that string as this function's output. + Raises subprocess.CalledProcessError or TimeoutExpired on error. + + Args: + command: the command to run + timeout_seconds: the max number of seconds to allow the subprocess + to execute or None to indicate no timeout + + Returns: + The captured output of the subprocess' stdout as a string buffer + + >>> cmd('/bin/echo foo')[:-1] + 'foo' + + >>> cmd('/bin/sleep 2', 0.01) + Traceback (most recent call last): + ... + subprocess.TimeoutExpired: Command '/bin/sleep 2' timed out after 0.01 seconds + + """ + ret = subprocess.run( + command, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + check=True, + timeout=timeout_seconds, + ).stdout + return ret.decode("utf-8") + + +def run_silently(command: str, timeout_seconds: Optional[float] = None) -> None: + """Run a command silently but raise subprocess.CalledProcessError if + it fails and raise a TimeoutExpired if it runs too long. + + Args: + command: the command to run + timeout_seconds: the max number of seconds to allow the subprocess + to execute or None to indicate no timeout + + Returns: + No return value; error conditions (including non-zero child process + exits) produce exceptions. + + >>> run_silently("/usr/bin/true") + + >>> run_silently("/usr/bin/false") + Traceback (most recent call last): + ... + subprocess.CalledProcessError: Command '/usr/bin/false' returned non-zero exit status 1. + + """ + subprocess.run( + command, + shell=True, + stderr=subprocess.DEVNULL, + stdout=subprocess.DEVNULL, + capture_output=False, + check=True, + timeout=timeout_seconds, + ) + + +def cmd_in_background(command: str, *, silent: bool = False) -> subprocess.Popen: + """Spawns a child process in the background and registers an exit + handler to make sure we kill it if the parent process (us) is + terminated. + + Args: + command: the command to run + silent: do not allow any output from the child process to be displayed + in the parent process' window + + Returns: + the :class:`Popen` object that can be used to communicate + with the background process. + """ + args = shlex.split(command) + if silent: + subproc = subprocess.Popen( + args, + stdin=subprocess.DEVNULL, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + else: + subproc = subprocess.Popen(args, stdin=subprocess.DEVNULL) + + def kill_subproc() -> None: + try: + if subproc.poll() is None: + logger.info('At exit handler: killing %s (%s)', subproc, command) + subproc.terminate() + subproc.wait(timeout=10.0) + except BaseException as be: + logger.exception(be) + + atexit.register(kill_subproc) + return subproc + + +def cmd_list(command: List[str]) -> str: + """Run a command with args encapsulated in a list and return the + output text as a string. Raises subprocess.CalledProcessError. + """ + ret = subprocess.run(command, capture_output=True, check=True).stdout + return ret.decode("utf-8") + + +if __name__ == '__main__': + import doctest + + doctest.testmod() diff --git a/src/pyutils/files/__init__.py b/src/pyutils/files/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/pyutils/files/directory_filter.py b/src/pyutils/files/directory_filter.py new file mode 100644 index 0000000..3d0522b --- /dev/null +++ b/src/pyutils/files/directory_filter.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""Two predicates that can help avoid unnecessary disk I/O by +detecting if a particular file is identical to the contents about to +be written or if a particular directory already contains a file that +is identical to the one about to be written. See examples below. +""" + +import hashlib +import logging +import os +from typing import Any, Dict, Optional, Set + +logger = logging.getLogger(__name__) + + +class DirectoryFileFilter(object): + """A predicate that will return False if / when a proposed file's + content to-be-written is identical to the contents of the file on + disk allowing calling code to safely skip the write. + + >>> testfile = '/tmp/directory_filter_text_f39e5b58-c260-40da-9448-ad1c3b2a69c2.txt' + >>> contents = b'This is a test' + >>> with open(testfile, 'wb') as wf: + ... wf.write(contents) + 14 + + >>> d = DirectoryFileFilter('/tmp') + + >>> d.apply(contents, testfile) # False if testfile already contains contents + False + + >>> d.apply(b'That was a test', testfile) # True otherwise + True + + >>> os.remove(testfile) + + """ + + def __init__(self, directory: str): + """C'tor. + + Args: + directory: the directory we're filtering accesses to + """ + super().__init__() + from pyutils.files import file_utils + + if not file_utils.does_directory_exist(directory): + raise ValueError(directory) + self.directory = directory + self.md5_by_filename: Dict[str, str] = {} + self.mtime_by_filename: Dict[str, float] = {} + self._update() + + def _update(self): + """ + Internal method. Foreach file in the directory, compute its + MD5 checksum via :meth:`_update_file`. + """ + for direntry in os.scandir(self.directory): + if direntry.is_file(follow_symlinks=True): + mtime = direntry.stat(follow_symlinks=True).st_mtime + path = f'{self.directory}/{direntry.name}' + self._update_file(path, mtime) + + def _update_file(self, filename: str, mtime: Optional[float] = None): + """ + Internal method. Given a file and mtime, compute its MD5 checksum + and persist it in an internal map. + """ + from pyutils.files import file_utils + + assert file_utils.does_file_exist(filename) + if mtime is None: + mtime = file_utils.get_file_raw_mtime(filename) + assert mtime is not None + if self.mtime_by_filename.get(filename, 0) != mtime: + md5 = file_utils.get_file_md5(filename) + logger.debug( + 'Computed/stored %s\'s MD5 at ts=%.2f (%s)', filename, mtime, md5 + ) + self.mtime_by_filename[filename] = mtime + self.md5_by_filename[filename] = md5 + + def apply(self, proposed_contents: Any, filename: str) -> bool: + """Call this with the proposed new contents of filename in + memory and we'll compute the checksum of those contents and + return a value that indicates whether they are identical to + the disk contents already (so you can skip the write safely). + + Args: + proposed_contents: the contents about to be written to + filename + filename: the file about to be populated with + proposed_contents + + Returns: + True if the disk contents of the file are identical to + proposed_contents already and False otherwise. + """ + self._update_file(filename) + file_md5 = self.md5_by_filename.get(filename, 0) + logger.debug('%s\'s checksum is %s', filename, file_md5) + mem_hash = hashlib.md5() + mem_hash.update(proposed_contents) + md5 = mem_hash.hexdigest() + logger.debug('Item\'s checksum is %s', md5) + return md5 != file_md5 + + +class DirectoryAllFilesFilter(DirectoryFileFilter): + """A predicate that will return False if a file to-be-written to a + particular directory is identical to any other file in that same + directory (regardless of its name). + + i.e. this is the same as :class:`DirectoryFileFilter` except that + our apply() method will return true not only if the contents to be + written are identical to the contents of filename on the disk but + also it returns true if there exists some other file sitting in + the same directory which already contains those identical + contents. + + >>> testfile = '/tmp/directory_filter_text_f39e5b58-c260-40da-9448-ad1c3b2a69c3.txt' + + >>> contents = b'This is a test' + >>> with open(testfile, 'wb') as wf: + ... wf.write(contents) + 14 + + >>> d = DirectoryAllFilesFilter('/tmp') + + >>> d.apply(contents) # False is _any_ file in /tmp contains contents + False + + >>> d.apply(b'That was a test') # True otherwise + True + + >>> os.remove(testfile) + + """ + + def __init__(self, directory: str): + """C'tor. + + Args: + directory: the directory we're watching + """ + self.all_md5s: Set[str] = set() + super().__init__(directory) + + def _update_file(self, filename: str, mtime: Optional[float] = None): + """Internal method. Given a file and its mtime, update internal + state. + """ + from pyutils.files import file_utils + + assert file_utils.does_file_exist(filename) + if mtime is None: + mtime = file_utils.get_file_raw_mtime(filename) + assert mtime is not None + if self.mtime_by_filename.get(filename, 0) != mtime: + md5 = file_utils.get_file_md5(filename) + self.mtime_by_filename[filename] = mtime + self.md5_by_filename[filename] = md5 + self.all_md5s.add(md5) + + def apply(self, proposed_contents: Any, ignored_filename: str = None) -> bool: + """Call this before writing a new file to directory with the + proposed_contents to be written and it will return a value that + indicates whether the identical contents is already sitting in + *any* file in that directory. Useful, e.g., for caching. + + Args: + proposed_contents: the contents about to be persisted to + directory + ignored_filename: unused for now, must be None + + Returns: + True if proposed contents does not yet exist in any file in + directory or False if it does exist in some file already. + """ + assert ignored_filename is None + self._update() + mem_hash = hashlib.md5() + mem_hash.update(proposed_contents) + md5 = mem_hash.hexdigest() + return md5 not in self.all_md5s + + +if __name__ == '__main__': + import doctest + + doctest.testmod() diff --git a/src/pyutils/files/file_utils.py b/src/pyutils/files/file_utils.py new file mode 100644 index 0000000..dd6cf16 --- /dev/null +++ b/src/pyutils/files/file_utils.py @@ -0,0 +1,832 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""Utilities for working with files.""" + +import contextlib +import datetime +import errno +import fnmatch +import glob +import hashlib +import logging +import os +import pathlib +import re +import time +from os.path import exists, isfile, join +from typing import Callable, List, Literal, Optional, TextIO +from uuid import uuid4 + +logger = logging.getLogger(__name__) + + +def remove_newlines(x: str) -> str: + """Trivial function to be used as a line_transformer in + :meth:`slurp_file` for no newlines in file contents""" + return x.replace('\n', '') + + +def strip_whitespace(x: str) -> str: + """Trivial function to be used as a line_transformer in + :meth:`slurp_file` for no leading / trailing whitespace in + file contents""" + return x.strip() + + +def remove_hash_comments(x: str) -> str: + """Trivial function to be used as a line_transformer in + :meth:`slurp_file` for no # comments in file contents""" + return re.sub(r'#.*$', '', x) + + +def slurp_file( + filename: str, + *, + skip_blank_lines=False, + line_transformers: Optional[List[Callable[[str], str]]] = None, +): + """Reads in a file's contents line-by-line to a memory buffer applying + each line transformation in turn. + + Args: + filename: file to be read + skip_blank_lines: should reading skip blank lines? + line_transformers: little string->string transformations + """ + + ret = [] + xforms = [] + if line_transformers is not None: + for x in line_transformers: + xforms.append(x) + if not file_is_readable(filename): + raise Exception(f'{filename} can\'t be read.') + with open(filename) as rf: + for line in rf: + for transformation in xforms: + line = transformation(line) + if skip_blank_lines and line == '': + continue + ret.append(line) + return ret + + +def remove(path: str) -> None: + """Deletes a file. Raises if path refers to a directory or a file + that doesn't exist. + + Args: + path: the path of the file to delete + + >>> import os + >>> filename = '/tmp/file_utils_test_file' + >>> os.system(f'touch {filename}') + 0 + >>> does_file_exist(filename) + True + >>> remove(filename) + >>> does_file_exist(filename) + False + """ + os.remove(path) + + +def fix_multiple_slashes(path: str) -> str: + """Fixes multi-slashes in paths or path-like strings + + Args: + path: the path in which to remove multiple slashes + + >>> p = '/usr/local//etc/rc.d///file.txt' + >>> fix_multiple_slashes(p) + '/usr/local/etc/rc.d/file.txt' + + >>> p = 'this is a test' + >>> fix_multiple_slashes(p) == p + True + """ + return re.sub(r'/+', '/', path) + + +def delete(path: str) -> None: + """This is a convenience for my dumb ass who can't remember os.remove + sometimes. + """ + os.remove(path) + + +def without_extension(path: str) -> str: + """Remove one (the last) extension from a file or path. + + Args: + path: the path from which to remove an extension + + Returns: + the path with one extension removed. + + >>> without_extension('foobar.txt') + 'foobar' + + >>> without_extension('/home/scott/frapp.py') + '/home/scott/frapp' + + >>> f = 'a.b.c.tar.gz' + >>> while('.' in f): + ... f = without_extension(f) + ... print(f) + a.b.c.tar + a.b.c + a.b + a + + >>> without_extension('foobar') + 'foobar' + + """ + return os.path.splitext(path)[0] + + +def without_all_extensions(path: str) -> str: + """Removes all extensions from a path; handles multiple extensions + like foobar.tar.gz -> foobar. + + Args: + path: the path from which to remove all extensions + + Returns: + the path with all extensions removed. + + >>> without_all_extensions('/home/scott/foobar.1.tar.gz') + '/home/scott/foobar' + + """ + while '.' in path: + path = without_extension(path) + return path + + +def get_extension(path: str) -> str: + """Extract and return one (the last) extension from a file or path. + + Args: + path: the path from which to extract an extension + + Returns: + The last extension from the file path. + + >>> get_extension('this_is_a_test.txt') + '.txt' + + >>> get_extension('/home/scott/test.py') + '.py' + + >>> get_extension('foobar') + '' + + """ + return os.path.splitext(path)[1] + + +def get_all_extensions(path: str) -> List[str]: + """Return the extensions of a file or path in order. + + Args: + path: the path from which to extract all extensions. + + Returns: + a list containing each extension which may be empty. + + >>> get_all_extensions('/home/scott/foo.tar.gz.1') + ['.tar', '.gz', '.1'] + + >>> get_all_extensions('/home/scott/foobar') + [] + + """ + ret = [] + while True: + ext = get_extension(path) + path = without_extension(path) + if ext: + ret.append(ext) + else: + ret.reverse() + return ret + + +def without_path(filespec: str) -> str: + """Returns the base filename without any leading path. + + Args: + filespec: path to remove leading directories from + + Returns: + filespec without leading dir components. + + >>> without_path('/home/scott/foo.py') + 'foo.py' + + >>> without_path('foo.py') + 'foo.py' + + """ + return os.path.split(filespec)[1] + + +def get_path(filespec: str) -> str: + """Returns just the path of the filespec by removing the filename and + extension. + + Args: + filespec: path to remove filename / extension(s) from + + Returns: + filespec with just the leading directory components and no + filename or extension(s) + + >>> get_path('/home/scott/foobar.py') + '/home/scott' + + >>> get_path('/home/scott/test.1.2.3.gz') + '/home/scott' + + >>> get_path('~scott/frapp.txt') + '~scott' + + """ + return os.path.split(filespec)[0] + + +def get_canonical_path(filespec: str) -> str: + """Returns a canonicalized absolute path. + + Args: + filespec: the path to canonicalize + + Returns: + the canonicalized path + + >>> get_canonical_path('/home/scott/../../home/lynn/../scott/foo.txt') + '/usr/home/scott/foo.txt' + + """ + return os.path.realpath(filespec) + + +def create_path_if_not_exist(path, on_error=None) -> None: + """ + Attempts to create path if it does not exist already. + + .. warning:: + + Files are created with mode 0x0777 (i.e. world read/writeable). + + Args: + path: the path to attempt to create + on_error: If True, it's invoked on error conditions. Otherwise + any exceptions are raised. + + >>> import uuid + >>> import os + >>> path = os.path.join("/tmp", str(uuid.uuid4()), str(uuid.uuid4())) + >>> os.path.exists(path) + False + >>> create_path_if_not_exist(path) + >>> os.path.exists(path) + True + """ + logger.debug("Creating path %s", path) + previous_umask = os.umask(0) + try: + os.makedirs(path) + os.chmod(path, 0o777) + except OSError as ex: + if ex.errno != errno.EEXIST and not os.path.isdir(path): + if on_error is not None: + on_error(path, ex) + else: + raise + finally: + os.umask(previous_umask) + + +def does_file_exist(filename: str) -> bool: + """Returns True if a file exists and is a normal file. + + Args: + filename: filename to check + + Returns: + True if filename exists and is a normal file. + + >>> does_file_exist(__file__) + True + >>> does_file_exist('/tmp/2492043r9203r9230r9230r49230r42390r4230') + False + """ + return os.path.exists(filename) and os.path.isfile(filename) + + +def file_is_readable(filename: str) -> bool: + """True if file exists, is a normal file and is readable by the + current process. False otherwise. + + Args: + filename: the filename to check for read access + """ + return does_file_exist(filename) and os.access(filename, os.R_OK) + + +def file_is_writable(filename: str) -> bool: + """True if file exists, is a normal file and is writable by the + current process. False otherwise. + + Args: + filename: the file to check for write access. + """ + return does_file_exist(filename) and os.access(filename, os.W_OK) + + +def file_is_executable(filename: str) -> bool: + """True if file exists, is a normal file and is executable by the + current process. False otherwise. + + Args: + filename: the file to check for execute access. + """ + return does_file_exist(filename) and os.access(filename, os.X_OK) + + +def does_directory_exist(dirname: str) -> bool: + """Returns True if a file exists and is a directory. + + >>> does_directory_exist('/tmp') + True + >>> does_directory_exist('/xyzq/21341') + False + """ + return os.path.exists(dirname) and os.path.isdir(dirname) + + +def does_path_exist(pathname: str) -> bool: + """Just a more verbose wrapper around os.path.exists.""" + return os.path.exists(pathname) + + +def get_file_size(filename: str) -> int: + """Returns the size of a file in bytes. + + Args: + filename: the filename to size + + Returns: + size of filename in bytes + """ + return os.path.getsize(filename) + + +def is_normal_file(filename: str) -> bool: + """Returns True if filename is a normal file. + + >>> is_normal_file(__file__) + True + """ + return os.path.isfile(filename) + + +def is_directory(filename: str) -> bool: + """Returns True if filename is a directory. + + >>> is_directory('/tmp') + True + """ + return os.path.isdir(filename) + + +def is_symlink(filename: str) -> bool: + """True if filename is a symlink, False otherwise. + + >>> is_symlink('/tmp') + False + + >>> is_symlink('/home') + True + + """ + return os.path.islink(filename) + + +def is_same_file(file1: str, file2: str) -> bool: + """Returns True if the two files are the same inode. + + >>> is_same_file('/tmp', '/tmp/../tmp') + True + + >>> is_same_file('/tmp', '/home') + False + + """ + return os.path.samefile(file1, file2) + + +def get_file_raw_timestamps(filename: str) -> Optional[os.stat_result]: + """Stats the file and returns an os.stat_result or None on error. + + Args: + filename: the file whose timestamps to fetch + + Returns: + the os.stat_result or None to indicate an error occurred + """ + try: + return os.stat(filename) + except Exception as e: + logger.exception(e) + return None + + +def get_file_raw_timestamp( + filename: str, extractor: Callable[[os.stat_result], Optional[float]] +) -> Optional[float]: + """Stat a file and, if successful, use extractor to fetch some + subset of the information in the os.stat_result. See also + :meth:`get_file_raw_atime`, :meth:`get_file_raw_mtime`, and + :meth:`get_file_raw_ctime` which just call this with a lambda + extractor. + + Args: + filename: the filename to stat + extractor: Callable that takes a os.stat_result and produces + something useful(?) with it. + + Returns: + whatever the extractor produced or None on error. + """ + tss = get_file_raw_timestamps(filename) + if tss is not None: + return extractor(tss) + return None + + +def get_file_raw_atime(filename: str) -> Optional[float]: + """Get a file's raw access time or None on error. + + See also :meth:`get_file_atime_as_datetime`, + :meth:`get_file_atime_timedelta`, + and :meth:`get_file_atime_age_seconds`. + """ + return get_file_raw_timestamp(filename, lambda x: x.st_atime) + + +def get_file_raw_mtime(filename: str) -> Optional[float]: + """Get a file's raw modification time or None on error. + + See also :meth:`get_file_mtime_as_datetime`, + :meth:`get_file_mtime_timedelta`, + and :meth:`get_file_mtime_age_seconds`. + """ + return get_file_raw_timestamp(filename, lambda x: x.st_mtime) + + +def get_file_raw_ctime(filename: str) -> Optional[float]: + """Get a file's raw creation time or None on error. + + See also :meth:`get_file_ctime_as_datetime`, + :meth:`get_file_ctime_timedelta`, + and :meth:`get_file_ctime_age_seconds`. + """ + return get_file_raw_timestamp(filename, lambda x: x.st_ctime) + + +def get_file_md5(filename: str) -> str: + """Hashes filename's disk contents and returns the MD5 digest. + + Args: + filename: the file whose contents to hash + + Returns: + the MD5 digest of the file's contents. Raises on errors. + """ + file_hash = hashlib.md5() + with open(filename, "rb") as f: + chunk = f.read(8192) + while chunk: + file_hash.update(chunk) + chunk = f.read(8192) + return file_hash.hexdigest() + + +def set_file_raw_atime(filename: str, atime: float): + """Sets a file's raw access time. + + See also :meth:`get_file_atime_as_datetime`, + :meth:`get_file_atime_timedelta`, + :meth:`get_file_atime_age_seconds`, + and :meth:`get_file_raw_atime`. + """ + mtime = get_file_raw_mtime(filename) + assert mtime is not None + os.utime(filename, (atime, mtime)) + + +def set_file_raw_mtime(filename: str, mtime: float): + """Sets a file's raw modification time. + + See also :meth:`get_file_mtime_as_datetime`, + :meth:`get_file_mtime_timedelta`, + :meth:`get_file_mtime_age_seconds`, + and :meth:`get_file_raw_mtime`. + """ + atime = get_file_raw_atime(filename) + assert atime is not None + os.utime(filename, (atime, mtime)) + + +def set_file_raw_atime_and_mtime(filename: str, ts: float = None): + """Sets both a file's raw modification and access times + + Args: + filename: the file whose times to set + ts: the raw time to set or None to indicate time should be + set to the current time. + """ + if ts is not None: + os.utime(filename, (ts, ts)) + else: + os.utime(filename, None) + + +def convert_file_timestamp_to_datetime( + filename: str, producer +) -> Optional[datetime.datetime]: + """Convert a raw file timestamp into a python datetime.""" + ts = producer(filename) + if ts is not None: + return datetime.datetime.fromtimestamp(ts) + return None + + +def get_file_atime_as_datetime(filename: str) -> Optional[datetime.datetime]: + """Fetch a file's access time as a python datetime. + + See also :meth:`get_file_atime_as_datetime`, + :meth:`get_file_atime_timedelta`, + :meth:`get_file_atime_age_seconds`, + :meth:`describe_file_atime`, + and :meth:`get_file_raw_atime`. + """ + return convert_file_timestamp_to_datetime(filename, get_file_raw_atime) + + +def get_file_mtime_as_datetime(filename: str) -> Optional[datetime.datetime]: + """Fetches a file's modification time as a python datetime. + + See also :meth:`get_file_mtime_as_datetime`, + :meth:`get_file_mtime_timedelta`, + :meth:`get_file_mtime_age_seconds`, + and :meth:`get_file_raw_mtime`. + """ + return convert_file_timestamp_to_datetime(filename, get_file_raw_mtime) + + +def get_file_ctime_as_datetime(filename: str) -> Optional[datetime.datetime]: + """Fetches a file's creation time as a python datetime. + + See also :meth:`get_file_ctime_as_datetime`, + :meth:`get_file_ctime_timedelta`, + :meth:`get_file_ctime_age_seconds`, + and :meth:`get_file_raw_ctime`. + """ + return convert_file_timestamp_to_datetime(filename, get_file_raw_ctime) + + +def get_file_timestamp_age_seconds(filename: str, extractor) -> Optional[int]: + """~Internal helper""" + now = time.time() + ts = get_file_raw_timestamps(filename) + if ts is None: + return None + result = extractor(ts) + return now - result + + +def get_file_atime_age_seconds(filename: str) -> Optional[int]: + """Gets a file's access time as an age in seconds (ago). + + See also :meth:`get_file_atime_as_datetime`, + :meth:`get_file_atime_timedelta`, + :meth:`get_file_atime_age_seconds`, + :meth:`describe_file_atime`, + and :meth:`get_file_raw_atime`. + """ + return get_file_timestamp_age_seconds(filename, lambda x: x.st_atime) + + +def get_file_ctime_age_seconds(filename: str) -> Optional[int]: + """Gets a file's creation time as an age in seconds (ago). + + See also :meth:`get_file_ctime_as_datetime`, + :meth:`get_file_ctime_timedelta`, + :meth:`get_file_ctime_age_seconds`, + and :meth:`get_file_raw_ctime`. + """ + return get_file_timestamp_age_seconds(filename, lambda x: x.st_ctime) + + +def get_file_mtime_age_seconds(filename: str) -> Optional[int]: + """Gets a file's modification time as seconds (ago). + + See also :meth:`get_file_mtime_as_datetime`, + :meth:`get_file_mtime_timedelta`, + :meth:`get_file_mtime_age_seconds`, + and :meth:`get_file_raw_mtime`. + """ + return get_file_timestamp_age_seconds(filename, lambda x: x.st_mtime) + + +def get_file_timestamp_timedelta( + filename: str, extractor +) -> Optional[datetime.timedelta]: + """~Internal helper""" + age = get_file_timestamp_age_seconds(filename, extractor) + if age is not None: + return datetime.timedelta(seconds=float(age)) + return None + + +def get_file_atime_timedelta(filename: str) -> Optional[datetime.timedelta]: + """How long ago was a file accessed as a timedelta? + + See also :meth:`get_file_atime_as_datetime`, + :meth:`get_file_atime_timedelta`, + :meth:`get_file_atime_age_seconds`, + :meth:`describe_file_atime`, + and :meth:`get_file_raw_atime`. + """ + return get_file_timestamp_timedelta(filename, lambda x: x.st_atime) + + +def get_file_ctime_timedelta(filename: str) -> Optional[datetime.timedelta]: + """How long ago was a file created as a timedelta? + + See also :meth:`get_file_ctime_as_datetime`, + :meth:`get_file_ctime_timedelta`, + :meth:`get_file_ctime_age_seconds`, + and :meth:`get_file_raw_ctime`. + """ + return get_file_timestamp_timedelta(filename, lambda x: x.st_ctime) + + +def get_file_mtime_timedelta(filename: str) -> Optional[datetime.timedelta]: + """ + Gets a file's modification time as a python timedelta. + + See also :meth:`get_file_mtime_as_datetime`, + :meth:`get_file_mtime_timedelta`, + :meth:`get_file_mtime_age_seconds`, + and :meth:`get_file_raw_mtime`. + """ + return get_file_timestamp_timedelta(filename, lambda x: x.st_mtime) + + +def describe_file_timestamp(filename: str, extractor, *, brief=False) -> Optional[str]: + """~Internal helper""" + from pyutils.datetimez.datetime_utils import ( + describe_duration, + describe_duration_briefly, + ) + + age = get_file_timestamp_age_seconds(filename, extractor) + if age is None: + return None + if brief: + return describe_duration_briefly(age) + else: + return describe_duration(age) + + +def describe_file_atime(filename: str, *, brief=False) -> Optional[str]: + """ + Describe how long ago a file was accessed. + + See also :meth:`get_file_atime_as_datetime`, + :meth:`get_file_atime_timedelta`, + :meth:`get_file_atime_age_seconds`, + :meth:`describe_file_atime`, + and :meth:`get_file_raw_atime`. + """ + return describe_file_timestamp(filename, lambda x: x.st_atime, brief=brief) + + +def describe_file_ctime(filename: str, *, brief=False) -> Optional[str]: + """Describes a file's creation time. + + See also :meth:`get_file_ctime_as_datetime`, + :meth:`get_file_ctime_timedelta`, + :meth:`get_file_ctime_age_seconds`, + and :meth:`get_file_raw_ctime`. + """ + return describe_file_timestamp(filename, lambda x: x.st_ctime, brief=brief) + + +def describe_file_mtime(filename: str, *, brief=False) -> Optional[str]: + """ + Describes how long ago a file was modified. + + See also :meth:`get_file_mtime_as_datetime`, + :meth:`get_file_mtime_timedelta`, + :meth:`get_file_mtime_age_seconds`, + and :meth:`get_file_raw_mtime`. + """ + return describe_file_timestamp(filename, lambda x: x.st_mtime, brief=brief) + + +def touch_file(filename: str, *, mode: Optional[int] = 0o666): + """Like unix "touch" command's semantics: update the timestamp + of a file to the current time if the file exists. Create the + file if it doesn't exist. + + Args: + filename: the filename + mode: the mode to create the file with + """ + pathlib.Path(filename, mode=mode).touch() + + +def expand_globs(in_filename: str): + """Expands shell globs (* and ? wildcards) to the matching files.""" + for filename in glob.glob(in_filename): + yield filename + + +def get_files(directory: str): + """Returns the files in a directory as a generator.""" + for filename in os.listdir(directory): + full_path = join(directory, filename) + if isfile(full_path) and exists(full_path): + yield full_path + + +def get_matching_files(directory: str, glob: str): + """Returns the subset of files whose name matches a glob.""" + for filename in get_files(directory): + if fnmatch.fnmatch(filename, glob): + yield filename + + +def get_directories(directory: str): + """Returns the subdirectories in a directory as a generator.""" + for d in os.listdir(directory): + full_path = join(directory, d) + if not isfile(full_path) and exists(full_path): + yield full_path + + +def get_files_recursive(directory: str): + """Find the files and directories under a root recursively.""" + for filename in get_files(directory): + yield filename + for subdir in get_directories(directory): + for file_or_directory in get_files_recursive(subdir): + yield file_or_directory + + +def get_matching_files_recursive(directory: str, glob: str): + """Returns the subset of files whose name matches a glob under a root recursively.""" + for filename in get_files_recursive(directory): + if fnmatch.fnmatch(filename, glob): + yield filename + + +class FileWriter(contextlib.AbstractContextManager): + """A helper that writes a file to a temporary location and then moves + it atomically to its ultimate destination on close. + """ + + def __init__(self, filename: str) -> None: + self.filename = filename + uuid = uuid4() + self.tempfile = f'{filename}-{uuid}.tmp' + self.handle: Optional[TextIO] = None + + def __enter__(self) -> TextIO: + assert not does_path_exist(self.tempfile) + self.handle = open(self.tempfile, mode="w") + return self.handle + + def __exit__(self, exc_type, exc_val, exc_tb) -> Literal[False]: + if self.handle is not None: + self.handle.close() + cmd = f'/bin/mv -f {self.tempfile} {self.filename}' + ret = os.system(cmd) + if (ret >> 8) != 0: + raise Exception(f'{cmd} failed, exit value {ret>>8}!') + return False + + +if __name__ == '__main__': + import doctest + + doctest.testmod() diff --git a/src/pyutils/files/lockfile.py b/src/pyutils/files/lockfile.py new file mode 100644 index 0000000..11bb100 --- /dev/null +++ b/src/pyutils/files/lockfile.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""File-based locking helper.""" + +from __future__ import annotations + +import contextlib +import datetime +import json +import logging +import os +import signal +import sys +import warnings +from dataclasses import dataclass +from typing import Literal, Optional + +from pyutils import config, decorator_utils +from pyutils.datetimez import datetime_utils + +cfg = config.add_commandline_args(f'Lockfile ({__file__})', 'Args related to lockfiles') +cfg.add_argument( + '--lockfile_held_duration_warning_threshold_sec', + type=float, + default=60.0, + metavar='SECONDS', + help='If a lock is held for longer than this threshold we log a warning', +) +logger = logging.getLogger(__name__) + + +class LockFileException(Exception): + """An exception related to lock files.""" + + pass + + +@dataclass +class LockFileContents: + """The contents we'll write to each lock file.""" + + pid: int + """The pid of the process that holds the lock""" + + commandline: str + """The commandline of the process that holds the lock""" + + expiration_timestamp: Optional[float] + """When this lock will expire as seconds since Epoch""" + + +class LockFile(contextlib.AbstractContextManager): + """A file locking mechanism that has context-manager support so you + can use it in a with statement. e.g.:: + + with LockFile('./foo.lock'): + # do a bunch of stuff... if the process dies we have a signal + # handler to do cleanup. Other code (in this process or another) + # that tries to take the same lockfile will block. There is also + # some logic for detecting stale locks. + """ + + def __init__( + self, + lockfile_path: str, + *, + do_signal_cleanup: bool = True, + expiration_timestamp: Optional[float] = None, + override_command: Optional[str] = None, + ) -> None: + """C'tor. + + Args: + lockfile_path: path of the lockfile to acquire + do_signal_cleanup: handle SIGINT and SIGTERM events by + releasing the lock before exiting + expiration_timestamp: when our lease on the lock should + expire (as seconds since the Epoch). None means the + lock will not expire until we explicltly release it. + override_command: don't use argv to determine our commandline + rather use this instead if provided. + """ + self.is_locked: bool = False + self.lockfile: str = lockfile_path + self.locktime: Optional[int] = None + self.override_command: Optional[str] = override_command + if do_signal_cleanup: + signal.signal(signal.SIGINT, self._signal) + signal.signal(signal.SIGTERM, self._signal) + self.expiration_timestamp = expiration_timestamp + + def locked(self): + """Is it locked currently?""" + return self.is_locked + + def available(self): + """Is it available currently?""" + return not os.path.exists(self.lockfile) + + def try_acquire_lock_once(self) -> bool: + """Attempt to acquire the lock with no blocking. + + Returns: + True if the lock was acquired and False otherwise. + """ + logger.debug("Trying to acquire %s.", self.lockfile) + try: + # Attempt to create the lockfile. These flags cause + # os.open to raise an OSError if the file already + # exists. + fd = os.open(self.lockfile, os.O_CREAT | os.O_EXCL | os.O_RDWR) + with os.fdopen(fd, "a") as f: + contents = self._get_lockfile_contents() + logger.debug(contents) + f.write(contents) + logger.debug('Success; I own %s.', self.lockfile) + self.is_locked = True + return True + except OSError: + pass + logger.warning('Couldn\'t acquire %s.', self.lockfile) + return False + + def acquire_with_retries( + self, + *, + initial_delay: float = 1.0, + backoff_factor: float = 2.0, + max_attempts=5, + ) -> bool: + """Attempt to acquire the lock repeatedly with retries and backoffs. + + Args: + initial_delay: how long to wait before retrying the first time + backoff_factor: a float >= 1.0 the multiples the current retry + delay each subsequent time we attempt to acquire and fail + to do so. + max_attempts: maximum number of times to try before giving up + and failing. + + Returns: + True if the lock was acquired and False otherwise. + """ + + @decorator_utils.retry_if_false( + tries=max_attempts, delay_sec=initial_delay, backoff=backoff_factor + ) + def _try_acquire_lock_with_retries() -> bool: + success = self.try_acquire_lock_once() + if not success and os.path.exists(self.lockfile): + self._detect_stale_lockfile() + return success + + if os.path.exists(self.lockfile): + self._detect_stale_lockfile() + return _try_acquire_lock_with_retries() + + def release(self): + """Release the lock""" + try: + os.unlink(self.lockfile) + except Exception as e: + logger.exception(e) + self.is_locked = False + + def __enter__(self): + if self.acquire_with_retries(): + self.locktime = datetime.datetime.now().timestamp() + return self + msg = f"Couldn't acquire {self.lockfile}; giving up." + logger.warning(msg) + raise LockFileException(msg) + + def __exit__(self, _, value, traceback) -> Literal[False]: + if self.locktime: + ts = datetime.datetime.now().timestamp() + duration = ts - self.locktime + if ( + duration + >= config.config['lockfile_held_duration_warning_threshold_sec'] + ): + # Note: describe duration briefly only does 1s granularity... + str_duration = datetime_utils.describe_duration_briefly(int(duration)) + msg = f'Held {self.lockfile} for {str_duration}' + logger.warning(msg) + warnings.warn(msg, stacklevel=2) + self.release() + return False + + def __del__(self): + if self.is_locked: + self.release() + + def _signal(self, *args): + if self.is_locked: + self.release() + + def _get_lockfile_contents(self) -> str: + if self.override_command: + cmd = self.override_command + else: + cmd = ' '.join(sys.argv) + contents = LockFileContents( + pid=os.getpid(), + commandline=cmd, + expiration_timestamp=self.expiration_timestamp, + ) + return json.dumps(contents.__dict__) + + def _detect_stale_lockfile(self) -> None: + try: + with open(self.lockfile, 'r') as rf: + lines = rf.readlines() + if len(lines) == 1: + line = lines[0] + line_dict = json.loads(line) + contents = LockFileContents(**line_dict) + logger.debug('Blocking lock contents="%s"', contents) + + # Does the PID exist still? + try: + os.kill(contents.pid, 0) + except OSError: + logger.warning( + 'Lockfile %s\'s pid (%d) is stale; force acquiring...', + self.lockfile, + contents.pid, + ) + self.release() + + # Has the lock expiration expired? + if contents.expiration_timestamp is not None: + now = datetime.datetime.now().timestamp() + if now > contents.expiration_timestamp: + logger.warning( + 'Lockfile %s\'s expiration time has passed; force acquiring', + self.lockfile, + ) + self.release() + except Exception: + pass # If the lockfile doesn't exist or disappears, good. diff --git a/src/pyutils/function_utils.py b/src/pyutils/function_utils.py new file mode 100644 index 0000000..a8ab0c7 --- /dev/null +++ b/src/pyutils/function_utils.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""Helper methods dealing with functions.""" + +from typing import Callable + + +def function_identifier(f: Callable) -> str: + """ + Given a callable function, return a string that identifies it. + Usually that string is just __module__:__name__ but there's a + corner case: when __module__ is __main__ (i.e. the callable is + defined in the same module as __main__). In this case, + f.__module__ returns "__main__" instead of the file that it is + defined in. Work around this using pathlib.Path (see below). + + >>> function_identifier(function_identifier) + 'function_utils:function_identifier' + """ + + if f.__module__ == '__main__': + from pathlib import Path + + import __main__ + + module = __main__.__file__ + module = Path(module).stem + return f'{module}:{f.__name__}' + else: + return f'{f.__module__}:{f.__name__}' diff --git a/src/pyutils/id_generator.py b/src/pyutils/id_generator.py new file mode 100644 index 0000000..4b61a93 --- /dev/null +++ b/src/pyutils/id_generator.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""A helper class for generating thread safe monotonically increasing +id numbers. + +""" + +import itertools +import logging + +# This module is commonly used by others in here and should avoid +# taking any unnecessary dependencies back on them. + +logger = logging.getLogger(__name__) +generators = {} + + +def get(name: str, *, start=0) -> int: + """ + Returns a thread-safe, monotonically increasing id suitable for use + as a globally unique identifier. + + >>> import id_generator + >>> id_generator.get('student_id') + 0 + >>> id_generator.get('student_id') + 1 + >>> id_generator.get('employee_id', start=10000) + 10000 + >>> id_generator.get('employee_id', start=10000) + 10001 + """ + if name not in generators: + generators[name] = itertools.count(start, 1) + x = next(generators[name]) + logger.debug("Generated next id %d", x) + return x + + +if __name__ == '__main__': + import doctest + + doctest.testmod() diff --git a/src/pyutils/iter_utils.py b/src/pyutils/iter_utils.py new file mode 100644 index 0000000..c6daddf --- /dev/null +++ b/src/pyutils/iter_utils.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""A collection if :class:`Iterator` subclasses that can be composed +with another iterator and provide extra functionality. e.g. + + + :class:`PeekingIterator` + + :class:`PushbackIterator` + + :class:`SamplingIterator` + +""" + +import random +from collections.abc import Iterator +from typing import Any, List, Optional + + +class PeekingIterator(Iterator): + """An iterator that lets you :meth:`peek` at the next item on deck. + Returns None when there is no next item (i.e. when + :meth:`__next__` will produce a StopIteration exception). + + >>> p = PeekingIterator(iter(range(3))) + >>> p.__next__() + 0 + >>> p.peek() + 1 + >>> p.peek() + 1 + >>> p.__next__() + 1 + >>> p.__next__() + 2 + >>> p.peek() == None + True + >>> p.__next__() + Traceback (most recent call last): + ... + StopIteration + + """ + + def __init__(self, source_iter: Iterator): + self.source_iter = source_iter + self.on_deck: List[Any] = [] + + def __iter__(self) -> Iterator: + return self + + def __next__(self) -> Any: + if len(self.on_deck) > 0: + return self.on_deck.pop() + else: + item = self.source_iter.__next__() + return item + + def peek(self) -> Optional[Any]: + if len(self.on_deck) > 0: + return self.on_deck[0] + try: + item = self.source_iter.__next__() + self.on_deck.append(item) + return self.peek() + except StopIteration: + return None + + +class PushbackIterator(Iterator): + """An iterator that allows you to push items back + onto the front of the sequence. e.g. + + >>> i = PushbackIterator(iter(range(3))) + >>> i.__next__() + 0 + >>> i.push_back(99) + >>> i.push_back(98) + >>> i.__next__() + 98 + >>> i.__next__() + 99 + >>> i.__next__() + 1 + >>> i.__next__() + 2 + >>> i.push_back(100) + >>> i.__next__() + 100 + >>> i.__next__() + Traceback (most recent call last): + ... + StopIteration + """ + + def __init__(self, source_iter: Iterator): + self.source_iter = source_iter + self.pushed_back: List[Any] = [] + + def __iter__(self) -> Iterator: + return self + + def __next__(self) -> Any: + if len(self.pushed_back) > 0: + return self.pushed_back.pop() + return self.source_iter.__next__() + + def push_back(self, item: Any): + self.pushed_back.append(item) + + +class SamplingIterator(Iterator): + """An iterator that simply echoes what source_iter produces but also + collects a random sample (of size sample_size) of the stream that can + be queried at any time. + + .. note:: + Until sample_size elements have been seen the sample will be + less than sample_size elements in length. + + .. note:: + If sample_size is > len(source_iter) then it will produce a + copy of source_iter. + + >>> import collections + >>> import random + + >>> random.seed(22) + >>> s = SamplingIterator(iter(range(100)), 10) + >>> s.__next__() + 0 + + >>> s.__next__() + 1 + + >>> s.get_sample() + [0, 1] + + >>> collections.deque(s) + deque([2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]) + + >>> s.get_sample() + [78, 18, 47, 83, 93, 26, 25, 73, 94, 38] + + """ + + def __init__(self, source_iter: Iterator, sample_size: int): + self.source_iter = source_iter + self.sample_size = sample_size + self.resovoir: List[Any] = [] + self.stream_length_so_far = 0 + + def __iter__(self) -> Iterator: + return self + + def __next__(self) -> Any: + item = self.source_iter.__next__() + self.stream_length_so_far += 1 + + # Filling the resovoir + pop = len(self.resovoir) + if pop < self.sample_size: + self.resovoir.append(item) + if self.sample_size == (pop + 1): # just finished filling... + random.shuffle(self.resovoir) + + # Swap this item for one in the resovoir with probabilty + # sample_size / stream_length_so_far. See: + # + # https://en.wikipedia.org/wiki/Reservoir_sampling + else: + r = random.randint(0, self.stream_length_so_far) + if r < self.sample_size: + self.resovoir[r] = item + return item + + def get_sample(self) -> List[Any]: + return self.resovoir + + +if __name__ == '__main__': + import doctest + + doctest.testmod() diff --git a/src/pyutils/list_utils.py b/src/pyutils/list_utils.py new file mode 100644 index 0000000..c67db7d --- /dev/null +++ b/src/pyutils/list_utils.py @@ -0,0 +1,332 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""Some useful(?) utilities for dealing with Lists.""" + +import random +from collections import Counter +from itertools import chain, combinations, islice +from typing import Any, Iterator, List, MutableSequence, Sequence, Tuple + + +def shard(lst: List[Any], size: int) -> Iterator[Any]: + """ + Yield successive size-sized shards from lst. + + >>> for sublist in shard([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 3): + ... [_ for _ in sublist] + [1, 2, 3] + [4, 5, 6] + [7, 8, 9] + [10, 11, 12] + + """ + for x in range(0, len(lst), size): + yield islice(lst, x, x + size) + + +def flatten(lst: List[Any]) -> List[Any]: + """ + Flatten out a list: + + >>> flatten([ 1, [2, 3, 4, [5], 6], 7, [8, [9]]]) + [1, 2, 3, 4, 5, 6, 7, 8, 9] + + """ + if len(lst) == 0: + return lst + if isinstance(lst[0], list): + return flatten(lst[0]) + flatten(lst[1:]) + return lst[:1] + flatten(lst[1:]) + + +def prepend(item: Any, lst: List[Any]) -> List[Any]: + """ + Prepend an item to a list. + + >>> prepend('foo', ['bar', 'baz']) + ['foo', 'bar', 'baz'] + + """ + lst.insert(0, item) + return lst + + +def remove_list_if_one_element(lst: List[Any]) -> Any: + """ + Remove the list and return the 0th element iff its length is one. + + >>> remove_list_if_one_element([1234]) + 1234 + + >>> remove_list_if_one_element([1, 2, 3, 4]) + [1, 2, 3, 4] + + """ + if len(lst) == 1: + return lst[0] + else: + return lst + + +def population_counts(lst: Sequence[Any]) -> Counter: + """ + Return a population count mapping for the list (i.e. the keys are + list items and the values are the number of occurrances of that + list item in the original list. + + >>> population_counts([1, 1, 1, 2, 2, 3, 3, 3, 4]) + Counter({1: 3, 3: 3, 2: 2, 4: 1}) + + """ + return Counter(lst) + + +def most_common(lst: List[Any], *, count=1) -> Any: + + """ + Return the most common item in the list. In the case of ties, + which most common item is returned will be random. + + >>> most_common([1, 1, 1, 2, 2, 3, 3, 3, 3, 4, 4]) + 3 + + >>> most_common([1, 1, 1, 2, 2, 3, 3, 3, 3, 4, 4], count=2) + [3, 1] + + """ + p = population_counts(lst) + return remove_list_if_one_element([_[0] for _ in p.most_common()[0:count]]) + + +def least_common(lst: List[Any], *, count=1) -> Any: + """ + Return the least common item in the list. In the case of + ties, which least common item is returned will be random. + + >>> least_common([1, 1, 1, 2, 2, 3, 3, 3, 4]) + 4 + + >>> least_common([1, 1, 1, 2, 2, 3, 3, 3, 4], count=2) + [4, 2] + + """ + p = population_counts(lst) + mc = p.most_common()[-count:] + mc.reverse() + return remove_list_if_one_element([_[0] for _ in mc]) + + +def dedup_list(lst: List[Any]) -> List[Any]: + """ + Remove duplicates from the list performantly. + + >>> dedup_list([1, 2, 1, 3, 3, 4, 2, 3, 4, 5, 1]) + [1, 2, 3, 4, 5] + + """ + return list(set(lst)) + + +def uniq(lst: List[Any]) -> List[Any]: + """ + Alias for dedup_list. + """ + return dedup_list(lst) + + +def contains_duplicates(lst: List[Any]) -> bool: + """ + Does the list contian duplicate elements or not? + + >>> lst = [1, 2, 1, 3, 3, 4, 4, 5, 6, 1, 3, 4] + >>> contains_duplicates(lst) + True + + >>> contains_duplicates(dedup_list(lst)) + False + + """ + seen = set() + for _ in lst: + if _ in seen: + return True + seen.add(_) + return False + + +def all_unique(lst: List[Any]) -> bool: + """ + Inverted alias for contains_duplicates. + """ + return not contains_duplicates(lst) + + +def transpose(lst: List[Any]) -> List[Any]: + """ + Transpose a list of lists. + + >>> lst = [[1, 2], [3, 4], [5, 6]] + >>> transpose(lst) + [[1, 3, 5], [2, 4, 6]] + + """ + transposed = zip(*lst) + return [list(_) for _ in transposed] + + +def ngrams(lst: Sequence[Any], n): + """ + Return the ngrams in the sequence. + + >>> seq = 'encyclopedia' + >>> for _ in ngrams(seq, 3): + ... _ + 'enc' + 'ncy' + 'cyc' + 'ycl' + 'clo' + 'lop' + 'ope' + 'ped' + 'edi' + 'dia' + + >>> seq = ['this', 'is', 'an', 'awesome', 'test'] + >>> for _ in ngrams(seq, 3): + ... _ + ['this', 'is', 'an'] + ['is', 'an', 'awesome'] + ['an', 'awesome', 'test'] + """ + for i in range(len(lst) - n + 1): + yield lst[i : i + n] + + +def permute(seq: str): + """ + Returns all permutations of a sequence; takes O(N!) time. + + >>> for x in permute('cat'): + ... print(x) + cat + cta + act + atc + tca + tac + + """ + yield from _permute(seq, "") + + +def _permute(seq: str, path: str): + seq_len = len(seq) + if seq_len == 0: + yield path + + for i in range(seq_len): + car = seq[i] + left = seq[0:i] + right = seq[i + 1 :] + cdr = left + right + yield from _permute(cdr, path + car) + + +def shuffle(seq: MutableSequence[Any]) -> MutableSequence[Any]: + """Shuffles a sequence into a random order. + + >>> random.seed(22) + >>> shuffle([1, 2, 3, 4, 5]) + [3, 4, 1, 5, 2] + + >>> shuffle('example') + 'empaelx' + + """ + if isinstance(seq, str): + from pyutils import string_utils + + return string_utils.shuffle(seq) + else: + random.shuffle(seq) + return seq + + +def scramble(seq: MutableSequence[Any]) -> MutableSequence[Any]: + return shuffle(seq) + + +def binary_search(lst: Sequence[Any], target: Any) -> Tuple[bool, int]: + """Performs a binary search on lst (which must already be sorted). + Returns a Tuple composed of a bool which indicates whether the + target was found and an int which indicates the index closest to + target whether it was found or not. + + >>> a = [1, 4, 5, 6, 7, 9, 10, 11] + >>> binary_search(a, 4) + (True, 1) + + >>> binary_search(a, 12) + (False, 8) + + >>> binary_search(a, 3) + (False, 1) + + >>> binary_search(a, 2) + (False, 1) + + >>> a.append(9) + >>> binary_search(a, 4) + Traceback (most recent call last): + ... + AssertionError + + """ + if __debug__: + last = None + for x in lst: + if last is not None: + assert x >= last # This asserts iff the list isn't sorted + last = x # in ascending order. + return _binary_search(lst, target, 0, len(lst) - 1) + + +def _binary_search( + lst: Sequence[Any], target: Any, low: int, high: int +) -> Tuple[bool, int]: + if high >= low: + mid = (high + low) // 2 + if lst[mid] == target: + return (True, mid) + elif lst[mid] > target: + return _binary_search(lst, target, low, mid - 1) + else: + return _binary_search(lst, target, mid + 1, high) + else: + return (False, low) + + +def powerset(lst: Sequence[Any]) -> Iterator[Sequence[Any]]: + """Returns the powerset of the items in the input sequence. + + >>> for x in powerset([1, 2, 3]): + ... print(x) + () + (1,) + (2,) + (3,) + (1, 2) + (1, 3) + (2, 3) + (1, 2, 3) + """ + return chain.from_iterable(combinations(lst, r) for r in range(len(lst) + 1)) + + +if __name__ == '__main__': + import doctest + + doctest.testmod() diff --git a/src/pyutils/logging_utils.py b/src/pyutils/logging_utils.py new file mode 100644 index 0000000..d13527c --- /dev/null +++ b/src/pyutils/logging_utils.py @@ -0,0 +1,913 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# © Copyright 2021-2022, Scott Gasch + +"""Utilities related to logging. To use it you must invoke +:meth:`initialize_logging`. If you use the +:meth:`bootstrap.initialize` decorator on your program's entry point, +it will call this for you. See :meth:`python_modules.bootstrap.initialize` +for more details. If you use this you get: + +* Ability to set logging level, +* ability to define the logging format, +* ability to tee all logging on stderr, +* ability to tee all logging into a file, +* ability to rotate said file as it grows, +* ability to tee all logging into the system log (syslog) and + define the facility and level used to do so, +* easy automatic pid/tid stamp on logging for debugging threads, +* ability to squelch repeated log messages, +* ability to log probabilistically in code, +* ability to only see log messages from a particular module or + function, +* ability to clear logging handlers added by earlier loaded modules. + +All of these are controlled via commandline arguments to your program, +see the code below for details. +""" + +import collections +import contextlib +import datetime +import enum +import io +import logging +import os +import random +import sys +from logging.config import fileConfig +from logging.handlers import RotatingFileHandler, SysLogHandler +from typing import Any, Callable, Dict, Iterable, List, Optional + +import pytz +from overrides import overrides + +# This module is commonly used by others in here and should avoid +# taking any unnecessary dependencies back on them. +from pyutils import argparse_utils, config + +cfg = config.add_commandline_args(f'Logging ({__file__})', 'Args related to logging') +cfg.add_argument( + '--logging_config_file', + type=argparse_utils.valid_filename, + default=None, + metavar='FILENAME', + help='Config file containing the logging setup, see: https://docs.python.org/3/howto/logging.html#logging-advanced-tutorial', +) +cfg.add_argument( + '--logging_level', + type=str, + default='INFO', + choices=['NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], + metavar='LEVEL', + help='The global default level below which to squelch log messages; see also --lmodule', +) +cfg.add_argument( + '--logging_format', + type=str, + default=None, + help='The format for lines logged via the logger module. See: https://docs.python.org/3/library/logging.html#formatter-objects', +) +cfg.add_argument( + '--logging_date_format', + type=str, + default='%Y/%m/%dT%H:%M:%S.%f%z', + metavar='DATEFMT', + help='The format of any dates in --logging_format.', +) +cfg.add_argument( + '--logging_console', + action=argparse_utils.ActionNoYes, + default=True, + help='Should we log to the console (stderr)', +) +cfg.add_argument( + '--logging_filename', + type=str, + default=None, + metavar='FILENAME', + help='The filename of the logfile to write.', +) +cfg.add_argument( + '--logging_filename_maxsize', + type=int, + default=(1024 * 1024), + metavar='#BYTES', + help='The maximum size (in bytes) to write to the logging_filename.', +) +cfg.add_argument( + '--logging_filename_count', + type=int, + default=7, + metavar='COUNT', + help='The number of logging_filename copies to keep before deleting.', +) +cfg.add_argument( + '--logging_syslog', + action=argparse_utils.ActionNoYes, + default=False, + help='Should we log to localhost\'s syslog.', +) +cfg.add_argument( + '--logging_syslog_facility', + type=str, + default='USER', + choices=[ + 'NOTSET', + 'AUTH', + 'AUTH_PRIV', + 'CRON', + 'DAEMON', + 'FTP', + 'KERN', + 'LPR', + 'MAIL', + 'NEWS', + 'SYSLOG', + 'USER', + 'UUCP', + 'LOCAL0', + 'LOCAL1', + 'LOCAL2', + 'LOCAL3', + 'LOCAL4', + 'LOCAL5', + 'LOCAL6', + 'LOCAL7', + ], + metavar='SYSLOG_FACILITY_LIST', + help='The default syslog message facility identifier', +) +cfg.add_argument( + '--logging_debug_threads', + action=argparse_utils.ActionNoYes, + default=False, + help='Should we prepend pid/tid data to all log messages?', +) +cfg.add_argument( + '--logging_debug_modules', + action=argparse_utils.ActionNoYes, + default=False, + help='Should we prepend module/function data to all log messages?', +) +cfg.add_argument( + '--logging_info_is_print', + action=argparse_utils.ActionNoYes, + default=False, + help='logging.info also prints to stdout.', +) +cfg.add_argument( + '--logging_squelch_repeats', + action=argparse_utils.ActionNoYes, + default=True, + help='Do we allow code to indicate that it wants to squelch repeated logging messages or should we always log?', +) +cfg.add_argument( + '--logging_probabilistically', + action=argparse_utils.ActionNoYes, + default=True, + help='Do we allow probabilistic logging (for code that wants it) or should we always log?', +) +# See also: OutputMultiplexer +cfg.add_argument( + '--logging_captures_prints', + action=argparse_utils.ActionNoYes, + default=False, + help='When calling print, also log.info automatically.', +) +cfg.add_argument( + '--lmodule', + type=str, + metavar='=[,=...]', + help=( + 'Allows per-scope logging levels which override the global level set with --logging-level.' + + 'Pass a space separated list of = where is one of: module, ' + + 'module:function, or :function and is a logging level (e.g. INFO, DEBUG...)' + ), +) +cfg.add_argument( + '--logging_clear_preexisting_handlers', + action=argparse_utils.ActionNoYes, + default=True, + help=( + 'Should logging code clear preexisting global logging handlers and thus insist that is ' + + 'alone can add handlers. Use this to work around annoying modules that insert global ' + + 'handlers with formats and logging levels you might now want. Caveat emptor, this may ' + + 'cause you to miss logging messages.' + ), +) + +BUILT_IN_PRINT = print +LOGGING_INITIALIZED = False + + +# A map from logging_callsite_id -> count of logged messages. +squelched_logging_counts: Dict[str, int] = {} + + +def squelch_repeated_log_messages(squelch_after_n_repeats: int) -> Callable: + """ + A decorator that marks a function as interested in having the logging + messages that it produces be squelched (ignored) after it logs the + same message more than N times. + + .. note:: + + This decorator affects *ALL* logging messages produced + within the decorated function. That said, messages must be + identical in order to be squelched. For example, if the same line + of code produces different messages (because of, e.g., a format + string), the messages are considered to be different. + + """ + + def squelch_logging_wrapper(f: Callable): + from pyutils import function_utils + + identifier = function_utils.function_identifier(f) + squelched_logging_counts[identifier] = squelch_after_n_repeats + return f + + return squelch_logging_wrapper + + +class SquelchRepeatedMessagesFilter(logging.Filter): + """A filter that only logs messages from a given site with the same + (exact) message at the same logging level N times and ignores + subsequent attempts to log. + + This filter only affects logging messages that repeat more than a + threshold number of times from functions that are tagged with the + @logging_utils.squelched_logging_ok decorator (see above); others + are ignored. + + This functionality is enabled by default but can be disabled via + the :code:`--no_logging_squelch_repeats` commandline flag. + """ + + def __init__(self) -> None: + super().__init__() + self.counters: collections.Counter = collections.Counter() + + @overrides + def filter(self, record: logging.LogRecord) -> bool: + id1 = f'{record.module}:{record.funcName}' + if id1 not in squelched_logging_counts: + return True + threshold = squelched_logging_counts[id1] + logsite = f'{record.pathname}+{record.lineno}+{record.levelno}+{record.msg}' + count = self.counters[logsite] + self.counters[logsite] += 1 + return count < threshold + + +class DynamicPerScopeLoggingLevelFilter(logging.Filter): + """This filter only allows logging messages from an allow list of + module names or module:function names. Blocks all others. + """ + + @staticmethod + def level_name_to_level(name: str) -> int: + numeric_level = getattr(logging, name, None) + if not isinstance(numeric_level, int): + raise ValueError(f'Invalid level: {name}') + return numeric_level + + def __init__( + self, + default_logging_level: int, + per_scope_logging_levels: str, + ) -> None: + super().__init__() + self.valid_levels = set( + ['NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] + ) + self.default_logging_level = default_logging_level + self.level_by_scope = {} + if per_scope_logging_levels is not None: + for chunk in per_scope_logging_levels.split(','): + if '=' not in chunk: + print( + f'Malformed lmodule directive: "{chunk}", missing "=". Ignored.', + file=sys.stderr, + ) + continue + try: + (scope, level) = chunk.split('=') + except ValueError: + print( + f'Malformed lmodule directive: "{chunk}". Ignored.', + file=sys.stderr, + ) + continue + scope = scope.strip() + level = level.strip().upper() + if level not in self.valid_levels: + print( + f'Malformed lmodule directive: "{chunk}", bad level. Ignored.', + file=sys.stderr, + ) + continue + self.level_by_scope[ + scope + ] = DynamicPerScopeLoggingLevelFilter.level_name_to_level(level) + + @overrides + def filter(self, record: logging.LogRecord) -> bool: + """Decides whether or not to log based on an allow list.""" + + # First try to find a logging level by scope (--lmodule) + if len(self.level_by_scope) > 0: + min_level = None + for scope in ( + record.module, + f'{record.module}:{record.funcName}', + f':{record.funcName}', + ): + level = self.level_by_scope.get(scope, None) + if level is not None: + if min_level is None or level < min_level: + min_level = level + + # If we found one, use it instead of the global default level. + if min_level is not None: + return record.levelno >= min_level + + # Otherwise, use the global logging level (--logging_level) + return record.levelno >= self.default_logging_level + + +# A map from function_identifier -> probability of logging (0.0%..100.0%) +probabilistic_logging_levels: Dict[str, float] = {} + + +def logging_is_probabilistic(probability_of_logging: float) -> Callable: + """A decorator that indicates that all logging statements within the + scope of a particular (marked) function are not deterministic + (i.e. they do not always unconditionally log) but rather are + probabilistic (i.e. they log N% of the time, randomly). + + .. note:: + This affects *ALL* logging statements within the marked function. + + That this functionality can be disabled (forcing all logged + messages to produce output) via the + :code:`--no_logging_probabilistically` cmdline argument. + """ + + def probabilistic_logging_wrapper(f: Callable): + from pyutils import function_utils + + identifier = function_utils.function_identifier(f) + probabilistic_logging_levels[identifier] = probability_of_logging + return f + + return probabilistic_logging_wrapper + + +class ProbabilisticFilter(logging.Filter): + """ + A filter that logs messages probabilistically (i.e. randomly at some + percent chance). + + This filter only affects logging messages from functions that have + been tagged with the @logging_utils.probabilistic_logging decorator. + """ + + @overrides + def filter(self, record: logging.LogRecord) -> bool: + id1 = f'{record.module}:{record.funcName}' + if id1 not in probabilistic_logging_levels: + return True + threshold = probabilistic_logging_levels[id1] + return (random.random() * 100.0) <= threshold + + +class OnlyInfoFilter(logging.Filter): + """A filter that only logs messages produced at the INFO logging + level. This is used by the ::code`--logging_info_is_print` + commandline option to select a subset of the logging stream to + send to a stdout handler. + """ + + @overrides + def filter(self, record: logging.LogRecord): + return record.levelno == logging.INFO + + +class MillisecondAwareFormatter(logging.Formatter): + """ + A formatter for adding milliseconds to log messages which, for + whatever reason, the default python logger doesn't do. + """ + + converter = datetime.datetime.fromtimestamp # type: ignore + + @overrides + def formatTime(self, record, datefmt=None): + ct = MillisecondAwareFormatter.converter( + record.created, pytz.timezone("US/Pacific") + ) + if datefmt: + s = ct.strftime(datefmt) + else: + t = ct.strftime("%Y-%m-%d %H:%M:%S") + s = f"{t},{record.msecs:%03d}" + return s + + +def log_about_logging( + logger, + default_logging_level, + preexisting_handlers_count, + fmt, + facility_name, +): + """Some of the initial messages in the debug log are about how we + have set up logging itself.""" + + level_name = logging._levelToName.get( + default_logging_level, str(default_logging_level) + ) + logger.debug('Initialized global logging; default logging level is %s.', level_name) + if ( + config.config['logging_clear_preexisting_handlers'] + and preexisting_handlers_count > 0 + ): + logger.debug( + 'Logging cleared %d global handlers (--logging_clear_preexisting_handlers)', + preexisting_handlers_count, + ) + logger.debug('Logging format specification is "%s"', fmt) + if config.config['logging_debug_threads']: + logger.debug( + '...Logging format spec captures tid/pid. (--logging_debug_threads)' + ) + if config.config['logging_debug_modules']: + logger.debug( + '...Logging format spec captures files/functions/lineno. (--logging_debug_modules)' + ) + if config.config['logging_syslog']: + logger.debug( + 'Logging to syslog as %s with priority mapping based on level. (--logging_syslog)', + facility_name, + ) + if config.config['logging_filename']: + logger.debug( + 'Logging to file "%s". (--logging_filename)', + config.config["logging_filename"], + ) + logger.debug( + '...with %d bytes max file size. (--logging_filename_maxsize)', + config.config["logging_filename_maxsize"], + ) + logger.debug( + '...and %d rotating backup file count. (--logging_filename_count)', + config.config["logging_filename_count"], + ) + if config.config['logging_console']: + logger.debug('Logging to the console (stderr). (--logging_console)') + if config.config['logging_info_is_print']: + logger.debug( + 'Logging logger.info messages will be repeated on stdout. (--logging_info_is_print)' + ) + if config.config['logging_squelch_repeats']: + logger.debug( + 'Logging code allowed to request repeated messages be squelched. (--logging_squelch_repeats)' + ) + else: + logger.debug( + 'Logging code forbidden to request messages be squelched; all messages logged. (--no_logging_squelch_repeats)' + ) + if config.config['logging_probabilistically']: + logger.debug( + 'Logging code is allowed to request probabilistic logging. (--logging_probabilistically)' + ) + else: + logger.debug( + 'Logging code is forbidden to request probabilistic logging; messages always logged. (--no_logging_probabilistically)' + ) + if config.config['lmodule']: + logger.debug( + f'Logging dynamic per-module logging enabled. (--lmodule={config.config["lmodule"]})' + ) + if config.config['logging_captures_prints']: + logger.debug( + 'Logging will capture printed data as logger.info messages. (--logging_captures_prints)' + ) + + +def initialize_logging(logger=None) -> logging.Logger: + """Initialize logging for the program. This must be called if you want + to use any of the functionality provided by this module such as: + + * Ability to set logging level, + * ability to define the logging format, + * ability to tee all logging on stderr, + * ability to tee all logging into a file, + * ability to rotate said file as it grows, + * ability to tee all logging into the system log (syslog) and + define the facility and level used to do so, + * easy automatic pid/tid stamp on logging for debugging threads, + * ability to squelch repeated log messages, + * ability to log probabilistically in code, + * ability to only see log messages from a particular module or + function, + * ability to clear logging handlers added by earlier loaded modules. + + All of these are controlled via commandline arguments to your program, + see the code below for details. + + If you use the + :meth:`bootstrap.initialize` decorator on your program's entry point, + it will call this for you. See :meth:`python_modules.bootstrap.initialize` + for more details. + """ + global LOGGING_INITIALIZED + if LOGGING_INITIALIZED: + return logging.getLogger() + LOGGING_INITIALIZED = True + + if logger is None: + logger = logging.getLogger() + + # --logging_clear_preexisting_handlers removes logging handlers + # that were registered by global statements during imported module + # setup. + preexisting_handlers_count = 0 + assert config.has_been_parsed() + if config.config['logging_clear_preexisting_handlers']: + while logger.hasHandlers(): + logger.removeHandler(logger.handlers[0]) + preexisting_handlers_count += 1 + + # --logging_config_file pulls logging settings from a config file + # skipping the rest of this setup. + if config.config['logging_config_file'] is not None: + fileConfig(config.config['logging_config_file']) + return logger + + handlers: List[logging.Handler] = [] + handler: Optional[logging.Handler] = None + + # Global default logging level (--logging_level); messages below + # this level will be silenced. + default_logging_level = getattr( + logging, config.config['logging_level'].upper(), None + ) + if not isinstance(default_logging_level, int): + raise ValueError(f'Invalid level: {config.config["logging_level"]}') + + # Custom or default --logging_format? + if config.config['logging_format']: + fmt = config.config['logging_format'] + else: + if config.config['logging_syslog']: + fmt = '%(levelname).1s:%(filename)s[%(process)d]: %(message)s' + else: + fmt = '%(levelname).1s:%(asctime)s: %(message)s' + + # --logging_debug_threads and --logging_debug_modules both affect + # the format by prepending information about the pid/tid or + # file/function. + if config.config['logging_debug_threads']: + fmt = f'%(process)d.%(thread)d|{fmt}' + if config.config['logging_debug_modules']: + fmt = f'%(filename)s:%(funcName)s:%(lineno)s|{fmt}' + + # --logging_syslog (optionally with --logging_syslog_facility) + # sets up for logging to use the standard system syslogd as a + # sink. + facility_name = None + if config.config['logging_syslog']: + if sys.platform not in ('win32', 'cygwin'): + if config.config['logging_syslog_facility']: + facility_name = 'LOG_' + config.config['logging_syslog_facility'] + facility = SysLogHandler.__dict__.get(facility_name, SysLogHandler.LOG_USER) # type: ignore + assert facility is not None + handler = SysLogHandler(facility=facility, address='/dev/log') + handler.setFormatter( + MillisecondAwareFormatter( + fmt=fmt, + datefmt=config.config['logging_date_format'], + ) + ) + handlers.append(handler) + + # --logging_filename (with friends --logging_filename_count and + # --logging_filename_maxsize) set up logging to a file on the + # filesystem with automatic rotation when it gets too big. + if config.config['logging_filename']: + handler = RotatingFileHandler( + config.config['logging_filename'], + maxBytes=config.config['logging_filename_maxsize'], + backupCount=config.config['logging_filename_count'], + ) + handler.setFormatter( + MillisecondAwareFormatter( + fmt=fmt, + datefmt=config.config['logging_date_format'], + ) + ) + handlers.append(handler) + + # --logging_console is, ahem, logging to the console. + if config.config['logging_console']: + handler = logging.StreamHandler(sys.stderr) + handler.setFormatter( + MillisecondAwareFormatter( + fmt=fmt, + datefmt=config.config['logging_date_format'], + ) + ) + handlers.append(handler) + + if len(handlers) == 0: + handlers.append(logging.NullHandler()) + for handler in handlers: + logger.addHandler(handler) + + # --logging_info_is_print echoes any message to logger.info(x) as + # a print statement on stdout. + if config.config['logging_info_is_print']: + handler = logging.StreamHandler(sys.stdout) + handler.addFilter(OnlyInfoFilter()) + logger.addHandler(handler) + + # --logging_squelch_repeats allows code to request repeat logging + # messages (identical log site and message contents) to be + # silenced. Logging code must request this explicitly, it isn't + # automatic. This option just allows the silencing to happen. + if config.config['logging_squelch_repeats']: + for handler in handlers: + handler.addFilter(SquelchRepeatedMessagesFilter()) + + # --logging_probabilistically allows code to request + # non-deterministic logging where messages have some probability + # of being produced. Logging code must request this explicitly. + # This option just allows the non-deterministic behavior to + # happen. Disabling it will cause every log message to be + # produced. + if config.config['logging_probabilistically']: + for handler in handlers: + handler.addFilter(ProbabilisticFilter()) + + # --lmodule is a way to have a special logging level for just on + # module or one set of modules that is different than the one set + # globally via --logging_level. + for handler in handlers: + handler.addFilter( + DynamicPerScopeLoggingLevelFilter( + default_logging_level, + config.config['lmodule'], + ) + ) + logger.setLevel(0) + logger.propagate = False + + # --logging_captures_prints, if set, will capture and log.info + # anything printed on stdout. + if config.config['logging_captures_prints']: + import builtins + + def print_and_also_log(*arg, **kwarg): + f = kwarg.get('file', None) + if f == sys.stderr: + logger.warning(*arg) + else: + logger.info(*arg) + BUILT_IN_PRINT(*arg, **kwarg) + + builtins.print = print_and_also_log + + # At this point the logger is ready, handlers are set up, + # etc... so log about the logging configuration. + log_about_logging( + logger, + default_logging_level, + preexisting_handlers_count, + fmt, + facility_name, + ) + return logger + + +def get_logger(name: str = ""): + """Get the global logger""" + logger = logging.getLogger(name) + return initialize_logging(logger) + + +def tprint(*args, **kwargs) -> None: + """Legacy function for printing a message augmented with thread id + still needed by some code. Please use --logging_debug_threads in + new code. + """ + if config.config['logging_debug_threads']: + from pyutils.parallelize.thread_utils import current_thread_id + + print(f'{current_thread_id()}', end="") + print(*args, **kwargs) + else: + pass + + +def dprint(*args, **kwargs) -> None: + """Legacy function used to print to stderr still needed by some code. + Please just use normal logging with --logging_console which + accomplishes the same thing in new code. + """ + print(*args, file=sys.stderr, **kwargs) + + +class OutputMultiplexer(object): + """A class that broadcasts printed messages to several sinks + (including various logging levels, different files, different file + handles, the house log, etc...). See also + :class:`OutputMultiplexerContext` for an easy usage pattern. + """ + + class Destination(enum.IntEnum): + """Bits in the destination_bitv bitvector. Used to indicate the + output destination.""" + + # fmt: off + LOG_DEBUG = 0x01 # ⎫ + LOG_INFO = 0x02 # ⎪ + LOG_WARNING = 0x04 # ⎬ Must provide logger to the c'tor. + LOG_ERROR = 0x08 # ⎪ + LOG_CRITICAL = 0x10 # ⎭ + FILENAMES = 0x20 # Must provide a filename to the c'tor. + FILEHANDLES = 0x40 # Must provide a handle to the c'tor. + HLOG = 0x80 + ALL_LOG_DESTINATIONS = ( + LOG_DEBUG | LOG_INFO | LOG_WARNING | LOG_ERROR | LOG_CRITICAL + ) + ALL_OUTPUT_DESTINATIONS = 0x8F + # fmt: on + + def __init__( + self, + destination_bitv: int, + *, + logger=None, + filenames: Optional[Iterable[str]] = None, + handles: Optional[Iterable[io.TextIOWrapper]] = None, + ): + """ + Constructs the OutputMultiplexer instance. + + Args: + destination_bitv: a bitvector where each bit represents an + output destination. Multiple bits may be set. + logger: if LOG_* bits are set, you must pass a logger here. + filenames: if FILENAMES bit is set, this should be a list of + files you'd like to output into. This code handles opening + and closing said files. + handles: if FILEHANDLES bit is set, this should be a list of + already opened filehandles you'd like to output into. The + handles will remain open after the scope of the multiplexer. + """ + if logger is None: + logger = logging.getLogger(None) + self.logger = logger + + self.f: Optional[List[Any]] = None + if filenames is not None: + self.f = [open(filename, 'wb', buffering=0) for filename in filenames] + else: + if destination_bitv & OutputMultiplexer.Destination.FILENAMES: + raise ValueError("Filenames argument is required if bitv & FILENAMES") + self.f = None + + self.h: Optional[List[Any]] = None + if handles is not None: + self.h = list(handles) + else: + if destination_bitv & OutputMultiplexer.Destination.FILEHANDLES: + raise ValueError("Handle argument is required if bitv & FILEHANDLES") + self.h = None + + self.set_destination_bitv(destination_bitv) + + def get_destination_bitv(self): + """Where are we outputting?""" + return self.destination_bitv + + def set_destination_bitv(self, destination_bitv: int): + """Change the output destination_bitv to the one provided.""" + if destination_bitv & self.Destination.FILENAMES and self.f is None: + raise ValueError("Filename argument is required if bitv & FILENAMES") + if destination_bitv & self.Destination.FILEHANDLES and self.h is None: + raise ValueError("Handle argument is required if bitv & FILEHANDLES") + self.destination_bitv = destination_bitv + + def print(self, *args, **kwargs): + """Produce some output to all sinks.""" + from pyutils.string_utils import sprintf, strip_escape_sequences + + end = kwargs.pop("end", None) + if end is not None: + if not isinstance(end, str): + raise TypeError("end must be None or a string") + sep = kwargs.pop("sep", None) + if sep is not None: + if not isinstance(sep, str): + raise TypeError("sep must be None or a string") + if kwargs: + raise TypeError("invalid keyword arguments to print()") + buf = sprintf(*args, end="", sep=sep) + if sep is None: + sep = " " + if end is None: + end = "\n" + if end == '\n': + buf += '\n' + if self.destination_bitv & self.Destination.FILENAMES and self.f is not None: + for _ in self.f: + _.write(buf.encode('utf-8')) + _.flush() + + if self.destination_bitv & self.Destination.FILEHANDLES and self.h is not None: + for _ in self.h: + _.write(buf) + _.flush() + + buf = strip_escape_sequences(buf) + if self.logger is not None: + if self.destination_bitv & self.Destination.LOG_DEBUG: + self.logger.debug(buf) + if self.destination_bitv & self.Destination.LOG_INFO: + self.logger.info(buf) + if self.destination_bitv & self.Destination.LOG_WARNING: + self.logger.warning(buf) + if self.destination_bitv & self.Destination.LOG_ERROR: + self.logger.error(buf) + if self.destination_bitv & self.Destination.LOG_CRITICAL: + self.logger.critical(buf) + if self.destination_bitv & self.Destination.HLOG: + hlog(buf) + + def close(self): + """Close all open files.""" + if self.f is not None: + for _ in self.f: + _.close() + + +class OutputMultiplexerContext(OutputMultiplexer, contextlib.ContextDecorator): + """ + A context that uses an :class:`OutputMultiplexer`. e.g.:: + + with OutputMultiplexerContext( + OutputMultiplexer.LOG_INFO | + OutputMultiplexer.LOG_DEBUG | + OutputMultiplexer.FILENAMES | + OutputMultiplexer.FILEHANDLES, + filenames = [ '/tmp/foo.log', '/var/log/bar.log' ], + handles = [ f, g ] + ) as mplex: + mplex.print("This is a log message!") + """ + + def __init__( + self, + destination_bitv: OutputMultiplexer.Destination, + *, + logger=None, + filenames=None, + handles=None, + ): + super().__init__( + destination_bitv, + logger=logger, + filenames=filenames, + handles=handles, + ) + + def __enter__(self): + return self + + def __exit__(self, etype, value, traceback) -> bool: + super().close() + if etype is not None: + return False + return True + + +def hlog(message: str) -> None: + """Write a message to the house log (syslog facility local7 priority + info) by calling /usr/bin/logger. This is pretty hacky but used + by a bunch of code. Another way to do this would be to use + :code:`--logging_syslog` and :code:`--logging_syslog_facility` but + I can't actually say that's easier. + """ + message = message.replace("'", "'\"'\"'") + os.system(f"/usr/bin/logger -p local7.info -- '{message}'") + + +if __name__ == '__main__': + import doctest + + doctest.testmod() diff --git a/src/pyutils/math_utils.py b/src/pyutils/math_utils.py new file mode 100644 index 0000000..97bb635 --- /dev/null +++ b/src/pyutils/math_utils.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""Mathematical helpers.""" + +import collections +import functools +import math +from heapq import heappop, heappush +from typing import Dict, List, Optional, Tuple + +from pyutils import dict_utils + + +class NumericPopulation(object): + """A numeric population with some statistics such as median, mean, pN, + stdev, etc... + + >>> pop = NumericPopulation() + >>> pop.add_number(1) + >>> pop.add_number(10) + >>> pop.add_number(3) + >>> pop.get_median() + 3 + >>> pop.add_number(7) + >>> pop.add_number(5) + >>> pop.get_median() + 5 + >>> pop.get_mean() + 5.2 + >>> round(pop.get_stdev(), 1) + 1.4 + >>> pop.get_percentile(20) + 3 + >>> pop.get_percentile(60) + 7 + """ + + def __init__(self): + self.lowers, self.highers = [], [] + self.aggregate = 0.0 + self.sorted_copy: Optional[List[float]] = None + self.maximum = None + self.minimum = None + + def add_number(self, number: float): + """Adds a number to the population. Runtime complexity of this + operation is :math:`O(2 log_2 n)`""" + + if not self.highers or number > self.highers[0]: + heappush(self.highers, number) + else: + heappush(self.lowers, -number) # for lowers we need a max heap + self.aggregate += number + self._rebalance() + if not self.maximum or number > self.maximum: + self.maximum = number + if not self.minimum or number < self.minimum: + self.minimum = number + + def _rebalance(self): + if len(self.lowers) - len(self.highers) > 1: + heappush(self.highers, -heappop(self.lowers)) + elif len(self.highers) - len(self.lowers) > 1: + heappush(self.lowers, -heappop(self.highers)) + + def get_median(self) -> float: + """Returns the approximate median (p50) so far in O(1) time.""" + + if len(self.lowers) == len(self.highers): + return -self.lowers[0] + elif len(self.lowers) > len(self.highers): + return -self.lowers[0] + else: + return self.highers[0] + + def get_mean(self) -> float: + """Returns the mean (arithmetic mean) so far in O(1) time.""" + + count = len(self.lowers) + len(self.highers) + return self.aggregate / count + + def get_mode(self) -> Tuple[float, int]: + """Returns the mode (most common member in the population) + in O(n) time.""" + + count: Dict[float, int] = collections.defaultdict(int) + for n in self.lowers: + count[-n] += 1 + for n in self.highers: + count[n] += 1 + return dict_utils.item_with_max_value(count) + + def get_stdev(self) -> float: + """Returns the stdev so far in O(n) time.""" + + mean = self.get_mean() + variance = 0.0 + for n in self.lowers: + n = -n + variance += (n - mean) ** 2 + for n in self.highers: + variance += (n - mean) ** 2 + count = len(self.lowers) + len(self.highers) + return math.sqrt(variance) / count + + def _create_sorted_copy_if_needed(self, count: int): + if not self.sorted_copy or count != len(self.sorted_copy): + self.sorted_copy = [] + for x in self.lowers: + self.sorted_copy.append(-x) + for x in self.highers: + self.sorted_copy.append(x) + self.sorted_copy = sorted(self.sorted_copy) + + def get_percentile(self, n: float) -> float: + """Returns the number at approximately pn% (i.e. the nth percentile) + of the distribution in O(n log n) time. Not thread-safe; + does caching across multiple calls without an invocation to + add_number for perf reasons. + """ + if n == 50: + return self.get_median() + count = len(self.lowers) + len(self.highers) + self._create_sorted_copy_if_needed(count) + assert self.sorted_copy + index = round(count * (n / 100.0)) + index = max(0, index) + index = min(count - 1, index) + return self.sorted_copy[index] + + +def gcd_floats(a: float, b: float) -> float: + """Returns the greatest common divisor of a and b.""" + if a < b: + return gcd_floats(b, a) + + # base case + if abs(b) < 0.001: + return a + return gcd_floats(b, a - math.floor(a / b) * b) + + +def gcd_float_sequence(lst: List[float]) -> float: + """Returns the greatest common divisor of a list of floats.""" + if len(lst) <= 0: + raise ValueError("Need at least one number") + elif len(lst) == 1: + return lst[0] + assert len(lst) >= 2 + gcd = gcd_floats(lst[0], lst[1]) + for i in range(2, len(lst)): + gcd = gcd_floats(gcd, lst[i]) + return gcd + + +def truncate_float(n: float, decimals: int = 2): + """Truncate a float to a particular number of decimals. + + >>> truncate_float(3.1415927, 3) + 3.141 + + """ + assert 0 < decimals < 10 + multiplier = 10**decimals + return int(n * multiplier) / multiplier + + +def percentage_to_multiplier(percent: float) -> float: + """Given a percentage (e.g. 155%), return a factor needed to scale a + number by that percentage. + + >>> percentage_to_multiplier(155) + 2.55 + >>> percentage_to_multiplier(45) + 1.45 + >>> percentage_to_multiplier(-25) + 0.75 + """ + multiplier = percent / 100 + multiplier += 1.0 + return multiplier + + +def multiplier_to_percent(multiplier: float) -> float: + """Convert a multiplicative factor into a percent change. + + >>> multiplier_to_percent(0.75) + -25.0 + >>> multiplier_to_percent(1.0) + 0.0 + >>> multiplier_to_percent(1.99) + 99.0 + """ + percent = multiplier + if percent > 0.0: + percent -= 1.0 + else: + percent = 1.0 - percent + percent *= 100.0 + return percent + + +@functools.lru_cache(maxsize=1024, typed=True) +def is_prime(n: int) -> bool: + """ + Returns True if n is prime and False otherwise. Obviously(?) very slow for + very large input numbers. + + >>> is_prime(13) + True + >>> is_prime(22) + False + >>> is_prime(51602981) + True + """ + if not isinstance(n, int): + raise TypeError("argument passed to is_prime is not of 'int' type") + + # Corner cases + if n <= 1: + return False + if n <= 3: + return True + + # This is checked so that we can skip middle five numbers in below + # loop + if n % 2 == 0 or n % 3 == 0: + return False + + i = 5 + while i * i <= n: + if n % i == 0 or n % (i + 2) == 0: + return False + i = i + 6 + return True + + +if __name__ == '__main__': + import doctest + + doctest.testmod() diff --git a/src/pyutils/misc_utils.py b/src/pyutils/misc_utils.py new file mode 100644 index 0000000..669b3ef --- /dev/null +++ b/src/pyutils/misc_utils.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""Miscellaneous utilities.""" + +import os +import sys + + +def is_running_as_root() -> bool: + """Returns True if running as root. + + >>> is_running_as_root() + False + """ + return os.geteuid() == 0 + + +def debugger_is_attached() -> bool: + """Return if the debugger is attached""" + + gettrace = getattr(sys, 'gettrace', lambda: None) + return gettrace() is not None + + +if __name__ == '__main__': + import doctest + + doctest.testmod() diff --git a/src/pyutils/parallelize/__init__.py b/src/pyutils/parallelize/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/pyutils/parallelize/deferred_operand.py b/src/pyutils/parallelize/deferred_operand.py new file mode 100644 index 0000000..9edbb9e --- /dev/null +++ b/src/pyutils/parallelize/deferred_operand.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""This is a helper class that tries to define every __dunder__ method +so as to defer that evaluation of an object as long as possible. It +is used by smart_future.py as a base class. + +""" + +from abc import ABC, abstractmethod +from typing import Any, Generic, TypeVar + +# This module is commonly used by others in here and should avoid +# taking any unnecessary dependencies back on them. + +T = TypeVar('T') + + +class DeferredOperand(ABC, Generic[T]): + """A wrapper around an operand whose value is deferred until it is + needed (i.e. accessed). See the subclass :class:`SmartFuture` for + an example usage and/or a more useful patten. + """ + + @abstractmethod + def _resolve(self, timeout=None) -> T: + pass + + @staticmethod + def resolve(x: Any) -> Any: + while isinstance(x, DeferredOperand): + x = x._resolve() + return x + + def __lt__(self, other: Any) -> bool: + return DeferredOperand.resolve(self) < DeferredOperand.resolve(other) + + def __le__(self, other: Any) -> bool: + return DeferredOperand.resolve(self) <= DeferredOperand.resolve(other) + + def __eq__(self, other: Any) -> bool: + return DeferredOperand.resolve(self) == DeferredOperand.resolve(other) + + def __ne__(self, other: Any) -> bool: + return DeferredOperand.resolve(self) != DeferredOperand.resolve(other) + + def __gt__(self, other: Any) -> bool: + return DeferredOperand.resolve(self) > DeferredOperand.resolve(other) + + def __ge__(self, other: Any) -> bool: + return DeferredOperand.resolve(self) >= DeferredOperand.resolve(other) + + def __not__(self) -> bool: + return not DeferredOperand.resolve(self) + + def bool(self) -> bool: + return DeferredOperand.resolve(self) + + def __add__(self, other: Any) -> T: + return DeferredOperand.resolve(self) + DeferredOperand.resolve(other) + + def __iadd__(self, other: Any) -> T: + return DeferredOperand.resolve(self) + DeferredOperand.resolve(other) + + def __radd__(self, other: Any) -> T: + return DeferredOperand.resolve(self) + DeferredOperand.resolve(other) + + def __sub__(self, other: Any) -> T: + return DeferredOperand.resolve(self) - DeferredOperand.resolve(other) + + def __mul__(self, other: Any) -> T: + return DeferredOperand.resolve(self) * DeferredOperand.resolve(other) + + def __pow__(self, other: Any) -> T: + return DeferredOperand.resolve(self) ** DeferredOperand.resolve(other) + + def __truediv__(self, other: Any) -> Any: + return DeferredOperand.resolve(self) / DeferredOperand.resolve(other) + + def __floordiv__(self, other: Any) -> T: + return DeferredOperand.resolve(self) // DeferredOperand.resolve(other) + + def __contains__(self, other): + return DeferredOperand.resolve(other) in DeferredOperand.resolve(self) + + def and_(self, other): + return DeferredOperand.resolve(self) & DeferredOperand.resolve(other) + + def or_(self, other): + return DeferredOperand.resolve(self) & DeferredOperand.resolve(other) + + def xor(self, other): + return DeferredOperand.resolve(self) & DeferredOperand.resolve(other) + + def invert(self): + return ~(DeferredOperand.resolve(self)) + + def is_(self, other): + return DeferredOperand.resolve(self) is DeferredOperand.resolve(other) + + def is_not(self, other): + return DeferredOperand.resolve(self) is not DeferredOperand.resolve(other) + + def __abs__(self): + return abs(DeferredOperand.resolve(self)) + + def setitem(self, k, v): + DeferredOperand.resolve(self)[DeferredOperand.resolve(k)] = v + + def delitem(self, k): + del DeferredOperand.resolve(self)[DeferredOperand.resolve(k)] + + def getitem(self, k): + return DeferredOperand.resolve(self)[DeferredOperand.resolve(k)] + + def lshift(self, other): + return DeferredOperand.resolve(self) << DeferredOperand.resolve(other) + + def rshift(self, other): + return DeferredOperand.resolve(self) >> DeferredOperand.resolve(other) + + def mod(self, other): + return DeferredOperand.resolve(self) % DeferredOperand.resolve(other) + + def matmul(self, other): + return DeferredOperand.resolve(self) @ DeferredOperand.resolve(other) + + def neg(self): + return -(DeferredOperand.resolve(self)) + + def pos(self): + return +(DeferredOperand.resolve(self)) + + def truth(self): + return DeferredOperand.resolve(self) + + def __hash__(self): + return DeferredOperand.resolve(self).__hash__() + + def __call__(self): + return DeferredOperand.resolve(self)() + + def __iter__(self): + return DeferredOperand.resolve(self).__iter__() + + def __repr__(self) -> str: + return DeferredOperand.resolve(self).__repr__() + + def __bytes__(self) -> bytes: + return DeferredOperand.resolve(self).__bytes__() + + def __int__(self) -> int: + return int(DeferredOperand.resolve(self)) + + def __float__(self) -> float: + return float(DeferredOperand.resolve(self)) + + def __getattr__(self, method_name): + def method(*args, **kwargs): + return getattr(DeferredOperand.resolve(self), method_name)(*args, **kwargs) + + return method diff --git a/src/pyutils/parallelize/executors.py b/src/pyutils/parallelize/executors.py new file mode 100644 index 0000000..7bd44ca --- /dev/null +++ b/src/pyutils/parallelize/executors.py @@ -0,0 +1,1540 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# © Copyright 2021-2022, Scott Gasch + +"""Defines three executors: a thread executor for doing work using a +threadpool, a process executor for doing work in other processes on +the same machine and a remote executor for farming out work to other +machines. + +Also defines DefaultExecutors which is a container for references to +global executors / worker pools with automatic shutdown semantics.""" + +from __future__ import annotations + +import concurrent.futures as fut +import logging +import os +import platform +import random +import subprocess +import threading +import time +import warnings +from abc import ABC, abstractmethod +from collections import defaultdict +from dataclasses import dataclass, fields +from typing import Any, Callable, Dict, List, Optional, Set + +import cloudpickle # type: ignore +import numpy +from overrides import overrides + +import pyutils.typez.histogram as hist +from pyutils import argparse_utils, config, persistent, string_utils +from pyutils.ansi import bg, fg, reset, underline +from pyutils.decorator_utils import singleton +from pyutils.exec_utils import cmd_exitcode, cmd_in_background, run_silently +from pyutils.parallelize.thread_utils import background_thread + +logger = logging.getLogger(__name__) + +parser = config.add_commandline_args( + f"Executors ({__file__})", "Args related to processing executors." +) +parser.add_argument( + '--executors_threadpool_size', + type=int, + metavar='#THREADS', + help='Number of threads in the default threadpool, leave unset for default', + default=None, +) +parser.add_argument( + '--executors_processpool_size', + type=int, + metavar='#PROCESSES', + help='Number of processes in the default processpool, leave unset for default', + default=None, +) +parser.add_argument( + '--executors_schedule_remote_backups', + default=True, + action=argparse_utils.ActionNoYes, + help='Should we schedule duplicative backup work if a remote bundle is slow', +) +parser.add_argument( + '--executors_max_bundle_failures', + type=int, + default=3, + metavar='#FAILURES', + help='Maximum number of failures before giving up on a bundle', +) +parser.add_argument( + '--remote_worker_records_file', + type=str, + metavar='FILENAME', + help='Path of the remote worker records file (JSON)', + default=f'{os.environ.get("HOME", ".")}/.remote_worker_records', +) + + +SSH = '/usr/bin/ssh -oForwardX11=no' +SCP = '/usr/bin/scp -C' + + +def _make_cloud_pickle(fun, *args, **kwargs): + """Internal helper to create cloud pickles.""" + logger.debug("Making cloudpickled bundle at %s", fun.__name__) + return cloudpickle.dumps((fun, args, kwargs)) + + +class BaseExecutor(ABC): + """The base executor interface definition. The interface for + :class:`ProcessExecutor`, :class:`RemoteExecutor`, and + :class:`ThreadExecutor`. + """ + + def __init__(self, *, title=''): + self.title = title + self.histogram = hist.SimpleHistogram( + hist.SimpleHistogram.n_evenly_spaced_buckets(int(0), int(500), 50) + ) + self.task_count = 0 + + @abstractmethod + def submit(self, function: Callable, *args, **kwargs) -> fut.Future: + pass + + @abstractmethod + def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None: + pass + + def shutdown_if_idle(self, *, quiet: bool = False) -> bool: + """Shutdown the executor and return True if the executor is idle + (i.e. there are no pending or active tasks). Return False + otherwise. Note: this should only be called by the launcher + process. + + """ + if self.task_count == 0: + self.shutdown(wait=True, quiet=quiet) + return True + return False + + def adjust_task_count(self, delta: int) -> None: + """Change the task count. Note: do not call this method from a + worker, it should only be called by the launcher process / + thread / machine. + + """ + self.task_count += delta + logger.debug('Adjusted task count by %d to %d.', delta, self.task_count) + + def get_task_count(self) -> int: + """Change the task count. Note: do not call this method from a + worker, it should only be called by the launcher process / + thread / machine. + + """ + return self.task_count + + +class ThreadExecutor(BaseExecutor): + """A threadpool executor. This executor uses python threads to + schedule tasks. Note that, at least as of python3.10, because of + the global lock in the interpreter itself, these do not + parallelize very well so this class is useful mostly for non-CPU + intensive tasks. + + See also :class:`ProcessExecutor` and :class:`RemoteExecutor`. + """ + + def __init__(self, max_workers: Optional[int] = None): + super().__init__() + workers = None + if max_workers is not None: + workers = max_workers + elif 'executors_threadpool_size' in config.config: + workers = config.config['executors_threadpool_size'] + if workers is not None: + logger.debug('Creating threadpool executor with %d workers', workers) + else: + logger.debug('Creating a default sized threadpool executor') + self._thread_pool_executor = fut.ThreadPoolExecutor( + max_workers=workers, thread_name_prefix="thread_executor_helper" + ) + self.already_shutdown = False + + # This is run on a different thread; do not adjust task count here. + @staticmethod + def run_local_bundle(fun, *args, **kwargs): + logger.debug("Running local bundle at %s", fun.__name__) + result = fun(*args, **kwargs) + return result + + @overrides + def submit(self, function: Callable, *args, **kwargs) -> fut.Future: + if self.already_shutdown: + raise Exception('Submitted work after shutdown.') + self.adjust_task_count(+1) + newargs = [] + newargs.append(function) + for arg in args: + newargs.append(arg) + start = time.time() + result = self._thread_pool_executor.submit( + ThreadExecutor.run_local_bundle, *newargs, **kwargs + ) + result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start)) + result.add_done_callback(lambda _: self.adjust_task_count(-1)) + return result + + @overrides + def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None: + if not self.already_shutdown: + logger.debug('Shutting down threadpool executor %s', self.title) + self._thread_pool_executor.shutdown(wait) + if not quiet: + print(self.histogram.__repr__(label_formatter='%ds')) + self.already_shutdown = True + + +class ProcessExecutor(BaseExecutor): + """An executor which runs tasks in child processes. + + See also :class:`ThreadExecutor` and :class:`RemoteExecutor`. + """ + + def __init__(self, max_workers=None): + super().__init__() + workers = None + if max_workers is not None: + workers = max_workers + elif 'executors_processpool_size' in config.config: + workers = config.config['executors_processpool_size'] + if workers is not None: + logger.debug('Creating processpool executor with %d workers.', workers) + else: + logger.debug('Creating a default sized processpool executor') + self._process_executor = fut.ProcessPoolExecutor( + max_workers=workers, + ) + self.already_shutdown = False + + # This is run in another process; do not adjust task count here. + @staticmethod + def run_cloud_pickle(pickle): + fun, args, kwargs = cloudpickle.loads(pickle) + logger.debug("Running pickled bundle at %s", fun.__name__) + result = fun(*args, **kwargs) + return result + + @overrides + def submit(self, function: Callable, *args, **kwargs) -> fut.Future: + if self.already_shutdown: + raise Exception('Submitted work after shutdown.') + start = time.time() + self.adjust_task_count(+1) + pickle = _make_cloud_pickle(function, *args, **kwargs) + result = self._process_executor.submit(ProcessExecutor.run_cloud_pickle, pickle) + result.add_done_callback(lambda _: self.histogram.add_item(time.time() - start)) + result.add_done_callback(lambda _: self.adjust_task_count(-1)) + return result + + @overrides + def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None: + if not self.already_shutdown: + logger.debug('Shutting down processpool executor %s', self.title) + self._process_executor.shutdown(wait) + if not quiet: + print(self.histogram.__repr__(label_formatter='%ds')) + self.already_shutdown = True + + def __getstate__(self): + state = self.__dict__.copy() + state['_process_executor'] = None + return state + + +class RemoteExecutorException(Exception): + """Thrown when a bundle cannot be executed despite several retries.""" + + pass + + +@dataclass +class RemoteWorkerRecord: + """A record of info about a remote worker.""" + + username: str + """Username we can ssh into on this machine to run work.""" + + machine: str + """Machine address / name.""" + + weight: int + """Relative probability for the weighted policy to select this + machine for scheduling work.""" + + count: int + """If this machine is selected, what is the maximum number of task + that it can handle?""" + + def __hash__(self): + return hash((self.username, self.machine)) + + def __repr__(self): + return f'{self.username}@{self.machine}' + + +@dataclass +class BundleDetails: + """All info necessary to define some unit of work that needs to be + done, where it is being run, its state, whether it is an original + bundle of a backup bundle, how many times it has failed, etc... + """ + + pickled_code: bytes + """The code to run, cloud pickled""" + + uuid: str + """A unique identifier""" + + function_name: str + """The name of the function we pickled""" + + worker: Optional[RemoteWorkerRecord] + """The remote worker running this bundle or None if none (yet)""" + + username: Optional[str] + """The remote username running this bundle or None if none (yet)""" + + machine: Optional[str] + """The remote machine running this bundle or None if none (yet)""" + + hostname: str + """The controller machine""" + + code_file: str + """A unique filename to hold the work to be done""" + + result_file: str + """Where the results should be placed / read from""" + + pid: int + """The process id of the local subprocess watching the ssh connection + to the remote machine""" + + start_ts: float + """Starting time""" + + end_ts: float + """Ending time""" + + slower_than_local_p95: bool + """Currently slower then 95% of other bundles on remote host""" + + slower_than_global_p95: bool + """Currently slower than 95% of other bundles globally""" + + src_bundle: Optional[BundleDetails] + """If this is a backup bundle, this points to the original bundle + that it's backing up. None otherwise.""" + + is_cancelled: threading.Event + """An event that can be signaled to indicate this bundle is cancelled. + This is set when another copy (backup or original) of this work has + completed successfully elsewhere.""" + + was_cancelled: bool + """True if this bundle was cancelled, False if it finished normally""" + + backup_bundles: Optional[List[BundleDetails]] + """If we've created backups of this bundle, this is the list of them""" + + failure_count: int + """How many times has this bundle failed already?""" + + def __repr__(self): + uuid = self.uuid + if uuid[-9:-2] == '_backup': + uuid = uuid[:-9] + suffix = f'{uuid[-6:]}_b{self.uuid[-1:]}' + else: + suffix = uuid[-6:] + + # We colorize the uuid based on some bits from it to make them + # stand out in the logging and help a reader correlate log messages + # related to the same bundle. + colorz = [ + fg('violet red'), + fg('red'), + fg('orange'), + fg('peach orange'), + fg('yellow'), + fg('marigold yellow'), + fg('green yellow'), + fg('tea green'), + fg('cornflower blue'), + fg('turquoise blue'), + fg('tropical blue'), + fg('lavender purple'), + fg('medium purple'), + ] + c = colorz[int(uuid[-2:], 16) % len(colorz)] + function_name = ( + self.function_name if self.function_name is not None else 'nofname' + ) + machine = self.machine if self.machine is not None else 'nomachine' + return f'{c}{suffix}/{function_name}/{machine}{reset()}' + + +class RemoteExecutorStatus: + """A status 'scoreboard' for a remote executor tracking various + metrics and able to render a periodic dump of global state. + """ + + def __init__(self, total_worker_count: int) -> None: + """C'tor. + + Args: + total_worker_count: number of workers in the pool + + """ + self.worker_count: int = total_worker_count + self.known_workers: Set[RemoteWorkerRecord] = set() + self.start_time: float = time.time() + self.start_per_bundle: Dict[str, Optional[float]] = defaultdict(float) + self.end_per_bundle: Dict[str, float] = defaultdict(float) + self.finished_bundle_timings_per_worker: Dict[ + RemoteWorkerRecord, List[float] + ] = {} + self.in_flight_bundles_by_worker: Dict[RemoteWorkerRecord, Set[str]] = {} + self.bundle_details_by_uuid: Dict[str, BundleDetails] = {} + self.finished_bundle_timings: List[float] = [] + self.last_periodic_dump: Optional[float] = None + self.total_bundles_submitted: int = 0 + + # Protects reads and modification using self. Also used + # as a memory fence for modifications to bundle. + self.lock: threading.Lock = threading.Lock() + + def record_acquire_worker(self, worker: RemoteWorkerRecord, uuid: str) -> None: + """Record that bundle with uuid is assigned to a particular worker. + + Args: + worker: the record of the worker to which uuid is assigned + uuid: the uuid of a bundle that has been assigned to a worker + """ + with self.lock: + self.record_acquire_worker_already_locked(worker, uuid) + + def record_acquire_worker_already_locked( + self, worker: RemoteWorkerRecord, uuid: str + ) -> None: + """Same as above but an entry point that doesn't acquire the lock + for codepaths where it's already held.""" + assert self.lock.locked() + self.known_workers.add(worker) + self.start_per_bundle[uuid] = None + x = self.in_flight_bundles_by_worker.get(worker, set()) + x.add(uuid) + self.in_flight_bundles_by_worker[worker] = x + + def record_bundle_details(self, details: BundleDetails) -> None: + """Register the details about a bundle of work.""" + with self.lock: + self.record_bundle_details_already_locked(details) + + def record_bundle_details_already_locked(self, details: BundleDetails) -> None: + """Same as above but for codepaths that already hold the lock.""" + assert self.lock.locked() + self.bundle_details_by_uuid[details.uuid] = details + + def record_release_worker( + self, + worker: RemoteWorkerRecord, + uuid: str, + was_cancelled: bool, + ) -> None: + """Record that a bundle has released a worker.""" + with self.lock: + self.record_release_worker_already_locked(worker, uuid, was_cancelled) + + def record_release_worker_already_locked( + self, + worker: RemoteWorkerRecord, + uuid: str, + was_cancelled: bool, + ) -> None: + """Same as above but for codepaths that already hold the lock.""" + assert self.lock.locked() + ts = time.time() + self.end_per_bundle[uuid] = ts + self.in_flight_bundles_by_worker[worker].remove(uuid) + if not was_cancelled: + start = self.start_per_bundle[uuid] + assert start is not None + bundle_latency = ts - start + x = self.finished_bundle_timings_per_worker.get(worker, []) + x.append(bundle_latency) + self.finished_bundle_timings_per_worker[worker] = x + self.finished_bundle_timings.append(bundle_latency) + + def record_processing_began(self, uuid: str): + """Record when work on a bundle begins.""" + with self.lock: + self.start_per_bundle[uuid] = time.time() + + def total_in_flight(self) -> int: + """How many bundles are in flight currently?""" + assert self.lock.locked() + total_in_flight = 0 + for worker in self.known_workers: + total_in_flight += len(self.in_flight_bundles_by_worker[worker]) + return total_in_flight + + def total_idle(self) -> int: + """How many idle workers are there currently?""" + assert self.lock.locked() + return self.worker_count - self.total_in_flight() + + def __repr__(self): + assert self.lock.locked() + ts = time.time() + total_finished = len(self.finished_bundle_timings) + total_in_flight = self.total_in_flight() + ret = f'\n\n{underline()}Remote Executor Pool Status{reset()}: ' + qall = None + if len(self.finished_bundle_timings) > 1: + qall = numpy.quantile(self.finished_bundle_timings, [0.5, 0.95]) + ret += ( + f'⏱=∀p50:{qall[0]:.1f}s, ∀p95:{qall[1]:.1f}s, total={ts-self.start_time:.1f}s, ' + f'✅={total_finished}/{self.total_bundles_submitted}, ' + f'💻n={total_in_flight}/{self.worker_count}\n' + ) + else: + ret += ( + f'⏱={ts-self.start_time:.1f}s, ' + f'✅={total_finished}/{self.total_bundles_submitted}, ' + f'💻n={total_in_flight}/{self.worker_count}\n' + ) + + for worker in self.known_workers: + ret += f' {fg("lightning yellow")}{worker.machine}{reset()}: ' + timings = self.finished_bundle_timings_per_worker.get(worker, []) + count = len(timings) + qworker = None + if count > 1: + qworker = numpy.quantile(timings, [0.5, 0.95]) + ret += f' 💻p50: {qworker[0]:.1f}s, 💻p95: {qworker[1]:.1f}s\n' + else: + ret += '\n' + if count > 0: + ret += f' ...finished {count} total bundle(s) so far\n' + in_flight = len(self.in_flight_bundles_by_worker[worker]) + if in_flight > 0: + ret += f' ...{in_flight} bundles currently in flight:\n' + for bundle_uuid in self.in_flight_bundles_by_worker[worker]: + details = self.bundle_details_by_uuid.get(bundle_uuid, None) + pid = str(details.pid) if (details and details.pid != 0) else "TBD" + if self.start_per_bundle[bundle_uuid] is not None: + sec = ts - self.start_per_bundle[bundle_uuid] + ret += f' (pid={pid}): {details} for {sec:.1f}s so far ' + else: + ret += f' {details} setting up / copying data...' + sec = 0.0 + + if qworker is not None: + if sec > qworker[1]: + ret += f'{bg("red")}>💻p95{reset()} ' + if details is not None: + details.slower_than_local_p95 = True + else: + if details is not None: + details.slower_than_local_p95 = False + + if qall is not None: + if sec > qall[1]: + ret += f'{bg("red")}>∀p95{reset()} ' + if details is not None: + details.slower_than_global_p95 = True + else: + details.slower_than_global_p95 = False + ret += '\n' + return ret + + def periodic_dump(self, total_bundles_submitted: int) -> None: + assert self.lock.locked() + self.total_bundles_submitted = total_bundles_submitted + ts = time.time() + if self.last_periodic_dump is None or ts - self.last_periodic_dump > 5.0: + print(self) + self.last_periodic_dump = ts + + +class RemoteWorkerSelectionPolicy(ABC): + """A policy for selecting a remote worker base class.""" + + def __init__(self): + self.workers: Optional[List[RemoteWorkerRecord]] = None + + def register_worker_pool(self, workers: List[RemoteWorkerRecord]): + self.workers = workers + + @abstractmethod + def is_worker_available(self) -> bool: + pass + + @abstractmethod + def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]: + pass + + +class WeightedRandomRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy): + """A remote worker selector that uses weighted RNG.""" + + @overrides + def is_worker_available(self) -> bool: + if self.workers: + for worker in self.workers: + if worker.count > 0: + return True + return False + + @overrides + def acquire_worker(self, machine_to_avoid=None) -> Optional[RemoteWorkerRecord]: + grabbag = [] + if self.workers: + for worker in self.workers: + if worker.machine != machine_to_avoid: + if worker.count > 0: + for _ in range(worker.count * worker.weight): + grabbag.append(worker) + + if len(grabbag) == 0: + logger.debug( + 'There are no available workers that avoid %s', machine_to_avoid + ) + if self.workers: + for worker in self.workers: + if worker.count > 0: + for _ in range(worker.count * worker.weight): + grabbag.append(worker) + + if len(grabbag) == 0: + logger.warning('There are no available workers?!') + return None + + worker = random.sample(grabbag, 1)[0] + assert worker.count > 0 + worker.count -= 1 + logger.debug('Selected worker %s', worker) + return worker + + +class RoundRobinRemoteWorkerSelectionPolicy(RemoteWorkerSelectionPolicy): + """A remote worker selector that just round robins.""" + + def __init__(self) -> None: + super().__init__() + self.index = 0 + + @overrides + def is_worker_available(self) -> bool: + if self.workers: + for worker in self.workers: + if worker.count > 0: + return True + return False + + @overrides + def acquire_worker( + self, machine_to_avoid: str = None + ) -> Optional[RemoteWorkerRecord]: + if self.workers: + x = self.index + while True: + worker = self.workers[x] + if worker.count > 0: + worker.count -= 1 + x += 1 + if x >= len(self.workers): + x = 0 + self.index = x + logger.debug('Selected worker %s', worker) + return worker + x += 1 + if x >= len(self.workers): + x = 0 + if x == self.index: + logger.warning('Unexpectedly could not find a worker, retrying...') + return None + return None + + +class RemoteExecutor(BaseExecutor): + """An executor that uses processes on remote machines to do work. This + works by creating "bundles" of work with pickled code in each to be + executed. Each bundle is assigned a remote worker based on some policy + heuristics. Once assigned to a remote worker, a local subprocess is + created. It copies the pickled code to the remote machine via ssh/scp + and then starts up work on the remote machine again using ssh. When + the work is complete it copies the results back to the local machine. + + So there is essentially one "controller" machine (which may also be + in the remote executor pool and therefore do task work in addition to + controlling) and N worker machines. This code runs on the controller + whereas on the worker machines we invoke pickled user code via a + shim in :file:`remote_worker.py`. + + Some redundancy and safety provisions are made; e.g. slower than + expected tasks have redundant backups created and if a task fails + repeatedly we consider it poisoned and give up on it. + + .. warning:: + + The network overhead / latency of copying work from the + controller machine to the remote workers is relatively high. + This executor probably only makes sense to use with + computationally expensive tasks such as jobs that will execute + for ~30 seconds or longer. + + See also :class:`ProcessExecutor` and :class:`ThreadExecutor`. + """ + + def __init__( + self, + workers: List[RemoteWorkerRecord], + policy: RemoteWorkerSelectionPolicy, + ) -> None: + """C'tor. + + Args: + workers: A list of remote workers we can call on to do tasks. + policy: A policy for selecting remote workers for tasks. + """ + + super().__init__() + self.workers = workers + self.policy = policy + self.worker_count = 0 + for worker in self.workers: + self.worker_count += worker.count + if self.worker_count <= 0: + msg = f"We need somewhere to schedule work; count was {self.worker_count}" + logger.critical(msg) + raise RemoteExecutorException(msg) + self.policy.register_worker_pool(self.workers) + self.cv = threading.Condition() + logger.debug( + 'Creating %d local threads, one per remote worker.', self.worker_count + ) + self._helper_executor = fut.ThreadPoolExecutor( + thread_name_prefix="remote_executor_helper", + max_workers=self.worker_count, + ) + self.status = RemoteExecutorStatus(self.worker_count) + self.total_bundles_submitted = 0 + self.backup_lock = threading.Lock() + self.last_backup = None + ( + self.heartbeat_thread, + self.heartbeat_stop_event, + ) = self._run_periodic_heartbeat() + self.already_shutdown = False + + @background_thread + def _run_periodic_heartbeat(self, stop_event: threading.Event) -> None: + """ + We create a background thread to invoke :meth:`_heartbeat` regularly + while we are scheduling work. It does some accounting such as + looking for slow bundles to tag for backup creation, checking for + unexpected failures, and printing a fancy message on stdout. + """ + while not stop_event.is_set(): + time.sleep(5.0) + logger.debug('Running periodic heartbeat code...') + self._heartbeat() + logger.debug('Periodic heartbeat thread shutting down.') + + def _heartbeat(self) -> None: + # Note: this is invoked on a background thread, not an + # executor thread. Be careful what you do with it b/c it + # needs to get back and dump status again periodically. + with self.status.lock: + self.status.periodic_dump(self.total_bundles_submitted) + + # Look for bundles to reschedule via executor.submit + if config.config['executors_schedule_remote_backups']: + self._maybe_schedule_backup_bundles() + + def _maybe_schedule_backup_bundles(self): + """Maybe schedule backup bundles if we see a very slow bundle.""" + + assert self.status.lock.locked() + num_done = len(self.status.finished_bundle_timings) + num_idle_workers = self.worker_count - self.task_count + now = time.time() + if ( + num_done >= 2 + and num_idle_workers > 0 + and (self.last_backup is None or (now - self.last_backup > 9.0)) + and self.backup_lock.acquire(blocking=False) + ): + try: + assert self.backup_lock.locked() + + bundle_to_backup = None + best_score = None + for ( + worker, + bundle_uuids, + ) in self.status.in_flight_bundles_by_worker.items(): + + # Prefer to schedule backups of bundles running on + # slower machines. + base_score = 0 + for record in self.workers: + if worker.machine == record.machine: + base_score = float(record.weight) + base_score = 1.0 / base_score + base_score *= 200.0 + base_score = int(base_score) + break + + for uuid in bundle_uuids: + bundle = self.status.bundle_details_by_uuid.get(uuid, None) + if ( + bundle is not None + and bundle.src_bundle is None + and bundle.backup_bundles is not None + ): + score = base_score + + # Schedule backups of bundles running + # longer; especially those that are + # unexpectedly slow. + start_ts = self.status.start_per_bundle[uuid] + if start_ts is not None: + runtime = now - start_ts + score += runtime + logger.debug( + 'score[%s] => %.1f # latency boost', bundle, score + ) + + if bundle.slower_than_local_p95: + score += runtime / 2 + logger.debug( + 'score[%s] => %.1f # >worker p95', + bundle, + score, + ) + + if bundle.slower_than_global_p95: + score += runtime / 4 + logger.debug( + 'score[%s] => %.1f # >global p95', + bundle, + score, + ) + + # Prefer backups of bundles that don't + # have backups already. + backup_count = len(bundle.backup_bundles) + if backup_count == 0: + score *= 2 + elif backup_count == 1: + score /= 2 + elif backup_count == 2: + score /= 8 + else: + score = 0 + logger.debug( + 'score[%s] => %.1f # {backup_count} dup backup factor', + bundle, + score, + ) + + if score != 0 and ( + best_score is None or score > best_score + ): + bundle_to_backup = bundle + assert bundle is not None + assert bundle.backup_bundles is not None + assert bundle.src_bundle is None + best_score = score + + # Note: this is all still happening on the heartbeat + # runner thread. That's ok because + # _schedule_backup_for_bundle uses the executor to + # submit the bundle again which will cause it to be + # picked up by a worker thread and allow this thread + # to return to run future heartbeats. + if bundle_to_backup is not None: + self.last_backup = now + logger.info( + '=====> SCHEDULING BACKUP %s (score=%.1f) <=====', + bundle_to_backup, + best_score, + ) + self._schedule_backup_for_bundle(bundle_to_backup) + finally: + self.backup_lock.release() + + def _is_worker_available(self) -> bool: + """Is there a worker available currently?""" + return self.policy.is_worker_available() + + def _acquire_worker( + self, machine_to_avoid: str = None + ) -> Optional[RemoteWorkerRecord]: + """Try to acquire a worker.""" + return self.policy.acquire_worker(machine_to_avoid) + + def _find_available_worker_or_block( + self, machine_to_avoid: str = None + ) -> RemoteWorkerRecord: + """Find a worker or block until one becomes available.""" + with self.cv: + while not self._is_worker_available(): + self.cv.wait() + worker = self._acquire_worker(machine_to_avoid) + if worker is not None: + return worker + msg = "We should never reach this point in the code" + logger.critical(msg) + raise Exception(msg) + + def _release_worker(self, bundle: BundleDetails, *, was_cancelled=True) -> None: + """Release a previously acquired worker.""" + worker = bundle.worker + assert worker is not None + logger.debug('Released worker %s', worker) + self.status.record_release_worker( + worker, + bundle.uuid, + was_cancelled, + ) + with self.cv: + worker.count += 1 + self.cv.notify() + self.adjust_task_count(-1) + + def _check_if_cancelled(self, bundle: BundleDetails) -> bool: + """See if a particular bundle is cancelled. Do not block.""" + with self.status.lock: + if bundle.is_cancelled.wait(timeout=0.0): + logger.debug('Bundle %s is cancelled, bail out.', bundle.uuid) + bundle.was_cancelled = True + return True + return False + + def _launch(self, bundle: BundleDetails, override_avoid_machine=None) -> Any: + """Find a worker for bundle or block until one is available.""" + + self.adjust_task_count(+1) + uuid = bundle.uuid + hostname = bundle.hostname + avoid_machine = override_avoid_machine + is_original = bundle.src_bundle is None + + # Try not to schedule a backup on the same host as the original. + if avoid_machine is None and bundle.src_bundle is not None: + avoid_machine = bundle.src_bundle.machine + worker = None + while worker is None: + worker = self._find_available_worker_or_block(avoid_machine) + assert worker is not None + + # Ok, found a worker. + bundle.worker = worker + machine = bundle.machine = worker.machine + username = bundle.username = worker.username + self.status.record_acquire_worker(worker, uuid) + logger.debug('%s: Running bundle on %s...', bundle, worker) + + # Before we do any work, make sure the bundle is still viable. + # It may have been some time between when it was submitted and + # now due to lack of worker availability and someone else may + # have already finished it. + if self._check_if_cancelled(bundle): + try: + return self._process_work_result(bundle) + except Exception as e: + logger.warning( + '%s: bundle says it\'s cancelled upfront but no results?!', bundle + ) + self._release_worker(bundle) + if is_original: + # Weird. We are the original owner of this + # bundle. For it to have been cancelled, a backup + # must have already started and completed before + # we even for started. Moreover, the backup says + # it is done but we can't find the results it + # should have copied over. Reschedule the whole + # thing. + logger.exception(e) + logger.error( + '%s: We are the original owner thread and yet there are ' + 'no results for this bundle. This is unexpected and bad.', + bundle, + ) + return self._emergency_retry_nasty_bundle(bundle) + else: + # We're a backup and our bundle is cancelled + # before we even got started. Do nothing and let + # the original bundle's thread worry about either + # finding the results or complaining about it. + return None + + # Send input code / data to worker machine if it's not local. + if hostname not in machine: + try: + cmd = ( + f'{SCP} {bundle.code_file} {username}@{machine}:{bundle.code_file}' + ) + start_ts = time.time() + logger.info("%s: Copying work to %s via %s.", bundle, worker, cmd) + run_silently(cmd) + xfer_latency = time.time() - start_ts + logger.debug( + "%s: Copying to %s took %.1fs.", bundle, worker, xfer_latency + ) + except Exception as e: + self._release_worker(bundle) + if is_original: + # Weird. We tried to copy the code to the worker + # and it failed... And we're the original bundle. + # We have to retry. + logger.exception(e) + logger.error( + "%s: Failed to send instructions to the worker machine?! " + "This is not expected; we\'re the original bundle so this shouldn\'t " + "be a race condition. Attempting an emergency retry...", + bundle, + ) + return self._emergency_retry_nasty_bundle(bundle) + else: + # This is actually expected; we're a backup. + # There's a race condition where someone else + # already finished the work and removed the source + # code_file before we could copy it. Ignore. + logger.warning( + '%s: Failed to send instructions to the worker machine... ' + 'We\'re a backup and this may be caused by the original (or ' + 'some other backup) already finishing this work. Ignoring.', + bundle, + ) + return None + + # Kick off the work. Note that if this fails we let + # _wait_for_process deal with it. + self.status.record_processing_began(uuid) + cmd = ( + f'{SSH} {bundle.username}@{bundle.machine} ' + f'"source py39-venv/bin/activate &&' + f' /home/scott/lib/python_modules/remote_worker.py' + f' --code_file {bundle.code_file} --result_file {bundle.result_file}"' + ) + logger.debug( + '%s: Executing %s in the background to kick off work...', bundle, cmd + ) + p = cmd_in_background(cmd, silent=True) + bundle.pid = p.pid + logger.debug( + '%s: Local ssh process pid=%d; remote worker is %s.', bundle, p.pid, machine + ) + return self._wait_for_process(p, bundle, 0) + + def _wait_for_process( + self, p: Optional[subprocess.Popen], bundle: BundleDetails, depth: int + ) -> Any: + """At this point we've copied the bundle's pickled code to the remote + worker and started an ssh process that should be invoking the + remote worker to have it execute the user's code. See how + that's going and wait for it to complete or fail. Note that + this code is recursive: there are codepaths where we decide to + stop waiting for an ssh process (because another backup seems + to have finished) but then fail to fetch or parse the results + from that backup and thus call ourselves to continue waiting + on an active ssh process. This is the purpose of the depth + argument: to curtail potential infinite recursion by giving up + eventually. + + Args: + p: the Popen record of the ssh job + bundle: the bundle of work being executed remotely + depth: how many retries we've made so far. Starts at zero. + + """ + + machine = bundle.machine + assert p is not None + pid = p.pid # pid of the ssh process + if depth > 3: + logger.error( + "I've gotten repeated errors waiting on this bundle; giving up on pid=%d", + pid, + ) + p.terminate() + self._release_worker(bundle) + return self._emergency_retry_nasty_bundle(bundle) + + # Spin until either the ssh job we scheduled finishes the + # bundle or some backup worker signals that they finished it + # before we could. + while True: + try: + p.wait(timeout=0.25) + except subprocess.TimeoutExpired: + if self._check_if_cancelled(bundle): + logger.info( + '%s: looks like another worker finished bundle...', bundle + ) + break + else: + logger.info("%s: pid %d (%s) is finished!", bundle, pid, machine) + p = None + break + + # If we get here we believe the bundle is done; either the ssh + # subprocess finished (hopefully successfully) or we noticed + # that some other worker seems to have completed the bundle + # before us and we're bailing out. + try: + ret = self._process_work_result(bundle) + if ret is not None and p is not None: + p.terminate() + return ret + + # Something went wrong; e.g. we could not copy the results + # back, cleanup after ourselves on the remote machine, or + # unpickle the results we got from the remove machine. If we + # still have an active ssh subprocess, keep waiting on it. + # Otherwise, time for an emergency reschedule. + except Exception as e: + logger.exception(e) + logger.error('%s: Something unexpected just happened...', bundle) + if p is not None: + logger.warning( + "%s: Failed to wrap up \"done\" bundle, re-waiting on active ssh.", + bundle, + ) + return self._wait_for_process(p, bundle, depth + 1) + else: + self._release_worker(bundle) + return self._emergency_retry_nasty_bundle(bundle) + + def _process_work_result(self, bundle: BundleDetails) -> Any: + """A bundle seems to be completed. Check on the results.""" + + with self.status.lock: + is_original = bundle.src_bundle is None + was_cancelled = bundle.was_cancelled + username = bundle.username + machine = bundle.machine + result_file = bundle.result_file + code_file = bundle.code_file + + # Whether original or backup, if we finished first we must + # fetch the results if the computation happened on a + # remote machine. + bundle.end_ts = time.time() + if not was_cancelled: + assert bundle.machine is not None + if bundle.hostname not in bundle.machine: + cmd = f'{SCP} {username}@{machine}:{result_file} {result_file} 2>/dev/null' + logger.info( + "%s: Fetching results back from %s@%s via %s", + bundle, + username, + machine, + cmd, + ) + + # If either of these throw they are handled in + # _wait_for_process. + attempts = 0 + while True: + try: + run_silently(cmd) + except Exception as e: + attempts += 1 + if attempts >= 3: + raise e + else: + break + + # Cleanup remote /tmp files. + run_silently( + f'{SSH} {username}@{machine}' + f' "/bin/rm -f {code_file} {result_file}"' + ) + logger.debug( + 'Fetching results back took %.2fs', time.time() - bundle.end_ts + ) + dur = bundle.end_ts - bundle.start_ts + self.histogram.add_item(dur) + + # Only the original worker should unpickle the file contents + # though since it's the only one whose result matters. The + # original is also the only job that may delete result_file + # from disk. Note that the original may have been cancelled + # if one of the backups finished first; it still must read the + # result from disk. It still does that here with is_cancelled + # set. + if is_original: + logger.debug("%s: Unpickling %s.", bundle, result_file) + try: + with open(result_file, 'rb') as rb: + serialized = rb.read() + result = cloudpickle.loads(serialized) + except Exception as e: + logger.exception(e) + logger.error('Failed to load %s... this is bad news.', result_file) + self._release_worker(bundle) + + # Re-raise the exception; the code in _wait_for_process may + # decide to _emergency_retry_nasty_bundle here. + raise e + logger.debug('Removing local (master) %s and %s.', code_file, result_file) + os.remove(result_file) + os.remove(code_file) + + # Notify any backups that the original is done so they + # should stop ASAP. Do this whether or not we + # finished first since there could be more than one + # backup. + if bundle.backup_bundles is not None: + for backup in bundle.backup_bundles: + logger.debug( + '%s: Notifying backup %s that it\'s cancelled', + bundle, + backup.uuid, + ) + backup.is_cancelled.set() + + # This is a backup job and, by now, we have already fetched + # the bundle results. + else: + # Backup results don't matter, they just need to leave the + # result file in the right place for their originals to + # read/unpickle later. + result = None + + # Tell the original to stop if we finished first. + if not was_cancelled: + orig_bundle = bundle.src_bundle + assert orig_bundle is not None + logger.debug( + '%s: Notifying original %s we beat them to it.', + bundle, + orig_bundle.uuid, + ) + orig_bundle.is_cancelled.set() + self._release_worker(bundle, was_cancelled=was_cancelled) + return result + + def _create_original_bundle(self, pickle, function_name: str): + """Creates a bundle that is not a backup of any other bundle but + rather represents a user task. + """ + + uuid = string_utils.generate_uuid(omit_dashes=True) + code_file = f'/tmp/{uuid}.code.bin' + result_file = f'/tmp/{uuid}.result.bin' + + logger.debug('Writing pickled code to %s', code_file) + with open(code_file, 'wb') as wb: + wb.write(pickle) + + bundle = BundleDetails( + pickled_code=pickle, + uuid=uuid, + function_name=function_name, + worker=None, + username=None, + machine=None, + hostname=platform.node(), + code_file=code_file, + result_file=result_file, + pid=0, + start_ts=time.time(), + end_ts=0.0, + slower_than_local_p95=False, + slower_than_global_p95=False, + src_bundle=None, + is_cancelled=threading.Event(), + was_cancelled=False, + backup_bundles=[], + failure_count=0, + ) + self.status.record_bundle_details(bundle) + logger.debug('%s: Created an original bundle', bundle) + return bundle + + def _create_backup_bundle(self, src_bundle: BundleDetails): + """Creates a bundle that is a backup of another bundle that is + running too slowly.""" + + assert self.status.lock.locked() + assert src_bundle.backup_bundles is not None + n = len(src_bundle.backup_bundles) + uuid = src_bundle.uuid + f'_backup#{n}' + + backup_bundle = BundleDetails( + pickled_code=src_bundle.pickled_code, + uuid=uuid, + function_name=src_bundle.function_name, + worker=None, + username=None, + machine=None, + hostname=src_bundle.hostname, + code_file=src_bundle.code_file, + result_file=src_bundle.result_file, + pid=0, + start_ts=time.time(), + end_ts=0.0, + slower_than_local_p95=False, + slower_than_global_p95=False, + src_bundle=src_bundle, + is_cancelled=threading.Event(), + was_cancelled=False, + backup_bundles=None, # backup backups not allowed + failure_count=0, + ) + src_bundle.backup_bundles.append(backup_bundle) + self.status.record_bundle_details_already_locked(backup_bundle) + logger.debug('%s: Created a backup bundle', backup_bundle) + return backup_bundle + + def _schedule_backup_for_bundle(self, src_bundle: BundleDetails): + """Schedule a backup of src_bundle.""" + + assert self.status.lock.locked() + assert src_bundle is not None + backup_bundle = self._create_backup_bundle(src_bundle) + logger.debug( + '%s/%s: Scheduling backup for execution...', + backup_bundle.uuid, + backup_bundle.function_name, + ) + self._helper_executor.submit(self._launch, backup_bundle) + + # Results from backups don't matter; if they finish first + # they will move the result_file to this machine and let + # the original pick them up and unpickle them (and return + # a result). + + def _emergency_retry_nasty_bundle( + self, bundle: BundleDetails + ) -> Optional[fut.Future]: + """Something unexpectedly failed with bundle. Either retry it + from the beginning or throw in the towel and give up on it.""" + + is_original = bundle.src_bundle is None + bundle.worker = None + avoid_last_machine = bundle.machine + bundle.machine = None + bundle.username = None + bundle.failure_count += 1 + if is_original: + retry_limit = 3 + else: + retry_limit = 2 + + if bundle.failure_count > retry_limit: + logger.error( + '%s: Tried this bundle too many times already (%dx); giving up.', + bundle, + retry_limit, + ) + if is_original: + raise RemoteExecutorException( + f'{bundle}: This bundle can\'t be completed despite several backups and retries', + ) + else: + logger.error( + '%s: At least it\'s only a backup; better luck with the others.', + bundle, + ) + return None + else: + msg = f'>>> Emergency rescheduling {bundle} because of unexected errors (wtf?!) <<<' + logger.warning(msg) + warnings.warn(msg) + return self._launch(bundle, avoid_last_machine) + + @overrides + def submit(self, function: Callable, *args, **kwargs) -> fut.Future: + """Submit work to be done. This is the user entry point of this + class.""" + if self.already_shutdown: + raise Exception('Submitted work after shutdown.') + pickle = _make_cloud_pickle(function, *args, **kwargs) + bundle = self._create_original_bundle(pickle, function.__name__) + self.total_bundles_submitted += 1 + return self._helper_executor.submit(self._launch, bundle) + + @overrides + def shutdown(self, *, wait: bool = True, quiet: bool = False) -> None: + """Shutdown the executor.""" + if not self.already_shutdown: + logging.debug('Shutting down RemoteExecutor %s', self.title) + self.heartbeat_stop_event.set() + self.heartbeat_thread.join() + self._helper_executor.shutdown(wait) + if not quiet: + print(self.histogram.__repr__(label_formatter='%ds')) + self.already_shutdown = True + + +class RemoteWorkerPoolProvider: + @abstractmethod + def get_remote_workers(self) -> List[RemoteWorkerRecord]: + pass + + +@persistent.persistent_autoloaded_singleton() # type: ignore +class ConfigRemoteWorkerPoolProvider( + RemoteWorkerPoolProvider, persistent.JsonFileBasedPersistent +): + def __init__(self, json_remote_worker_pool: Dict[str, Any]): + self.remote_worker_pool = [] + for record in json_remote_worker_pool['remote_worker_records']: + self.remote_worker_pool.append( + self.dataclassFromDict(RemoteWorkerRecord, record) + ) + assert len(self.remote_worker_pool) > 0 + + @staticmethod + def dataclassFromDict(clsName, argDict: Dict[str, Any]) -> Any: + fieldSet = {f.name for f in fields(clsName) if f.init} + filteredArgDict = {k: v for k, v in argDict.items() if k in fieldSet} + return clsName(**filteredArgDict) + + @overrides + def get_remote_workers(self) -> List[RemoteWorkerRecord]: + return self.remote_worker_pool + + @overrides + def get_persistent_data(self) -> List[RemoteWorkerRecord]: + return self.remote_worker_pool + + @staticmethod + @overrides + def get_filename() -> str: + return config.config['remote_worker_records_file'] + + @staticmethod + @overrides + def should_we_load_data(filename: str) -> bool: + return True + + @staticmethod + @overrides + def should_we_save_data(filename: str) -> bool: + return False + + +@singleton +class DefaultExecutors(object): + """A container for a default thread, process and remote executor. + These are not created until needed and we take care to clean up + before process exit automatically for the caller's convenience. + Instead of creating your own executor, consider using the one + from this pool. e.g.:: + + @par.parallelize(method=par.Method.PROCESS) + def do_work( + solutions: List[Work], + shard_num: int, + ... + ): + + + + def start_do_work(all_work: List[Work]): + shards = [] + logger.debug('Sharding work into groups of 10.') + for subset in list_utils.shard(all_work, 10): + shards.append([x for x in subset]) + + logger.debug('Kicking off helper pool.') + try: + for n, shard in enumerate(shards): + results.append( + do_work( + shard, n, shared_cache.get_name(), max_letter_pop_per_word + ) + ) + smart_future.wait_all(results) + finally: + # Note: if you forget to do this it will clean itself up + # during program termination including tearing down any + # active ssh connections. + executors.DefaultExecutors().process_pool().shutdown() + """ + + def __init__(self): + self.thread_executor: Optional[ThreadExecutor] = None + self.process_executor: Optional[ProcessExecutor] = None + self.remote_executor: Optional[RemoteExecutor] = None + + @staticmethod + def _ping(host) -> bool: + logger.debug('RUN> ping -c 1 %s', host) + try: + x = cmd_exitcode( + f'ping -c 1 {host} >/dev/null 2>/dev/null', timeout_seconds=1.0 + ) + return x == 0 + except Exception: + return False + + def thread_pool(self) -> ThreadExecutor: + if self.thread_executor is None: + self.thread_executor = ThreadExecutor() + return self.thread_executor + + def process_pool(self) -> ProcessExecutor: + if self.process_executor is None: + self.process_executor = ProcessExecutor() + return self.process_executor + + def remote_pool(self) -> RemoteExecutor: + if self.remote_executor is None: + logger.info('Looking for some helper machines...') + provider = ConfigRemoteWorkerPoolProvider() + all_machines = provider.get_remote_workers() + pool = [] + + # Make sure we can ping each machine. + for record in all_machines: + if self._ping(record.machine): + logger.info('%s is alive / responding to pings', record.machine) + pool.append(record) + + # The controller machine has a lot to do; go easy on it. + for record in pool: + if record.machine == platform.node() and record.count > 1: + logger.info('Reducing workload for %s.', record.machine) + record.count = max(int(record.count / 2), 1) + + policy = WeightedRandomRemoteWorkerSelectionPolicy() + policy.register_worker_pool(pool) + self.remote_executor = RemoteExecutor(pool, policy) + return self.remote_executor + + def shutdown(self) -> None: + if self.thread_executor is not None: + self.thread_executor.shutdown(wait=True, quiet=True) + self.thread_executor = None + if self.process_executor is not None: + self.process_executor.shutdown(wait=True, quiet=True) + self.process_executor = None + if self.remote_executor is not None: + self.remote_executor.shutdown(wait=True, quiet=True) + self.remote_executor = None diff --git a/src/pyutils/parallelize/parallelize.py b/src/pyutils/parallelize/parallelize.py new file mode 100644 index 0000000..9824e8a --- /dev/null +++ b/src/pyutils/parallelize/parallelize.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""A decorator to help with dead simple parallelization.""" + + +import atexit +import functools +import typing +from enum import Enum + + +class Method(Enum): + """How should we parallelize; by threads, processes or remote workers?""" + + THREAD = 1 + PROCESS = 2 + REMOTE = 3 + + +def parallelize( + _funct: typing.Optional[typing.Callable] = None, *, method: Method = Method.THREAD +) -> typing.Callable: + """This is a decorator that was created to make multi-threading, + multi-processing and remote machine parallelism simple in python. + + Sample usage:: + + @parallelize # defaults to thread-mode + def my_function(a, b, c) -> int: + ...do some slow / expensive work, e.g., an http request + + @parallelize(method=Method.PROCESS) + def my_other_function(d, e, f) -> str: + ...do more really expensive work, e.g., a network read + + @parallelize(method=Method.REMOTE) + def my_other_other_function(g, h) -> int: + ...this work will be distributed to a remote machine pool + + This decorator will invoke the wrapped function on:: + + Method.THREAD (default): a background thread + Method.PROCESS: a background process + Method.REMOTE: a process on a remote host + + The wrapped function returns immediately with a value that is + wrapped in a :class:`SmartFuture`. This value will block if it is + either read directly (via a call to :meth:`_resolve`) or indirectly + (by using the result in an expression, printing it, hashing it, + passing it a function argument, etc...). See comments on + :class:`SmartFuture` for details. + + .. warning:: + You may stack @parallelized methods and it will "work". + That said, having multiple layers of :code:`Method.PROCESS` or + :code:`Method.REMOTE` will prove to be problematic because each process in + the stack will use its own independent pool which may overload + your machine with processes or your network with remote processes + beyond the control mechanisms built into one instance of the pool. + Be careful. + + .. note:: + There is non-trivial overhead of pickling code and + copying it over the network when you use :code:`Method.REMOTE`. There's + a smaller but still considerable cost of creating a new process + and passing code to/from it when you use :code:`Method.PROCESS`. + """ + + def wrapper(funct: typing.Callable): + @functools.wraps(funct) + def inner_wrapper(*args, **kwargs): + from pyutils.parallelize import executors, smart_future + + # Look for as of yet unresolved arguments in _funct's + # argument list and resolve them now. + newargs = [] + for arg in args: + newargs.append(smart_future.SmartFuture.resolve(arg)) + newkwargs = {} + for kw in kwargs: + newkwargs[kw] = smart_future.SmartFuture.resolve(kwargs[kw]) + + executor = None + if method == Method.PROCESS: + executor = executors.DefaultExecutors().process_pool() + elif method == Method.THREAD: + executor = executors.DefaultExecutors().thread_pool() + elif method == Method.REMOTE: + executor = executors.DefaultExecutors().remote_pool() + assert executor is not None + atexit.register(executors.DefaultExecutors().shutdown) + + future = executor.submit(funct, *newargs, **newkwargs) + + # Wrap the future that's returned in a SmartFuture object + # so that callers do not need to call .result(), they can + # just use is as normal. + return smart_future.SmartFuture(future) + + return inner_wrapper + + if _funct is None: + return wrapper + else: + return wrapper(_funct) diff --git a/src/pyutils/parallelize/smart_future.py b/src/pyutils/parallelize/smart_future.py new file mode 100644 index 0000000..c29124d --- /dev/null +++ b/src/pyutils/parallelize/smart_future.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""A :class:`Future` that can be treated as a substutute for the result +that it contains and will not block until it is used. At that point, +if the underlying value is not yet available yet, it will block until +the internal result actually becomes available. +""" + +from __future__ import annotations + +import concurrent +import concurrent.futures as fut +import logging +from typing import Callable, List, Set, TypeVar + +from overrides import overrides + +# This module is commonly used by others in here and should avoid +# taking any unnecessary dependencies back on them. +from pyutils import id_generator +from pyutils.parallelize.deferred_operand import DeferredOperand + +logger = logging.getLogger(__name__) + +T = TypeVar('T') + + +def wait_any( + futures: List[SmartFuture], + *, + callback: Callable = None, + log_exceptions: bool = True, + timeout: float = None, +): + """Await the completion of any of a collection of SmartFutures and + invoke callback each time one completes, repeatedly, until they are + all finished. + + Args: + futures: A collection of SmartFutures to wait on + callback: An optional callback to invoke whenever one of the + futures completes + log_exceptions: Should we log (warning + exception) any + underlying exceptions raised during future processing or + silently ignore then? + timeout: invoke callback with a periodicity of timeout while + awaiting futures + """ + + real_futures = [] + smart_future_by_real_future = {} + completed_futures: Set[fut.Future] = set() + for x in futures: + assert isinstance(x, SmartFuture) + real_futures.append(x.wrapped_future) + smart_future_by_real_future[x.wrapped_future] = x + + while len(completed_futures) != len(real_futures): + try: + newly_completed_futures = concurrent.futures.as_completed( + real_futures, timeout=timeout + ) + for f in newly_completed_futures: + if callback is not None: + callback() + completed_futures.add(f) + if log_exceptions and not f.cancelled(): + exception = f.exception() + if exception is not None: + logger.warning( + 'Future 0x%x raised an unhandled exception and exited.', + id(f), + ) + logger.exception(exception) + raise exception + yield smart_future_by_real_future[f] + except TimeoutError: + if callback is not None: + callback() + if callback is not None: + callback() + + +def wait_all( + futures: List[SmartFuture], + *, + log_exceptions: bool = True, +) -> None: + """Wait for all of the SmartFutures in the collection to finish before + returning. + + Args: + futures: A collection of futures that we're waiting for + log_exceptions: Should we log (warning + exception) any + underlying exceptions raised during future processing or + silently ignore then? + """ + + real_futures = [] + for x in futures: + assert isinstance(x, SmartFuture) + real_futures.append(x.wrapped_future) + + (done, not_done) = concurrent.futures.wait( + real_futures, timeout=None, return_when=concurrent.futures.ALL_COMPLETED + ) + if log_exceptions: + for f in real_futures: + if not f.cancelled(): + exception = f.exception() + if exception is not None: + logger.warning( + 'Future 0x%x raised an unhandled exception and exited.', id(f) + ) + logger.exception(exception) + raise exception + assert len(done) == len(real_futures) + assert len(not_done) == 0 + + +class SmartFuture(DeferredOperand): + """This is a SmartFuture, a class that wraps a normal :class:`Future` + and can then be used, mostly, like a normal (non-Future) + identifier of the type of that SmartFuture's result. + + Using a FutureWrapper in expressions will block and wait until + the result of the deferred operation is known. + """ + + def __init__(self, wrapped_future: fut.Future) -> None: + assert isinstance(wrapped_future, fut.Future) + self.wrapped_future = wrapped_future + self.id = id_generator.get("smart_future_id") + + def get_id(self) -> int: + return self.id + + def is_ready(self) -> bool: + return self.wrapped_future.done() + + # You shouldn't have to call this; instead, have a look at defining a + # method on DeferredOperand base class. + @overrides + def _resolve(self, timeout=None) -> T: + return self.wrapped_future.result(timeout) diff --git a/src/pyutils/parallelize/thread_utils.py b/src/pyutils/parallelize/thread_utils.py new file mode 100644 index 0000000..aaef13b --- /dev/null +++ b/src/pyutils/parallelize/thread_utils.py @@ -0,0 +1,212 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""Utilities for dealing with threads + threading.""" + +import functools +import logging +import os +import threading +from typing import Any, Callable, Optional, Tuple + +# This module is commonly used by others in here and should avoid +# taking any unnecessary dependencies back on them. + +logger = logging.getLogger(__name__) + + +def current_thread_id() -> str: + """ + Returns: + a string composed of the parent process' id, the current + process' id and the current thread identifier. The former two are + numbers (pids) whereas the latter is a thread id passed during thread + creation time. + + >>> ret = current_thread_id() + >>> (ppid, pid, tid) = ret.split('/') + >>> ppid.isnumeric() + True + >>> pid.isnumeric() + True + + """ + ppid = os.getppid() + pid = os.getpid() + tid = threading.current_thread().name + return f'{ppid}/{pid}/{tid}:' + + +def is_current_thread_main_thread() -> bool: + """ + Returns: + True is the current (calling) thread is the process' main + thread and False otherwise. + + >>> is_current_thread_main_thread() + True + + >>> result = None + >>> def thunk(): + ... global result + ... result = is_current_thread_main_thread() + + >>> thunk() + >>> result + True + + >>> import threading + >>> thread = threading.Thread(target=thunk) + >>> thread.start() + >>> thread.join() + >>> result + False + + """ + return threading.current_thread() is threading.main_thread() + + +def background_thread( + _funct: Optional[Callable[..., Any]], +) -> Callable[..., Tuple[threading.Thread, threading.Event]]: + """A function decorator to create a background thread. + + Usage:: + + @background_thread + def random(a: int, b: str, stop_event: threading.Event) -> None: + while True: + print(f"Hi there {b}: {a}!") + time.sleep(10.0) + if stop_event.is_set(): + return + + def main() -> None: + (thread, event) = random(22, "dude") + print("back!") + time.sleep(30.0) + event.set() + thread.join() + + .. warning:: + + In addition to any other arguments the function has, it must + take a stop_event as the last unnamed argument which it should + periodically check. If the event is set, it means the thread has + been requested to terminate ASAP. + """ + + def wrapper(funct: Callable): + @functools.wraps(funct) + def inner_wrapper(*a, **kwa) -> Tuple[threading.Thread, threading.Event]: + should_terminate = threading.Event() + should_terminate.clear() + newargs = (*a, should_terminate) + thread = threading.Thread( + target=funct, + args=newargs, + kwargs=kwa, + ) + thread.start() + logger.debug('Started thread "%s" tid=%d', thread.name, thread.ident) + return (thread, should_terminate) + + return inner_wrapper + + if _funct is None: + return wrapper # type: ignore + else: + return wrapper(_funct) + + +class ThreadWithReturnValue(threading.Thread): + """A thread whose return value is plumbed back out as the return + value of :meth:`join`. + """ + + def __init__( + self, group=None, target=None, name=None, args=(), kwargs={}, Verbose=None + ): + threading.Thread.__init__( + self, group=None, target=target, name=None, args=args, kwargs=kwargs + ) + self._target = target + self._return = None + + def run(self): + if self._target is not None: + self._return = self._target(*self._args, **self._kwargs) + + def join(self, *args): + threading.Thread.join(self, *args) + return self._return + + +def periodically_invoke( + period_sec: float, + stop_after: Optional[int], +): + """ + Periodically invoke the decorated function. + + Args: + period_sec: the delay period in seconds between invocations + stop_after: total number of invocations to make or, if None, + call forever + + Returns: + a :class:`Thread` object and an :class:`Event` that, when + signaled, will stop the invocations. + + .. note:: + It is possible to be invoked one time after the :class:`Event` + is set. This event can be used to stop infinite + invocation style or finite invocation style decorations. + + Usage:: + + @periodically_invoke(period_sec=0.5, stop_after=None) + def there(name: str, age: int) -> None: + print(f" ...there {name}, {age}") + + @periodically_invoke(period_sec=1.0, stop_after=3) + def hello(name: str) -> None: + print(f"Hello, {name}") + """ + + def decorator_repeat(func): + def helper_thread(should_terminate, *args, **kwargs) -> None: + if stop_after is None: + while True: + func(*args, **kwargs) + res = should_terminate.wait(period_sec) + if res: + return + else: + for _ in range(stop_after): + func(*args, **kwargs) + res = should_terminate.wait(period_sec) + if res: + return + return + + @functools.wraps(func) + def wrapper_repeat(*args, **kwargs): + should_terminate = threading.Event() + should_terminate.clear() + newargs = (should_terminate, *args) + thread = threading.Thread(target=helper_thread, args=newargs, kwargs=kwargs) + thread.start() + logger.debug('Started thread "%s" tid=%d', thread.name, thread.ident) + return (thread, should_terminate) + + return wrapper_repeat + + return decorator_repeat + + +if __name__ == '__main__': + import doctest + + doctest.testmod() diff --git a/src/pyutils/persistent.py b/src/pyutils/persistent.py new file mode 100644 index 0000000..2b03ea6 --- /dev/null +++ b/src/pyutils/persistent.py @@ -0,0 +1,349 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""A :class:`Persistent` is just a class with a load and save method. This +module defines the :class:`Persistent` base and a decorator that can be used to +create a persistent singleton that autoloads and autosaves.""" + +import atexit +import datetime +import enum +import functools +import logging +import re +from abc import ABC, abstractmethod +from typing import Any, Optional + +from overrides import overrides + +from pyutils.files import file_utils + +logger = logging.getLogger(__name__) + + +class Persistent(ABC): + """ + A base class of an object with a load/save method. Classes that are + decorated with :code:`@persistent_autoloaded_singleton` should subclass + this and implement their :meth:`save` and :meth:`load` methods. + """ + + @abstractmethod + def save(self) -> bool: + """ + Save this thing somewhere that you'll remember when someone calls + :meth:`load` later on in a way that makes sense to your code. + """ + pass + + @classmethod + @abstractmethod + def load(cls) -> Any: + """Load this thing from somewhere and give back an instance which + will become the global singleton and which may (see + below) be saved (via :meth:`save`) at program exit time. + + Oh, in case this is handy, here's a reminder how to write a + factory method that doesn't call the c'tor in python:: + + @classmethod + def load_from_somewhere(cls, somewhere): + # Note: __new__ does not call __init__. + obj = cls.__new__(cls) + + # Don't forget to call any polymorphic base class initializers + super(MyClass, obj).__init__() + + # Load the piece(s) of obj that you want to from somewhere. + obj._state = load_from_somewhere(somewhere) + return obj + """ + pass + + +class FileBasedPersistent(Persistent): + """A Persistent that uses a file to save/load data and knows the conditions + under which the state should be saved/loaded.""" + + @staticmethod + @abstractmethod + def get_filename() -> str: + """Since this class saves/loads to/from a file, what's its full path?""" + pass + + @staticmethod + @abstractmethod + def should_we_save_data(filename: str) -> bool: + pass + + @staticmethod + @abstractmethod + def should_we_load_data(filename: str) -> bool: + pass + + @abstractmethod + def get_persistent_data(self) -> Any: + pass + + +class PicklingFileBasedPersistent(FileBasedPersistent): + @classmethod + @overrides + def load(cls) -> Optional[Any]: + filename = cls.get_filename() + if cls.should_we_load_data(filename): + logger.debug('Attempting to load state from %s', filename) + assert file_utils.file_is_readable(filename) + + import pickle + + try: + with open(filename, 'rb') as rf: + data = pickle.load(rf) + return cls(data) + + except Exception as e: + raise Exception(f'Failed to load {filename}.') from e + return None + + @overrides + def save(self) -> bool: + filename = self.get_filename() + if self.should_we_save_data(filename): + logger.debug('Trying to save state in %s', filename) + try: + import pickle + + with open(filename, 'wb') as wf: + pickle.dump(self.get_persistent_data(), wf, pickle.HIGHEST_PROTOCOL) + return True + except Exception as e: + raise Exception(f'Failed to save to {filename}.') from e + return False + + +class JsonFileBasedPersistent(FileBasedPersistent): + @classmethod + @overrides + def load(cls) -> Any: + filename = cls.get_filename() + if cls.should_we_load_data(filename): + logger.debug('Trying to load state from %s', filename) + import json + + try: + with open(filename, 'r') as rf: + lines = rf.readlines() + + # This is probably bad... but I like comments + # in config files and JSON doesn't support them. So + # pre-process the buffer to remove comments thus + # allowing people to add them. + buf = '' + for line in lines: + line = re.sub(r'#.*$', '', line) + buf += line + + json_dict = json.loads(buf) + return cls(json_dict) + + except Exception as e: + logger.exception(e) + raise Exception(f'Failed to load {filename}.') from e + return None + + @overrides + def save(self) -> bool: + filename = self.get_filename() + if self.should_we_save_data(filename): + logger.debug('Trying to save state in %s', filename) + try: + import json + + json_blob = json.dumps(self.get_persistent_data()) + with open(filename, 'w') as wf: + wf.writelines(json_blob) + return True + except Exception as e: + raise Exception(f'Failed to save to {filename}.') from e + return False + + +def was_file_written_today(filename: str) -> bool: + """Convenience wrapper around :meth:`was_file_written_within_n_seconds`. + + Args: + filename: filename to check + + Returns: + True if filename was written today. + + >>> import os + >>> filename = f'/tmp/testing_persistent_py_{os.getpid()}' + >>> os.system(f'touch {filename}') + 0 + >>> was_file_written_today(filename) + True + >>> os.system(f'touch -d 1974-04-15T01:02:03.99 {filename}') + 0 + >>> was_file_written_today(filename) + False + >>> os.system(f'/bin/rm -f {filename}') + 0 + >>> was_file_written_today(filename) + False + """ + + if not file_utils.does_file_exist(filename): + return False + + mtime = file_utils.get_file_mtime_as_datetime(filename) + assert mtime is not None + now = datetime.datetime.now() + return mtime.month == now.month and mtime.day == now.day and mtime.year == now.year + + +def was_file_written_within_n_seconds( + filename: str, + limit_seconds: int, +) -> bool: + """Helper for determining persisted state staleness. + + Args: + filename: the filename to check + limit_seconds: how fresh, in seconds, it must be + + Returns: + True if filename was written within the past limit_seconds + or False otherwise (or on error). + + >>> import os + >>> filename = f'/tmp/testing_persistent_py_{os.getpid()}' + >>> os.system(f'touch {filename}') + 0 + >>> was_file_written_within_n_seconds(filename, 60) + True + >>> import time + >>> time.sleep(2.0) + >>> was_file_written_within_n_seconds(filename, 2) + False + >>> os.system(f'/bin/rm -f {filename}') + 0 + >>> was_file_written_within_n_seconds(filename, 60) + False + """ + + if not file_utils.does_file_exist(filename): + return False + + mtime = file_utils.get_file_mtime_as_datetime(filename) + assert mtime is not None + now = datetime.datetime.now() + return (now - mtime).total_seconds() <= limit_seconds + + +class PersistAtShutdown(enum.Enum): + """ + An enum to describe the conditions under which state is persisted + to disk. This is passed as an argument to the decorator below and + is used to indicate when to call :meth:`save` on a :class:`Persistent` + subclass. + + * NEVER: never call :meth:`save` + * IF_NOT_LOADED: call :meth:`save` as long as we did not successfully + :meth:`load` its state. + * ALWAYS: always call :meth:`save` + """ + + NEVER = (0,) + IF_NOT_LOADED = (1,) + ALWAYS = (2,) + + +class persistent_autoloaded_singleton(object): + """A decorator that can be applied to a :class:`Persistent` subclass + (i.e. a class with :meth:`save` and :meth:`load` methods. The + decorator will intercept attempts to instantiate the class via + it's c'tor and, instead, invoke the class' :meth:`load` to give it a + chance to read state from somewhere persistent (disk, db, + whatever). Subsequent calls to construt instances of the wrapped + class will return a single, global instance (i.e. the wrapped + class is a singleton). + + If :meth:`load` fails (returns None), the c'tor is invoked with the + original args as a fallback. + + Based upon the value of the optional argument + :code:`persist_at_shutdown` argument, (NEVER, IF_NOT_LOADED, + ALWAYS), the :meth:`save` method of the class will be invoked just + before program shutdown to give the class a chance to save its + state somewhere. + + .. note:: + The implementations of :meth:`save` and :meth:`load` and where the + class persists its state are details left to the :class:`Persistent` + implementation. Essentially this decorator just handles the + plumbing of calling your save/load and appropriate times and + creates a transparent global singleton whose state can be + persisted between runs. + + """ + + def __init__( + self, + *, + persist_at_shutdown: PersistAtShutdown = PersistAtShutdown.IF_NOT_LOADED, + ): + self.persist_at_shutdown = persist_at_shutdown + self.instance = None + + def __call__(self, cls: Persistent): + @functools.wraps(cls) # type: ignore + def _load(*args, **kwargs): + + # If class has already been loaded, act like a singleton + # and return a reference to the one and only instance in + # memory. + if self.instance is not None: + logger.debug( + 'Returning already instantiated singleton instance of %s.', + cls.__name__, + ) + return self.instance + + # Otherwise, try to load it from persisted state. + was_loaded = False + logger.debug('Attempting to load %s from persisted state.', cls.__name__) + self.instance = cls.load() + if not self.instance: + msg = 'Loading from cache failed.' + logger.warning(msg) + logger.debug('Attempting to instantiate %s directly.', cls.__name__) + self.instance = cls(*args, **kwargs) + else: + logger.debug( + 'Class %s was loaded from persisted state successfully.', + cls.__name__, + ) + was_loaded = True + + assert self.instance is not None + + if self.persist_at_shutdown is PersistAtShutdown.ALWAYS or ( + not was_loaded + and self.persist_at_shutdown is PersistAtShutdown.IF_NOT_LOADED + ): + logger.debug( + 'Scheduling a deferred called to save at process shutdown time.' + ) + atexit.register(self.instance.save) + return self.instance + + return _load + + +if __name__ == '__main__': + import doctest + + doctest.testmod() diff --git a/src/pyutils/remote_worker.py b/src/pyutils/remote_worker.py new file mode 100755 index 0000000..cd6e4d6 --- /dev/null +++ b/src/pyutils/remote_worker.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""A simple utility to unpickle some code, run it, and pickle the +results. Please don't unpickle (or run!) code you do not know. +""" + +import logging +import os +import signal +import sys +import threading +import time +from typing import Optional + +import cloudpickle # type: ignore +import psutil # type: ignore + +from pyutils import argparse_utils, bootstrap, config +from pyutils.parallelize.thread_utils import background_thread +from pyutils.stopwatch import Timer + +logger = logging.getLogger(__file__) + +cfg = config.add_commandline_args( + f"Remote Worker ({__file__})", + "Helper to run pickled code remotely and return results", +) +cfg.add_argument( + '--code_file', + type=str, + required=True, + metavar='FILENAME', + help='The location of the bundle of code to execute.', +) +cfg.add_argument( + '--result_file', + type=str, + required=True, + metavar='FILENAME', + help='The location where we should write the computation results.', +) +cfg.add_argument( + '--watch_for_cancel', + action=argparse_utils.ActionNoYes, + default=True, + help='Should we watch for the cancellation of our parent ssh process?', +) + + +@background_thread +def watch_for_cancel(terminate_event: threading.Event) -> None: + logger.debug('Starting up background thread...') + p = psutil.Process(os.getpid()) + while True: + saw_sshd = False + ancestors = p.parents() + for ancestor in ancestors: + name = ancestor.name() + pid = ancestor.pid + logger.debug('Ancestor process %s (pid=%d)', name, pid) + if 'ssh' in name.lower(): + saw_sshd = True + break + if not saw_sshd: + logger.error( + 'Did not see sshd in our ancestors list?! Committing suicide.' + ) + os.system('pstree') + os.kill(os.getpid(), signal.SIGTERM) + time.sleep(5.0) + os.kill(os.getpid(), signal.SIGKILL) + sys.exit(-1) + if terminate_event.is_set(): + return + time.sleep(1.0) + + +def cleanup_and_exit( + thread: Optional[threading.Thread], + stop_thread: Optional[threading.Event], + exit_code: int, +) -> None: + if stop_thread is not None: + stop_thread.set() + assert thread is not None + thread.join() + sys.exit(exit_code) + + +@bootstrap.initialize +def main() -> None: + in_file = config.config['code_file'] + out_file = config.config['result_file'] + + thread = None + stop_thread = None + if config.config['watch_for_cancel']: + (thread, stop_thread) = watch_for_cancel() + + logger.debug('Reading %s.', in_file) + try: + with open(in_file, 'rb') as rb: + serialized = rb.read() + except Exception as e: + logger.exception(e) + logger.critical('Problem reading %s. Aborting.', in_file) + cleanup_and_exit(thread, stop_thread, 1) + + logger.debug('Deserializing %s', in_file) + try: + fun, args, kwargs = cloudpickle.loads(serialized) + except Exception as e: + logger.exception(e) + logger.critical('Problem deserializing %s. Aborting.', in_file) + cleanup_and_exit(thread, stop_thread, 2) + + logger.debug('Invoking user code...') + with Timer() as t: + ret = fun(*args, **kwargs) + logger.debug('User code took %.1fs', t()) + + logger.debug('Serializing results') + try: + serialized = cloudpickle.dumps(ret) + except Exception as e: + logger.exception(e) + logger.critical('Could not serialize result (%s). Aborting.', type(ret)) + cleanup_and_exit(thread, stop_thread, 3) + + logger.debug('Writing %s', out_file) + try: + with open(out_file, 'wb') as wb: + wb.write(serialized) + except Exception as e: + logger.exception(e) + logger.critical('Error writing %s. Aborting.', out_file) + cleanup_and_exit(thread, stop_thread, 4) + cleanup_and_exit(thread, stop_thread, 0) + + +if __name__ == '__main__': + main() diff --git a/src/pyutils/search/__init__.py b/src/pyutils/search/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/pyutils/search/logical_search.py b/src/pyutils/search/logical_search.py new file mode 100644 index 0000000..2b52864 --- /dev/null +++ b/src/pyutils/search/logical_search.py @@ -0,0 +1,470 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""This is a module concerned with the creation of and searching of a +corpus of documents. The corpus and index are held in memory. +""" + +from __future__ import annotations + +import enum +import sys +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union + + +class ParseError(Exception): + """An error encountered while parsing a logical search expression.""" + + def __init__(self, message: str): + super().__init__() + self.message = message + + +@dataclass +class Document: + """A class representing a searchable document.""" + + docid: str = '' + """A unique identifier for each document -- must be provided + by the caller. See :meth:`python_modules.id_generator.get` or + :meth:`python_modules.string_utils.generate_uuid` for potential + sources.""" + + tags: Set[str] = field(default_factory=set) + """A set of tag strings for this document. May be empty. Tags + are simply text labels that are associated with a document and + may be used to search for it later. + """ + + properties: List[Tuple[str, str]] = field(default_factory=list) + """A list of key->value strings for this document. May be empty. + Properties are more flexible tags that have both a label and a + value. e.g. "category:mystery" or "author:smith".""" + + reference: Optional[Any] = None + """An optional reference to something else for convenience; + interpreted only by caller code, ignored here. + """ + + +class Operation(enum.Enum): + """A logical search query operation.""" + + QUERY = 1 + CONJUNCTION = 2 + DISJUNCTION = 3 + INVERSION = 4 + + @staticmethod + def from_token(token: str): + table = { + "not": Operation.INVERSION, + "and": Operation.CONJUNCTION, + "or": Operation.DISJUNCTION, + } + return table.get(token, None) + + def num_operands(self) -> Optional[int]: + table = { + Operation.INVERSION: 1, + Operation.CONJUNCTION: 2, + Operation.DISJUNCTION: 2, + } + return table.get(self, None) + + +class Corpus(object): + """A collection of searchable documents. The caller can + add documents to it (or edit existing docs) via :meth:`add_doc`, + retrieve a document given its docid via :meth:`get_doc`, and + perform various lookups of documents. The most interesting + lookup is implemented in :meth:`query`. + + >>> c = Corpus() + >>> c.add_doc(Document( + ... docid=1, + ... tags=set(['urgent', 'important']), + ... properties=[ + ... ('author', 'Scott'), + ... ('subject', 'your anniversary') + ... ], + ... reference=None, + ... ) + ... ) + >>> c.add_doc(Document( + ... docid=2, + ... tags=set(['important']), + ... properties=[ + ... ('author', 'Joe'), + ... ('subject', 'your performance at work') + ... ], + ... reference=None, + ... ) + ... ) + >>> c.add_doc(Document( + ... docid=3, + ... tags=set(['urgent']), + ... properties=[ + ... ('author', 'Scott'), + ... ('subject', 'car turning in front of you') + ... ], + ... reference=None, + ... ) + ... ) + >>> c.query('author:Scott and important') + {1} + >>> c.query('*') + {1, 2, 3} + >>> c.query('*:*') + {1, 2, 3} + >>> c.query('*:Scott') + {1, 3} + """ + + def __init__(self) -> None: + self.docids_by_tag: Dict[str, Set[str]] = defaultdict(set) + self.docids_by_property: Dict[Tuple[str, str], Set[str]] = defaultdict(set) + self.docids_with_property: Dict[str, Set[str]] = defaultdict(set) + self.documents_by_docid: Dict[str, Document] = {} + + def add_doc(self, doc: Document) -> None: + """Add a new Document to the Corpus. Each Document must have a + distinct docid that will serve as its primary identifier. If + the same Document is added multiple times, only the most + recent addition is indexed. If two distinct documents with + the same docid are added, the latter klobbers the former in + the indexes. See :meth:`python_modules.id_generator.get` or + :meth:`python_modules.string_utils.generate_uuid` for potential + sources of docids. + + Each Document may have an optional set of tags which can be + used later in expressions to the query method. These are simple + text labels. + + Each Document may have an optional list of key->value tuples + which can be used later in expressions to the query method. + + Document includes a user-defined "reference" field which is + never interpreted by this module. This is meant to allow easy + mapping between Documents in this corpus and external objects + they may represent. + + Args: + doc: the document to add or edit + """ + + if doc.docid in self.documents_by_docid: + # Handle collisions; assume that we are re-indexing the + # same document so remove it from the indexes before + # adding it back again. + colliding_doc = self.documents_by_docid[doc.docid] + assert colliding_doc.docid == doc.docid + del self.documents_by_docid[doc.docid] + for tag in colliding_doc.tags: + self.docids_by_tag[tag].remove(doc.docid) + for key, value in colliding_doc.properties: + self.docids_by_property[(key, value)].remove(doc.docid) + self.docids_with_property[key].remove(doc.docid) + + # Index the new Document + assert doc.docid not in self.documents_by_docid + self.documents_by_docid[doc.docid] = doc + for tag in doc.tags: + self.docids_by_tag[tag].add(doc.docid) + for key, value in doc.properties: + self.docids_by_property[(key, value)].add(doc.docid) + self.docids_with_property[key].add(doc.docid) + + def get_docids_by_exact_tag(self, tag: str) -> Set[str]: + """Return the set of docids that have a particular tag. + + Args: + tag: the tag for which to search + + Returns: + A set containing docids with the provided tag which + may be empty.""" + return self.docids_by_tag[tag] + + def get_docids_by_searching_tags(self, tag: str) -> Set[str]: + """Return the set of docids with a tag that contains a str. + + Args: + tag: the tag pattern for which to search + + Returns: + A set containing docids with tags that match the pattern + provided. e.g., if the arg was "foo" tags "football", "foobar", + and "food" all match. + """ + ret = set() + for search_tag in self.docids_by_tag: + if tag in search_tag: + for docid in self.docids_by_tag[search_tag]: + ret.add(docid) + return ret + + def get_docids_with_property(self, key: str) -> Set[str]: + """Return the set of docids that have a particular property no matter + what that property's value. + + Args: + key: the key value to search for. + + Returns: + A set of docids that contain the key (no matter what value) + which may be empty. + """ + return self.docids_with_property[key] + + def get_docids_by_property(self, key: str, value: str) -> Set[str]: + """Return the set of docids that have a particular property with a + particular value. + + Args: + key: the key to search for + value: the value that key must have in order to match a doc. + + Returns: + A set of docids that contain key with value which may be empty. + """ + return self.docids_by_property[(key, value)] + + def invert_docid_set(self, original: Set[str]) -> Set[str]: + """Invert a set of docids.""" + return {docid for docid in self.documents_by_docid if docid not in original} + + def get_doc(self, docid: str) -> Optional[Document]: + """Given a docid, retrieve the previously added Document. + + Args: + docid: the docid to retrieve + + Returns: + The Document with docid or None to indicate no match. + """ + return self.documents_by_docid.get(docid, None) + + def query(self, query: str) -> Optional[Set[str]]: + """Query the corpus for documents that match a logical expression. + + Args: + query: the logical query expressed using a simple language + that understands conjunction (and operator), disjunction + (or operator) and inversion (not operator) as well as + parenthesis. Here are some legal sample queries:: + + tag1 and tag2 and not tag3 + + (tag1 or tag2) and (tag3 or tag4) + + (tag1 and key2:value2) or (tag2 and key1:value1) + + key:* + + tag1 and key:* + + Returns: + A (potentially empty) set of docids for the matching + (previously added) documents or None on error. + """ + + try: + root = self._parse_query(query) + except ParseError as e: + print(e.message, file=sys.stderr) + return None + return root.eval() + + def _parse_query(self, query: str): + """Internal parse helper; prefer to use query instead.""" + + parens = set(["(", ")"]) + and_or = set(["and", "or"]) + + def operator_precedence(token: str) -> Optional[int]: + table = { + "(": 4, # higher + ")": 4, + "not": 3, + "and": 2, + "or": 1, # lower + } + return table.get(token, None) + + def is_operator(token: str) -> bool: + return operator_precedence(token) is not None + + def lex(query: str): + tokens = query.split() + for token in tokens: + # Handle ( and ) operators stuck to the ends of tokens + # that split() doesn't understand. + if len(token) > 1: + first = token[0] + if first in parens: + tail = token[1:] + yield first + token = tail + last = token[-1] + if last in parens: + head = token[0:-1] + yield head + token = last + yield token + + def evaluate(corpus: Corpus, stack: List[str]): + node_stack: List[Node] = [] + for token in stack: + node = None + if not is_operator(token): + node = Node(corpus, Operation.QUERY, [token]) + else: + args = [] + operation = Operation.from_token(token) + operand_count = operation.num_operands() + if len(node_stack) < operand_count: + raise ParseError( + f"Incorrect number of operations for {operation}" + ) + for _ in range(operation.num_operands()): + args.append(node_stack.pop()) + node = Node(corpus, operation, args) + node_stack.append(node) + return node_stack[0] + + output_stack = [] + operator_stack = [] + for token in lex(query): + if not is_operator(token): + output_stack.append(token) + continue + + # token is an operator... + if token == "(": + operator_stack.append(token) + elif token == ")": + ok = False + while len(operator_stack) > 0: + pop_operator = operator_stack.pop() + if pop_operator != "(": + output_stack.append(pop_operator) + else: + ok = True + break + if not ok: + raise ParseError("Unbalanced parenthesis in query expression") + + # and, or, not + else: + my_precedence = operator_precedence(token) + if my_precedence is None: + raise ParseError(f"Unknown operator: {token}") + while len(operator_stack) > 0: + peek_operator = operator_stack[-1] + if not is_operator(peek_operator) or peek_operator == "(": + break + peek_precedence = operator_precedence(peek_operator) + if peek_precedence is None: + raise ParseError("Internal error") + if ( + (peek_precedence < my_precedence) + or (peek_precedence == my_precedence) + and (peek_operator not in and_or) + ): + break + output_stack.append(operator_stack.pop()) + operator_stack.append(token) + while len(operator_stack) > 0: + token = operator_stack.pop() + if token in parens: + raise ParseError("Unbalanced parenthesis in query expression") + output_stack.append(token) + return evaluate(self, output_stack) + + +class Node(object): + """A query AST node.""" + + def __init__( + self, + corpus: Corpus, + op: Operation, + operands: Sequence[Union[Node, str]], + ): + self.corpus = corpus + self.op = op + self.operands = operands + + def eval(self) -> Set[str]: + """Evaluate this node.""" + + evaled_operands: List[Union[Set[str], str]] = [] + for operand in self.operands: + if isinstance(operand, Node): + evaled_operands.append(operand.eval()) + elif isinstance(operand, str): + evaled_operands.append(operand) + else: + raise ParseError(f"Unexpected operand: {operand}") + + retval = set() + if self.op is Operation.QUERY: + for tag in evaled_operands: + if isinstance(tag, str): + if ":" in tag: + try: + key, value = tag.split(":") + except ValueError as v: + raise ParseError( + f'Invalid key:value syntax at "{tag}"' + ) from v + + if key == '*': + r = set() + for kv, s in self.corpus.docids_by_property.items(): + if value in ('*', kv[1]): + r.update(s) + else: + if value == '*': + r = self.corpus.get_docids_with_property(key) + else: + r = self.corpus.get_docids_by_property(key, value) + else: + if tag == '*': + r = set() + for s in self.corpus.docids_by_tag.values(): + r.update(s) + else: + r = self.corpus.get_docids_by_exact_tag(tag) + retval.update(r) + else: + raise ParseError(f"Unexpected query {tag}") + elif self.op is Operation.DISJUNCTION: + if len(evaled_operands) != 2: + raise ParseError("Operation.DISJUNCTION (or) expects two operands.") + retval.update(evaled_operands[0]) + retval.update(evaled_operands[1]) + elif self.op is Operation.CONJUNCTION: + if len(evaled_operands) != 2: + raise ParseError("Operation.CONJUNCTION (and) expects two operands.") + retval.update(evaled_operands[0]) + retval = retval.intersection(evaled_operands[1]) + elif self.op is Operation.INVERSION: + if len(evaled_operands) != 1: + raise ParseError("Operation.INVERSION (not) expects one operand.") + _ = evaled_operands[0] + if isinstance(_, set): + retval.update(self.corpus.invert_docid_set(_)) + else: + raise ParseError(f"Unexpected negation operand {_} ({type(_)})") + return retval + + +if __name__ == '__main__': + import doctest + + doctest.testmod() diff --git a/src/pyutils/security/__init__.py b/src/pyutils/security/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/pyutils/security/acl.py b/src/pyutils/security/acl.py new file mode 100644 index 0000000..d6d5623 --- /dev/null +++ b/src/pyutils/security/acl.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""This module defines various flavors of Access Control Lists.""" + +import enum +import fnmatch +import logging +import re +from abc import ABC, abstractmethod +from typing import Any, Callable, List, Optional, Sequence, Set + +from overrides import overrides + +# This module is commonly used by others in here and should avoid +# taking any unnecessary dependencies back on them. + +logger = logging.getLogger(__name__) + + +class Order(enum.Enum): + """A helper to express the order of evaluation for allows/denies + in an Access Control List. + """ + + UNDEFINED = 0 + ALLOW_DENY = 1 + DENY_ALLOW = 2 + + +class SimpleACL(ABC): + """A simple Access Control List interface.""" + + def __init__(self, *, order_to_check_allow_deny: Order, default_answer: bool): + if order_to_check_allow_deny not in ( + Order.ALLOW_DENY, + Order.DENY_ALLOW, + ): + raise Exception( + 'order_to_check_allow_deny must be Order.ALLOW_DENY or ' + + 'Order.DENY_ALLOW' + ) + self.order_to_check_allow_deny = order_to_check_allow_deny + self.default_answer = default_answer + + def __call__(self, x: Any) -> bool: + """Returns True if x is allowed, False otherwise.""" + logger.debug('SimpleACL checking %s', x) + if self.order_to_check_allow_deny == Order.ALLOW_DENY: + logger.debug('Checking allowed first...') + if self.check_allowed(x): + logger.debug('%s was allowed explicitly.', x) + return True + logger.debug('Checking denied next...') + if self.check_denied(x): + logger.debug('%s was denied explicitly.', x) + return False + elif self.order_to_check_allow_deny == Order.DENY_ALLOW: + logger.debug('Checking denied first...') + if self.check_denied(x): + logger.debug('%s was denied explicitly.', x) + return False + if self.check_allowed(x): + logger.debug('%s was allowed explicitly.', x) + return True + + logger.debug( + f'{x} was not explicitly allowed or denied; ' + + f'using default answer ({self.default_answer})' + ) + return self.default_answer + + @abstractmethod + def check_allowed(self, x: Any) -> bool: + """Return True if x is explicitly allowed, False otherwise.""" + pass + + @abstractmethod + def check_denied(self, x: Any) -> bool: + """Return True if x is explicitly denied, False otherwise.""" + pass + + +class SetBasedACL(SimpleACL): + """An ACL that allows or denies based on membership in a set.""" + + def __init__( + self, + *, + allow_set: Optional[Set[Any]] = None, + deny_set: Optional[Set[Any]] = None, + order_to_check_allow_deny: Order, + default_answer: bool, + ) -> None: + super().__init__( + order_to_check_allow_deny=order_to_check_allow_deny, + default_answer=default_answer, + ) + self.allow_set = allow_set + self.deny_set = deny_set + + @overrides + def check_allowed(self, x: Any) -> bool: + if self.allow_set is None: + return False + return x in self.allow_set + + @overrides + def check_denied(self, x: Any) -> bool: + if self.deny_set is None: + return False + return x in self.deny_set + + +class AllowListACL(SetBasedACL): + """Convenience subclass for a list that only allows known items. + i.e. a 'allowlist' + """ + + def __init__(self, *, allow_set: Optional[Set[Any]]) -> None: + super().__init__( + allow_set=allow_set, + order_to_check_allow_deny=Order.ALLOW_DENY, + default_answer=False, + ) + + +class DenyListACL(SetBasedACL): + """Convenience subclass for a list that only disallows known items. + i.e. a 'blocklist' + """ + + def __init__(self, *, deny_set: Optional[Set[Any]]) -> None: + super().__init__( + deny_set=deny_set, + order_to_check_allow_deny=Order.ALLOW_DENY, + default_answer=True, + ) + + +class BlockListACL(SetBasedACL): + """Convenience subclass for a list that only disallows known items. + i.e. a 'blocklist' + """ + + def __init__(self, *, deny_set: Optional[Set[Any]]) -> None: + super().__init__( + deny_set=deny_set, + order_to_check_allow_deny=Order.ALLOW_DENY, + default_answer=True, + ) + + +class PredicateListBasedACL(SimpleACL): + """An ACL that allows or denies by applying predicates.""" + + def __init__( + self, + *, + allow_predicate_list: Sequence[Callable[[Any], bool]] = None, + deny_predicate_list: Sequence[Callable[[Any], bool]] = None, + order_to_check_allow_deny: Order, + default_answer: bool, + ) -> None: + super().__init__( + order_to_check_allow_deny=order_to_check_allow_deny, + default_answer=default_answer, + ) + self.allow_predicate_list = allow_predicate_list + self.deny_predicate_list = deny_predicate_list + + @overrides + def check_allowed(self, x: Any) -> bool: + if self.allow_predicate_list is None: + return False + return any(predicate(x) for predicate in self.allow_predicate_list) + + @overrides + def check_denied(self, x: Any) -> bool: + if self.deny_predicate_list is None: + return False + return any(predicate(x) for predicate in self.deny_predicate_list) + + +class StringWildcardBasedACL(PredicateListBasedACL): + """An ACL that allows or denies based on string glob :code:`(*, ?)` + patterns. + """ + + def __init__( + self, + *, + allowed_patterns: Optional[List[str]] = None, + denied_patterns: Optional[List[str]] = None, + order_to_check_allow_deny: Order, + default_answer: bool, + ) -> None: + allow_predicates = [] + if allowed_patterns is not None: + for pattern in allowed_patterns: + allow_predicates.append( + lambda x, pattern=pattern: fnmatch.fnmatch(x, pattern) + ) + deny_predicates = None + if denied_patterns is not None: + deny_predicates = [] + for pattern in denied_patterns: + deny_predicates.append( + lambda x, pattern=pattern: fnmatch.fnmatch(x, pattern) + ) + + super().__init__( + allow_predicate_list=allow_predicates, + deny_predicate_list=deny_predicates, + order_to_check_allow_deny=order_to_check_allow_deny, + default_answer=default_answer, + ) + + +class StringREBasedACL(PredicateListBasedACL): + """An ACL that allows or denies by applying regexps.""" + + def __init__( + self, + *, + allowed_regexs: Optional[List[re.Pattern]] = None, + denied_regexs: Optional[List[re.Pattern]] = None, + order_to_check_allow_deny: Order, + default_answer: bool, + ) -> None: + allow_predicates = None + if allowed_regexs is not None: + allow_predicates = [] + for pattern in allowed_regexs: + allow_predicates.append( + lambda x, pattern=pattern: pattern.match(x) is not None + ) + deny_predicates = None + if denied_regexs is not None: + deny_predicates = [] + for pattern in denied_regexs: + deny_predicates.append( + lambda x, pattern=pattern: pattern.match(x) is not None + ) + super().__init__( + allow_predicate_list=allow_predicates, + deny_predicate_list=deny_predicates, + order_to_check_allow_deny=order_to_check_allow_deny, + default_answer=default_answer, + ) + + +class AnyCompoundACL(SimpleACL): + """An ACL that allows if any of its subacls allow.""" + + def __init__( + self, + *, + subacls: Optional[List[SimpleACL]] = None, + order_to_check_allow_deny: Order, + default_answer: bool, + ) -> None: + super().__init__( + order_to_check_allow_deny=order_to_check_allow_deny, + default_answer=default_answer, + ) + self.subacls = subacls + + @overrides + def check_allowed(self, x: Any) -> bool: + if self.subacls is None: + return False + return any(acl(x) for acl in self.subacls) + + @overrides + def check_denied(self, x: Any) -> bool: + if self.subacls is None: + return False + return any(not acl(x) for acl in self.subacls) + + +class AllCompoundACL(SimpleACL): + """An ACL that allows if all of its subacls allow.""" + + def __init__( + self, + *, + subacls: Optional[List[SimpleACL]] = None, + order_to_check_allow_deny: Order, + default_answer: bool, + ) -> None: + super().__init__( + order_to_check_allow_deny=order_to_check_allow_deny, + default_answer=default_answer, + ) + self.subacls = subacls + + @overrides + def check_allowed(self, x: Any) -> bool: + if self.subacls is None: + return False + return all(acl(x) for acl in self.subacls) + + @overrides + def check_denied(self, x: Any) -> bool: + if self.subacls is None: + return False + return any(not acl(x) for acl in self.subacls) diff --git a/src/pyutils/state_tracker.py b/src/pyutils/state_tracker.py new file mode 100644 index 0000000..f83f254 --- /dev/null +++ b/src/pyutils/state_tracker.py @@ -0,0 +1,276 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""Several helpers to keep track of internal state via periodic +polling. :class:`StateTracker` expects to be invoked periodically to +maintain state whereas the others (:class:`AutomaticStateTracker` and +:class:`WaitableAutomaticStateTracker`) automatically update themselves +and, optionally, expose an event for client code to wait on state +changes. +""" + +import datetime +import logging +import threading +import time +from abc import ABC, abstractmethod +from typing import Dict, Optional + +import pytz + +from pyutils.parallelize.thread_utils import background_thread + +logger = logging.getLogger(__name__) + + +class StateTracker(ABC): + """A base class that maintains and updates a global state via an + update routine. Instances of this class should be periodically + invoked via the heartbeat() method. This method, in turn, invokes + update() with update_ids according to a schedule / periodicity + provided to the c'tor. + """ + + def __init__(self, update_ids_to_update_secs: Dict[str, float]) -> None: + """The update_ids_to_update_secs dict parameter describes one or more + update types (unique update_ids) and the periodicity(ies), in + seconds, at which it/they should be invoked. + + .. note:: + When more than one update is overdue, they will be + invoked in order by their update_ids so care in choosing these + identifiers may be in order. + + Args: + update_ids_to_update_secs: a dict mapping a user-defined + update_id into a period (number of seconds) with which + we would like this update performed. e.g.:: + + update_ids_to_update_secs = { + 'refresh_local_state': 10.0, + 'refresh_remote_state': 60.0, + } + + This would indicate that every 10s we would like to + refresh local state whereas every 60s we'd like to + refresh remote state. + """ + self.update_ids_to_update_secs = update_ids_to_update_secs + self.last_reminder_ts: Dict[str, Optional[datetime.datetime]] = {} + self.now: Optional[datetime.datetime] = None + for x in update_ids_to_update_secs.keys(): + self.last_reminder_ts[x] = None + + @abstractmethod + def update( + self, + update_id: str, + now: datetime.datetime, + last_invocation: Optional[datetime.datetime], + ) -> None: + """Put whatever you want here to perform your state updates. + + Args: + update_id: the string you passed to the c'tor as a key in + the update_ids_to_update_secs dict. :meth:`update` will + only be invoked on the shoulder, at most, every update_secs + seconds. + + now: the approximate current timestamp at invocation time. + + last_invocation: the last time this operation was invoked + (or None on the first invocation). + """ + pass + + def heartbeat(self, *, force_all_updates_to_run: bool = False) -> None: + """Invoke this method to cause the StateTracker instance to identify + and invoke any overdue updates based on the schedule passed to + the c'tor. In the base :class:`StateTracker` class, this method must + be invoked manually by a thread from external code. Other subclasses + are available that create their own updater threads (see below). + + If more than one type of update (update_id) are overdue, + they will be invoked in order based on their update_ids. + + Setting force_all_updates_to_run will invoke all updates + (ordered by update_id) immediately ignoring whether or not + they are due. + """ + + self.now = datetime.datetime.now(tz=pytz.timezone("US/Pacific")) + for update_id in sorted(self.last_reminder_ts.keys()): + if force_all_updates_to_run: + logger.debug('Forcing all updates to run') + self.update(update_id, self.now, self.last_reminder_ts[update_id]) + self.last_reminder_ts[update_id] = self.now + return + + refresh_secs = self.update_ids_to_update_secs[update_id] + last_run = self.last_reminder_ts[update_id] + if last_run is None: # Never run before + logger.debug('id %s has never been run; running it now', update_id) + self.update(update_id, self.now, self.last_reminder_ts[update_id]) + self.last_reminder_ts[update_id] = self.now + else: + delta = self.now - last_run + if delta.total_seconds() >= refresh_secs: # Is overdue? + logger.debug('id %s is overdue; running it now', update_id) + self.update( + update_id, + self.now, + self.last_reminder_ts[update_id], + ) + self.last_reminder_ts[update_id] = self.now + + +class AutomaticStateTracker(StateTracker): + """Just like :class:`StateTracker` but you don't need to pump the + :meth:`heartbeat` method periodically because we create a background + thread that manages periodic calling. You must call :meth:`shutdown`, + though, in order to terminate the update thread. + """ + + @background_thread + def pace_maker(self, should_terminate: threading.Event) -> None: + """Entry point for a background thread to own calling :meth:`heartbeat` + at regular intervals so that the main thread doesn't need to + do so. + """ + while True: + if should_terminate.is_set(): + logger.debug('pace_maker noticed event; shutting down') + return + self.heartbeat() + logger.debug('pace_maker is sleeping for %.1fs', self.sleep_delay) + time.sleep(self.sleep_delay) + + def __init__( + self, + update_ids_to_update_secs: Dict[str, float], + *, + override_sleep_delay: Optional[float] = None, + ) -> None: + """Construct an AutomaticStateTracker. + + Args: + update_ids_to_update_secs: a dict mapping a user-defined + update_id into a period (number of seconds) with which + we would like this update performed. e.g.:: + + update_ids_to_update_secs = { + 'refresh_local_state': 10.0, + 'refresh_remote_state': 60.0, + } + + This would indicate that every 10s we would like to + refresh local state whereas every 60s we'd like to + refresh remote state. + + override_sleep_delay: By default, this class determines + how long the background thread should sleep between + automatic invocations to :meth:`heartbeat` based on the + period of each update type in update_ids_to_update_secs. + If this argument is non-None, it overrides this computation + and uses this period as the sleep in the background thread. + """ + from pyutils import math_utils + + super().__init__(update_ids_to_update_secs) + if override_sleep_delay is not None: + logger.debug('Overriding sleep delay to %.1f', override_sleep_delay) + self.sleep_delay = override_sleep_delay + else: + periods_list = list(update_ids_to_update_secs.values()) + self.sleep_delay = math_utils.gcd_float_sequence(periods_list) + logger.info('Computed sleep_delay=%.1f', self.sleep_delay) + (thread, stop_event) = self.pace_maker() + self.should_terminate = stop_event + self.updater_thread = thread + + def shutdown(self): + """Terminates the background thread and waits for it to tear down. + This may block for as long as self.sleep_delay. + """ + logger.debug('Setting shutdown event and waiting for background thread.') + self.should_terminate.set() + self.updater_thread.join() + logger.debug('Background thread terminated.') + + +class WaitableAutomaticStateTracker(AutomaticStateTracker): + """This is an AutomaticStateTracker that exposes a wait method which + will block the calling thread until the state changes with an + optional timeout. The caller should check the return value of + wait; it will be true if something changed and false if the wait + simply timed out. If the return value is true, the instance + should be reset() before wait is called again. + + Example usage:: + + detector = waitable_presence.WaitableAutomaticStateSubclass() + while True: + changed = detector.wait(timeout=60 * 5) + if changed: + detector.reset() + # Figure out what changed and react + else: + # Just a timeout; no need to reset. Maybe do something + # else before looping up into wait again. + """ + + def __init__( + self, + update_ids_to_update_secs: Dict[str, float], + *, + override_sleep_delay: Optional[float] = None, + ) -> None: + """Construct an WaitableAutomaticStateTracker. + + Args: + update_ids_to_update_secs: a dict mapping a user-defined + update_id into a period (number of seconds) with which + we would like this update performed. e.g.:: + + update_ids_to_update_secs = { + 'refresh_local_state': 10.0, + 'refresh_remote_state': 60.0, + } + + This would indicate that every 10s we would like to + refresh local state whereas every 60s we'd like to + refresh remote state. + + override_sleep_delay: By default, this class determines + how long the background thread should sleep between + automatic invocations to :meth:`heartbeat` based on the + period of each update type in update_ids_to_update_secs. + If this argument is non-None, it overrides this computation + and uses this period as the sleep in the background thread. + """ + self._something_changed = threading.Event() + super().__init__( + update_ids_to_update_secs, override_sleep_delay=override_sleep_delay + ) + + def something_changed(self): + """Indicate that something has changed.""" + self._something_changed.set() + + def did_something_change(self) -> bool: + """Indicate whether some state has changed in the background.""" + return self._something_changed.is_set() + + def reset(self): + """Call to clear the 'something changed' bit. See usage above.""" + self._something_changed.clear() + + def wait(self, *, timeout=None): + """Wait for something to change or a timeout to lapse. + + Args: + timeout: maximum amount of time to wait. If None, wait + forever (until something changes). + """ + return self._something_changed.wait(timeout=timeout) diff --git a/src/pyutils/stopwatch.py b/src/pyutils/stopwatch.py new file mode 100644 index 0000000..81d9dce --- /dev/null +++ b/src/pyutils/stopwatch.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""A simple stopwatch decorator / context for timing things. This was factored out +of decorator utils so that bootstrap.py can keep its imports lighter.""" + +import contextlib +import time +from typing import Callable, Literal + + +class Timer(contextlib.AbstractContextManager): + """ + A stopwatch to time how long something takes (walltime). + + e.g. + + with stopwatch.Timer() as t: + do_the_thing() + + walltime = t() + print(f'That took {walltime} seconds.') + """ + + def __init__(self) -> None: + self.start = 0.0 + self.end = 0.0 + + def __enter__(self) -> Callable[[], float]: + """Returns a functor that, when called, returns the walltime of the + operation in seconds. + """ + self.start = time.perf_counter() + self.end = 0.0 + return lambda: self.end - self.start + + def __exit__(self, *args) -> Literal[False]: + self.end = time.perf_counter() + return False diff --git a/src/pyutils/string_utils.py b/src/pyutils/string_utils.py new file mode 100644 index 0000000..575e64e --- /dev/null +++ b/src/pyutils/string_utils.py @@ -0,0 +1,2392 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +"""The MIT License (MIT) + +Copyright (c) 2016-2020 Davide Zanotti +Modifications Copyright (c) 2021-2022 Scott Gasch + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +This class is based on: https://github.com/daveoncode/python-string-utils. +See NOTICE in the root of this module for a detailed enumeration of what +work is Davide's and what work was added by Scott. +""" + +import base64 +import contextlib # type: ignore +import datetime +import io +import json +import logging +import numbers +import random +import re +import string +import unicodedata +import warnings +from itertools import zip_longest +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Literal, + Optional, + Sequence, + Tuple, +) +from uuid import uuid4 + +from pyutils import list_utils + +logger = logging.getLogger(__name__) + +NUMBER_RE = re.compile(r"^([+\-]?)((\d+)(\.\d+)?([e|E]\d+)?|\.\d+)$") + +HEX_NUMBER_RE = re.compile(r"^([+|-]?)0[x|X]([0-9A-Fa-f]+)$") + +OCT_NUMBER_RE = re.compile(r"^([+|-]?)0[O|o]([0-7]+)$") + +BIN_NUMBER_RE = re.compile(r"^([+|-]?)0[B|b]([0|1]+)$") + +URLS_RAW_STRING = ( + r"([a-z-]+://)" # scheme + r"([a-z_\d-]+:[a-z_\d-]+@)?" # user:password + r"(www\.)?" # www. + r"((?]*/?>)(.*?())?||)", + re.IGNORECASE | re.MULTILINE | re.DOTALL, +) + +HTML_TAG_ONLY_RE = re.compile( + r"(<([a-z]+:)?[a-z]+[^>]*/?>|||)", + re.IGNORECASE | re.MULTILINE | re.DOTALL, +) + +SPACES_RE = re.compile(r"\s") + +NO_LETTERS_OR_NUMBERS_RE = re.compile(r"[^\w\d]+|_+", re.IGNORECASE | re.UNICODE) + +MARGIN_RE = re.compile(r"^[^\S\r\n]+") + +ESCAPE_SEQUENCE_RE = re.compile(r"\[[^A-Za-z]*[A-Za-z]") + +NUM_SUFFIXES = { + "Pb": (1024**5), + "P": (1024**5), + "Tb": (1024**4), + "T": (1024**4), + "Gb": (1024**3), + "G": (1024**3), + "Mb": (1024**2), + "M": (1024**2), + "Kb": (1024**1), + "K": (1024**1), +} + +units = [ + "zero", + "one", + "two", + "three", + "four", + "five", + "six", + "seven", + "eight", + "nine", + "ten", + "eleven", + "twelve", + "thirteen", + "fourteen", + "fifteen", + "sixteen", + "seventeen", + "eighteen", + "nineteen", +] + +tens = [ + "", + "", + "twenty", + "thirty", + "forty", + "fifty", + "sixty", + "seventy", + "eighty", + "ninety", +] + +scales = ["hundred", "thousand", "million", "billion", "trillion"] + +NUM_WORDS = {} +NUM_WORDS["and"] = (1, 0) +for i, word in enumerate(units): + NUM_WORDS[word] = (1, i) +for i, word in enumerate(tens): + NUM_WORDS[word] = (1, i * 10) +for i, word in enumerate(scales): + NUM_WORDS[word] = (10 ** (i * 3 or 2), 0) +NUM_WORDS['score'] = (20, 0) + + +def is_none_or_empty(in_str: Optional[str]) -> bool: + """ + Args: + in_str: the string to test + + Returns: + True if the input string is either None or an empty string, + False otherwise. + + >>> is_none_or_empty("") + True + >>> is_none_or_empty(None) + True + >>> is_none_or_empty(" \t ") + True + >>> is_none_or_empty('Test') + False + """ + return in_str is None or len(in_str.strip()) == 0 + + +def is_string(obj: Any) -> bool: + """ + Args: + in_str: the object to test + + Returns: + True if the object is a string and False otherwise. + + >>> is_string('test') + True + >>> is_string(123) + False + >>> is_string(100.3) + False + >>> is_string([1, 2, 3]) + False + """ + return isinstance(obj, str) + + +def is_empty_string(in_str: Any) -> bool: + """ + Args: + in_str: the string to test + + Returns: + True if the string is empty and False otherwise. + """ + return is_empty(in_str) + + +def is_empty(in_str: Any) -> bool: + """ + Args: + in_str: the string to test + + Returns: + True if the string is empty and false otherwise. + + >>> is_empty('') + True + >>> is_empty(' \t\t ') + True + >>> is_empty('test') + False + >>> is_empty(100.88) + False + >>> is_empty([1, 2, 3]) + False + """ + return is_string(in_str) and in_str.strip() == "" + + +def is_full_string(in_str: Any) -> bool: + """ + Args: + in_str: the object to test + + Returns: + True if the object is a string and is not empty ('') and + is not only composed of whitespace. + + >>> is_full_string('test!') + True + >>> is_full_string('') + False + >>> is_full_string(' ') + False + >>> is_full_string(100.999) + False + >>> is_full_string({"a": 1, "b": 2}) + False + """ + return is_string(in_str) and in_str.strip() != "" + + +def is_number(in_str: str) -> bool: + """ + Args: + in_str: the string to test + + Returns: + True if the string contains a valid numberic value and + False otherwise. + + >>> is_number(100.5) + Traceback (most recent call last): + ... + ValueError: 100.5 + >>> is_number("100.5") + True + >>> is_number("test") + False + >>> is_number("99") + True + >>> is_number([1, 2, 3]) + Traceback (most recent call last): + ... + ValueError: [1, 2, 3] + """ + if not is_string(in_str): + raise ValueError(in_str) + return NUMBER_RE.match(in_str) is not None + + +def is_integer_number(in_str: str) -> bool: + """ + Args: + in_str: the string to test + + Returns: + True if the string contains a valid (signed or unsigned, + decimal, hex, or octal, regular or scientific) integral + expression and False otherwise. + + >>> is_integer_number('42') + True + >>> is_integer_number('42.0') + False + """ + return ( + (is_number(in_str) and "." not in in_str) + or is_hexidecimal_integer_number(in_str) + or is_octal_integer_number(in_str) + or is_binary_integer_number(in_str) + ) + + +def is_hexidecimal_integer_number(in_str: str) -> bool: + """ + Args: + in_str: the string to test + + Returns: + True if the string is a hex integer number and False otherwise. + + >>> is_hexidecimal_integer_number('0x12345') + True + >>> is_hexidecimal_integer_number('0x1A3E') + True + >>> is_hexidecimal_integer_number('1234') # Needs 0x + False + >>> is_hexidecimal_integer_number('-0xff') + True + >>> is_hexidecimal_integer_number('test') + False + >>> is_hexidecimal_integer_number(12345) # Not a string + Traceback (most recent call last): + ... + ValueError: 12345 + >>> is_hexidecimal_integer_number(101.4) + Traceback (most recent call last): + ... + ValueError: 101.4 + >>> is_hexidecimal_integer_number(0x1A3E) + Traceback (most recent call last): + ... + ValueError: 6718 + """ + if not is_string(in_str): + raise ValueError(in_str) + return HEX_NUMBER_RE.match(in_str) is not None + + +def is_octal_integer_number(in_str: str) -> bool: + """ + Args: + in_str: the string to test + + Returns: + True if the string is a valid octal integral number and False otherwise. + + >>> is_octal_integer_number('0o777') + True + >>> is_octal_integer_number('-0O115') + True + >>> is_octal_integer_number('0xFF') # Not octal, needs 0o + False + >>> is_octal_integer_number('7777') # Needs 0o + False + >>> is_octal_integer_number('test') + False + """ + if not is_string(in_str): + raise ValueError(in_str) + return OCT_NUMBER_RE.match(in_str) is not None + + +def is_binary_integer_number(in_str: str) -> bool: + """ + Args: + in_str: the string to test + + Returns: + True if the string contains a binary integral number and False otherwise. + + >>> is_binary_integer_number('0b10111') + True + >>> is_binary_integer_number('-0b111') + True + >>> is_binary_integer_number('0B10101') + True + >>> is_binary_integer_number('0b10102') + False + >>> is_binary_integer_number('0xFFF') + False + >>> is_binary_integer_number('test') + False + """ + if not is_string(in_str): + raise ValueError(in_str) + return BIN_NUMBER_RE.match(in_str) is not None + + +def to_int(in_str: str) -> int: + """ + Args: + in_str: the string to convert + + Returns: + The integral value of the string or raises on error. + + >>> to_int('1234') + 1234 + >>> to_int('test') + Traceback (most recent call last): + ... + ValueError: invalid literal for int() with base 10: 'test' + """ + if not is_string(in_str): + raise ValueError(in_str) + if is_binary_integer_number(in_str): + return int(in_str, 2) + if is_octal_integer_number(in_str): + return int(in_str, 8) + if is_hexidecimal_integer_number(in_str): + return int(in_str, 16) + return int(in_str) + + +def number_string_to_integer(in_str: str) -> int: + """Convert a string containing a written-out number into an int. + + >>> number_string_to_integer("one hundred fifty two") + 152 + + >>> number_string_to_integer("ten billion two hundred million fifty four thousand three") + 10200054003 + + >>> number_string_to_integer("four-score and 7") + 87 + + >>> number_string_to_integer("fifty xyzzy three") + Traceback (most recent call last): + ... + ValueError: Unknown word: xyzzy + """ + if type(in_str) == int: + return in_str + + current = result = 0 + in_str = in_str.replace('-', ' ') + for word in in_str.split(): + if word not in NUM_WORDS: + if is_integer_number(word): + current += int(word) + continue + else: + raise ValueError("Unknown word: " + word) + scale, increment = NUM_WORDS[word] + current = current * scale + increment + if scale > 100: + result += current + current = 0 + return result + current + + +def is_decimal_number(in_str: str) -> bool: + """ + Args: + in_str: the string to check + + Returns: + True if the given string represents a decimal or False + otherwise. A decimal may be signed or unsigned or use + a "scientific notation". + + .. note:: + We do not consider integers without a decimal point + to be decimals; they return False (see example). + + >>> is_decimal_number('42.0') + True + >>> is_decimal_number('42') + False + """ + return is_number(in_str) and "." in in_str + + +def strip_escape_sequences(in_str: str) -> str: + """ + Args: + in_str: the string to strip of escape sequences. + + Returns: + in_str with escape sequences removed. + + .. note:: + What is considered to be an "escape sequence" is defined + by a regular expression. While this gets common ones, + there may exist valid sequences that it doesn't match. + + >>> strip_escape_sequences('this is a test!') + 'this is a test!' + """ + in_str = ESCAPE_SEQUENCE_RE.sub("", in_str) + return in_str + + +def add_thousands_separator(in_str: str, *, separator_char=',', places=3) -> str: + """ + Args: + in_str: string or number to which to add thousands separator(s) + separator_char: the separator character to add (defaults to comma) + places: add a separator every N places (defaults to three) + + Returns: + A numeric string with thousands separators added appropriately. + + >>> add_thousands_separator('12345678') + '12,345,678' + >>> add_thousands_separator(12345678) + '12,345,678' + >>> add_thousands_separator(12345678.99) + '12,345,678.99' + >>> add_thousands_separator('test') + Traceback (most recent call last): + ... + ValueError: test + + """ + if isinstance(in_str, numbers.Number): + in_str = f'{in_str}' + if is_number(in_str): + return _add_thousands_separator( + in_str, separator_char=separator_char, places=places + ) + raise ValueError(in_str) + + +def _add_thousands_separator(in_str: str, *, separator_char=',', places=3) -> str: + decimal_part = "" + if '.' in in_str: + (in_str, decimal_part) = in_str.split('.') + tmp = [iter(in_str[::-1])] * places + ret = separator_char.join("".join(x) for x in zip_longest(*tmp, fillvalue=""))[::-1] + if len(decimal_part) > 0: + ret += '.' + ret += decimal_part + return ret + + +def is_url(in_str: Any, allowed_schemes: Optional[List[str]] = None) -> bool: + """ + Args: + in_str: the string to test + allowed_schemes: an optional list of allowed schemes (e.g. + ['http', 'https', 'ftp']. If passed, only URLs that + begin with the one of the schemes passed will be considered + to be valid. Otherwise, any scheme:// will be considered + valid. + + Returns: + True if in_str contains a valid URL and False otherwise. + + >>> is_url('http://www.mysite.com') + True + >>> is_url('https://mysite.com') + True + >>> is_url('.mysite.com') + False + >>> is_url('scheme://username:password@www.domain.com:8042/folder/subfolder/file.extension?param=value¶m2=value2#hash') + True + """ + if not is_full_string(in_str): + return False + + valid = URL_RE.match(in_str) is not None + + if allowed_schemes: + return valid and any([in_str.startswith(s) for s in allowed_schemes]) + return valid + + +def is_email(in_str: Any) -> bool: + """ + Args: + in_str: the email address to check + + Returns: True if the in_str contains a valid email (as defined by + https://tools.ietf.org/html/rfc3696#section-3) or False + otherwise. + + >>> is_email('my.email@the-provider.com') + True + >>> is_email('@gmail.com') + False + """ + if not is_full_string(in_str) or len(in_str) > 320 or in_str.startswith("."): + return False + + try: + # we expect 2 tokens, one before "@" and one after, otherwise + # we have an exception and the email is not valid. + head, tail = in_str.split("@") + + # head's size must be <= 64, tail <= 255, head must not start + # with a dot or contain multiple consecutive dots. + if len(head) > 64 or len(tail) > 255 or head.endswith(".") or (".." in head): + return False + + # removes escaped spaces, so that later on the test regex will + # accept the string. + head = head.replace("\\ ", "") + if head.startswith('"') and head.endswith('"'): + head = head.replace(" ", "")[1:-1] + return EMAIL_RE.match(head + "@" + tail) is not None + + except ValueError: + # borderline case in which we have multiple "@" signs but the + # head part is correctly escaped. + if ESCAPED_AT_SIGN.search(in_str) is not None: + # replace "@" with "a" in the head + return is_email(ESCAPED_AT_SIGN.sub("a", in_str)) + return False + + +def suffix_string_to_number(in_str: str) -> Optional[int]: + """Takes a string like "33Gb" and converts it into a number (of bytes) + like 34603008. + + Args: + in_str: the string with a suffix to be interpreted and removed. + + Returns: + An integer number of bytes or None to indicate an error. + + >>> suffix_string_to_number('1Mb') + 1048576 + >>> suffix_string_to_number('13.1Gb') + 14066017894 + """ + + def suffix_capitalize(s: str) -> str: + if len(s) == 1: + return s.upper() + elif len(s) == 2: + return f"{s[0].upper()}{s[1].lower()}" + return suffix_capitalize(s[0:1]) + + if is_string(in_str): + if is_integer_number(in_str): + return to_int(in_str) + suffixes = [in_str[-2:], in_str[-1:]] + rest = [in_str[:-2], in_str[:-1]] + for x in range(len(suffixes)): + s = suffixes[x] + s = suffix_capitalize(s) + multiplier = NUM_SUFFIXES.get(s, None) + if multiplier is not None: + r = rest[x] + if is_integer_number(r): + return to_int(r) * multiplier + if is_decimal_number(r): + return int(float(r) * multiplier) + return None + + +def number_to_suffix_string(num: int) -> Optional[str]: + """Take a number (of bytes) and returns a string like "43.8Gb". + + Args: + num: an integer number of bytes + + Returns: + A string with a suffix representing num bytes concisely or + None to indicate an error. + + >>> number_to_suffix_string(14066017894) + '13.1Gb' + >>> number_to_suffix_string(1024 * 1024) + '1.0Mb' + """ + d = 0.0 + suffix = None + for (sfx, size) in NUM_SUFFIXES.items(): + if num >= size: + d = num / size + suffix = sfx + break + if suffix is not None: + return f"{d:.1f}{suffix}" + else: + return f'{num:d}' + + +def is_credit_card(in_str: Any, card_type: str = None) -> bool: + """ + Args: + in_str: a string to check + card_type: if provided, contains the card type to validate + with. Otherwise, all known credit card number types will + be accepted. + + Supported card types are the following: + + * VISA + * MASTERCARD + * AMERICAN_EXPRESS + * DINERS_CLUB + * DISCOVER + * JCB + + Returns: + True if in_str is a valid credit card number. + """ + if not is_full_string(in_str): + return False + + if card_type is not None: + if card_type not in CREDIT_CARDS: + raise KeyError( + f'Invalid card type "{card_type}". Valid types are: {CREDIT_CARDS.keys()}' + ) + return CREDIT_CARDS[card_type].match(in_str) is not None + for c in CREDIT_CARDS: + if CREDIT_CARDS[c].match(in_str) is not None: + return True + return False + + +def is_camel_case(in_str: Any) -> bool: + """ + Args: + in_str: the string to test + + Returns: + True if the string is formatted as camel case and False otherwise. + A string is considered camel case when: + + * it's composed only by letters ([a-zA-Z]) and optionally numbers ([0-9]) + * it contains both lowercase and uppercase letters + * it does not start with a number + """ + return is_full_string(in_str) and CAMEL_CASE_TEST_RE.match(in_str) is not None + + +def is_snake_case(in_str: Any, *, separator: str = "_") -> bool: + """ + Args: + in_str: the string to test + + Returns: True if the string is snake case and False otherwise. A + string is considered snake case when: + + * it's composed only by lowercase/uppercase letters and digits + * it contains at least one underscore (or provided separator) + * it does not start with a number + + >>> is_snake_case('this_is_a_test') + True + >>> is_snake_case('___This_Is_A_Test_1_2_3___') + True + >>> is_snake_case('this-is-a-test') + False + >>> is_snake_case('this-is-a-test', separator='-') + True + """ + if is_full_string(in_str): + re_map = {"_": SNAKE_CASE_TEST_RE, "-": SNAKE_CASE_TEST_DASH_RE} + re_template = r"([a-z]+\d*{sign}[a-z\d{sign}]*|{sign}+[a-z\d]+[a-z\d{sign}]*)" + r = re_map.get( + separator, + re.compile(re_template.format(sign=re.escape(separator)), re.IGNORECASE), + ) + return r.match(in_str) is not None + return False + + +def is_json(in_str: Any) -> bool: + """ + Args: + in_str: the string to test + + Returns: + True if the in_str contains valid JSON and False otherwise. + + >>> is_json('{"name": "Peter"}') + True + >>> is_json('[1, 2, 3]') + True + >>> is_json('{nope}') + False + """ + if is_full_string(in_str) and JSON_WRAPPER_RE.match(in_str) is not None: + try: + return isinstance(json.loads(in_str), (dict, list)) + except (TypeError, ValueError, OverflowError): + pass + return False + + +def is_uuid(in_str: Any, allow_hex: bool = False) -> bool: + """ + Args: + in_str: the string to test + + Returns: + True if the in_str contains a valid UUID and False otherwise. + + >>> is_uuid('6f8aa2f9-686c-4ac3-8766-5712354a04cf') + True + >>> is_uuid('6f8aa2f9686c4ac387665712354a04cf') + False + >>> is_uuid('6f8aa2f9686c4ac387665712354a04cf', allow_hex=True) + True + """ + # string casting is used to allow UUID itself as input data type + s = str(in_str) + if allow_hex: + return UUID_HEX_OK_RE.match(s) is not None + return UUID_RE.match(s) is not None + + +def is_ip_v4(in_str: Any) -> bool: + """ + Args: + in_str: the string to test + + Returns: + True if in_str contains a valid IPv4 address and False otherwise. + + >>> is_ip_v4('255.200.100.75') + True + >>> is_ip_v4('nope') + False + >>> is_ip_v4('255.200.100.999') # 999 out of range + False + """ + if not is_full_string(in_str) or SHALLOW_IP_V4_RE.match(in_str) is None: + return False + + # checks that each entry in the ip is in the valid range (0 to 255) + for token in in_str.split("."): + if not 0 <= int(token) <= 255: + return False + return True + + +def extract_ip_v4(in_str: Any) -> Optional[str]: + """ + Args: + in_str: the string to extract an IPv4 address from. + + Returns: + The first extracted IPv4 address from in_str or None if + none were found or an error occurred. + + >>> extract_ip_v4(' The secret IP address: 127.0.0.1 (use it wisely) ') + '127.0.0.1' + >>> extract_ip_v4('Your mom dresses you funny.') + """ + if not is_full_string(in_str): + return None + m = ANYWHERE_IP_V4_RE.search(in_str) + if m is not None: + return m.group(0) + return None + + +def is_ip_v6(in_str: Any) -> bool: + """ + Args: + in_str: the string to test. + + Returns: + True if in_str contains a valid IPv6 address and False otherwise. + + >>> is_ip_v6('2001:db8:85a3:0000:0000:8a2e:370:7334') + True + >>> is_ip_v6('2001:db8:85a3:0000:0000:8a2e:370:?') # invalid "?" + False + """ + return is_full_string(in_str) and IP_V6_RE.match(in_str) is not None + + +def extract_ip_v6(in_str: Any) -> Optional[str]: + """ + Args: + in_str: the string from which to extract an IPv6 address. + + Returns: + The first IPv6 address found in in_str or None if no address + was found or an error occurred. + + >>> extract_ip_v6('IP: 2001:db8:85a3:0000:0000:8a2e:370:7334') + '2001:db8:85a3:0000:0000:8a2e:370:7334' + >>> extract_ip_v6("(and she's ugly too, btw)") + """ + if not is_full_string(in_str): + return None + m = ANYWHERE_IP_V6_RE.search(in_str) + if m is not None: + return m.group(0) + return None + + +def is_ip(in_str: Any) -> bool: + """ + Args: + in_str: the string to test. + + Returns: + True if in_str contains a valid IP address (either IPv4 or + IPv6). + + >>> is_ip('255.200.100.75') + True + >>> is_ip('2001:db8:85a3:0000:0000:8a2e:370:7334') + True + >>> is_ip('1.2.3') + False + >>> is_ip('1.2.3.999') + False + """ + return is_ip_v6(in_str) or is_ip_v4(in_str) + + +def extract_ip(in_str: Any) -> Optional[str]: + """ + Args: + in_str: the string from which to extract in IP address. + + Returns: + The first IP address (IPv4 or IPv6) found in in_str or + None to indicate none found or an error condition. + + >>> extract_ip('Attacker: 255.200.100.75') + '255.200.100.75' + >>> extract_ip('Remote host: 2001:db8:85a3:0000:0000:8a2e:370:7334') + '2001:db8:85a3:0000:0000:8a2e:370:7334' + >>> extract_ip('1.2.3') + """ + ip = extract_ip_v4(in_str) + if ip is None: + ip = extract_ip_v6(in_str) + return ip + + +def is_mac_address(in_str: Any) -> bool: + """ + Args: + in_str: the string to test + + Returns: + True if in_str is a valid MAC address False otherwise. + + >>> is_mac_address("34:29:8F:12:0D:2F") + True + >>> is_mac_address('34:29:8f:12:0d:2f') + True + >>> is_mac_address('34-29-8F-12-0D-2F') + True + >>> is_mac_address("test") + False + """ + return is_full_string(in_str) and MAC_ADDRESS_RE.match(in_str) is not None + + +def extract_mac_address(in_str: Any, *, separator: str = ":") -> Optional[str]: + """ + Args: + in_str: the string from which to extract a MAC address. + + Returns: + The first MAC address found in in_str or None to indicate no + match or an error. + + >>> extract_mac_address(' MAC Address: 34:29:8F:12:0D:2F') + '34:29:8F:12:0D:2F' + + >>> extract_mac_address('? (10.0.0.30) at d8:5d:e2:34:54:86 on em0 expires in 1176 seconds [ethernet]') + 'd8:5d:e2:34:54:86' + """ + if not is_full_string(in_str): + return None + in_str.strip() + m = ANYWHERE_MAC_ADDRESS_RE.search(in_str) + if m is not None: + mac = m.group(0) + mac.replace(":", separator) + mac.replace("-", separator) + return mac + return None + + +def is_slug(in_str: Any, separator: str = "-") -> bool: + """ + Args: + in_str: string to test + + Returns: + True if in_str is a slug string and False otherwise. + + >>> is_slug('my-blog-post-title') + True + >>> is_slug('My blog post title') + False + """ + if not is_full_string(in_str): + return False + rex = r"^([a-z\d]+" + re.escape(separator) + r"*?)*[a-z\d]$" + return re.match(rex, in_str) is not None + + +def contains_html(in_str: str) -> bool: + """ + Args: + in_str: the string to check for tags in + + Returns: + True if the given string contains HTML/XML tags and False + otherwise. + + .. warning:: + By design, this function matches ANY type of tag, so don't expect + to use it as an HTML validator. It's a quick sanity check at + best. See something like BeautifulSoup for a more full-featuered + HTML parser. + + >>> contains_html('my string is bold') + True + >>> contains_html('my string is not bold') + False + + """ + if not is_string(in_str): + raise ValueError(in_str) + return HTML_RE.search(in_str) is not None + + +def words_count(in_str: str) -> int: + """ + Args: + in_str: the string to count words in + + Returns: + The number of words contained in the given string. + + .. note:: + + This method is "smart" in that it does consider only sequences + of one or more letter and/or numbers to be "words". Thus a + string like this: "! @ # % ... []" will return zero. Moreover + it is aware of punctuation, so the count for a string like + "one,two,three.stop" will be 4 not 1 (even if there are no spaces + in the string). + + >>> words_count('hello world') + 2 + >>> words_count('one,two,three.stop') + 4 + """ + if not is_string(in_str): + raise ValueError(in_str) + return len(WORDS_COUNT_RE.findall(in_str)) + + +def word_count(in_str: str) -> int: + """ + Args: + in_str: the string to count words in + + Returns: + The number of words contained in the given string. + + .. note:: + + This method is "smart" in that it does consider only sequences + of one or more letter and/or numbers to be "words". Thus a + string like this: "! @ # % ... []" will return zero. Moreover + it is aware of punctuation, so the count for a string like + "one,two,three.stop" will be 4 not 1 (even if there are no spaces + in the string). + + >>> word_count('hello world') + 2 + >>> word_count('one,two,three.stop') + 4 + """ + return words_count(in_str) + + +def generate_uuid(omit_dashes: bool = False) -> str: + """ + Args: + omit_dashes: should we omit the dashes in the generated UUID? + + Returns: + A generated UUID string (using `uuid.uuid4()`) with or without + dashes per the omit_dashes arg. + + generate_uuid() # possible output: '97e3a716-6b33-4ab9-9bb1-8128cb24d76b' + generate_uuid(omit_dashes=True) # possible output: '97e3a7166b334ab99bb18128cb24d76b' + """ + uid = uuid4() + if omit_dashes: + return uid.hex + return str(uid) + + +def generate_random_alphanumeric_string(size: int) -> str: + """ + Args: + size: number of characters to generate + + Returns: + A string of the specified size containing random characters + (uppercase/lowercase ascii letters and digits). + + >>> random.seed(22) + >>> generate_random_alphanumeric_string(9) + '96ipbNClS' + """ + if size < 1: + raise ValueError("size must be >= 1") + chars = string.ascii_letters + string.digits + buffer = [random.choice(chars) for _ in range(size)] + return from_char_list(buffer) + + +def reverse(in_str: str) -> str: + """ + Args: + in_str: the string to reverse + + Returns: + The reversed (chracter by character) string. + + >>> reverse('test') + 'tset' + """ + if not is_string(in_str): + raise ValueError(in_str) + return in_str[::-1] + + +def camel_case_to_snake_case(in_str, *, separator="_"): + """ + Args: + in_str: the camel case string to convert + + Returns: + A snake case string equivalent to the camel case input or the + original string if it is not a valid camel case string or some + other error occurs. + + >>> camel_case_to_snake_case('MacAddressExtractorFactory') + 'mac_address_extractor_factory' + >>> camel_case_to_snake_case('Luke Skywalker') + 'Luke Skywalker' + """ + if not is_string(in_str): + raise ValueError(in_str) + if not is_camel_case(in_str): + return in_str + return CAMEL_CASE_REPLACE_RE.sub(lambda m: m.group(1) + separator, in_str).lower() + + +def snake_case_to_camel_case( + in_str: str, *, upper_case_first: bool = True, separator: str = "_" +) -> str: + """ + Args: + in_str: the snake case string to convert + + Returns: + A camel case string that is equivalent to the snake case string + provided or the original string back again if it is not valid + snake case or another error occurs. + + >>> snake_case_to_camel_case('this_is_a_test') + 'ThisIsATest' + >>> snake_case_to_camel_case('Han Solo') + 'Han Solo' + """ + if not is_string(in_str): + raise ValueError(in_str) + if not is_snake_case(in_str, separator=separator): + return in_str + tokens = [s.title() for s in in_str.split(separator) if is_full_string(s)] + if not upper_case_first: + tokens[0] = tokens[0].lower() + return from_char_list(tokens) + + +def to_char_list(in_str: str) -> List[str]: + """ + Args: + in_str: the string to split into a char list + + Returns: + A list of strings of length one each. + + >>> to_char_list('test') + ['t', 'e', 's', 't'] + """ + if not is_string(in_str): + return [] + return list(in_str) + + +def from_char_list(in_list: List[str]) -> str: + """ + Args: + in_list: A list of characters to convert into a string. + + Returns: + The string resulting from gluing the characters in in_list + together. + + >>> from_char_list(['t', 'e', 's', 't']) + 'test' + """ + return "".join(in_list) + + +def shuffle(in_str: str) -> Optional[str]: + """ + Args: + in_str: a string to shuffle randomly by character + + Returns: + A new string containing same chars of the given one but in + a randomized order. Note that in rare cases this could result + in the same original string as no check is done. Returns + None to indicate error conditions. + + >>> random.seed(22) + >>> shuffle('awesome') + 'meosaew' + """ + if not is_string(in_str): + return None + chars = to_char_list(in_str) + random.shuffle(chars) + return from_char_list(chars) + + +def scramble(in_str: str) -> Optional[str]: + """ + Args: + in_str: a string to shuffle randomly by character + + Returns: + A new string containing same chars of the given one but in + a randomized order. Note that in rare cases this could result + in the same original string as no check is done. Returns + None to indicate error conditions. + + >>> random.seed(22) + >>> scramble('awesome') + 'meosaew' + """ + return shuffle(in_str) + + +def strip_html(in_str: str, keep_tag_content: bool = False) -> str: + """ + Args: + in_str: the string to strip tags from + keep_tag_content: should we keep the inner contents of tags? + + Returns: + A string with all HTML tags removed (optionally with tag contents + preserved). + + .. note:: + This method uses simple regular expressions to strip tags and is + not a full fledged HTML parser by any means. Consider using + something like BeautifulSoup if your needs are more than this + simple code can fulfill. + + >>> strip_html('test: click here') + 'test: ' + >>> strip_html('test: click here', keep_tag_content=True) + 'test: click here' + """ + if not is_string(in_str): + raise ValueError(in_str) + r = HTML_TAG_ONLY_RE if keep_tag_content else HTML_RE + return r.sub("", in_str) + + +def asciify(in_str: str) -> str: + """ + Args: + in_str: the string to asciify. + + Returns: + An output string roughly equivalent to the original string + where all content to are ascii-only. This is accomplished + by translating all non-ascii chars into their closest possible + ASCII representation (eg: ó -> o, Ë -> E, ç -> c...). + + .. warning:: + Some chars may be lost if impossible to translate. + + >>> asciify('èéùúòóäåëýñÅÀÁÇÌÍÑÓË') + 'eeuuooaaeynAAACIINOE' + """ + if not is_string(in_str): + raise ValueError(in_str) + + # "NFKD" is the algorithm which is able to successfully translate + # the most of non-ascii chars. + normalized = unicodedata.normalize("NFKD", in_str) + + # encode string forcing ascii and ignore any errors + # (unrepresentable chars will be stripped out) + ascii_bytes = normalized.encode("ascii", "ignore") + + # turns encoded bytes into an utf-8 string + return ascii_bytes.decode("utf-8") + + +def slugify(in_str: str, *, separator: str = "-") -> str: + """ + Args: + in_str: the string to slugify + separator: the character to use during sligification (default + is a dash) + + Returns: + The converted string. The returned string has the following properties: + + * it has no spaces + * all letters are in lower case + * all punctuation signs and non alphanumeric chars are removed + * words are divided using provided separator + * all chars are encoded as ascii (by using :meth:`asciify`) + * is safe for URL + + >>> slugify('Top 10 Reasons To Love Dogs!!!') + 'top-10-reasons-to-love-dogs' + >>> slugify('Mönstér Mägnët') + 'monster-magnet' + """ + if not is_string(in_str): + raise ValueError(in_str) + + # replace any character that is NOT letter or number with spaces + out = NO_LETTERS_OR_NUMBERS_RE.sub(" ", in_str.lower()).strip() + + # replace spaces with join sign + out = SPACES_RE.sub(separator, out) + + # normalize joins (remove duplicates) + out = re.sub(re.escape(separator) + r"+", separator, out) + return asciify(out) + + +def to_bool(in_str: str) -> bool: + """ + Args: + in_str: the string to convert to boolean + + Returns: + A boolean equivalent of the original string based on its contents. + All conversion is case insensitive. A positive boolean (True) is + returned if the string value is any of the following: + + * "true" + * "t" + * "1" + * "yes" + * "y" + * "on" + + Otherwise False is returned. + + >>> to_bool('True') + True + + >>> to_bool('1') + True + + >>> to_bool('yes') + True + + >>> to_bool('no') + False + + >>> to_bool('huh?') + False + + >>> to_bool('on') + True + """ + if not is_string(in_str): + raise ValueError(in_str) + return in_str.lower() in ("true", "1", "yes", "y", "t", "on") + + +def to_date(in_str: str) -> Optional[datetime.date]: + """ + Args: + in_str: the string to convert into a date + + Returns: + The datetime.date the string contained or None to indicate + an error. This parser is relatively clever; see + :class:`datetimez.dateparse_utils` docs for details. + + >>> to_date('9/11/2001') + datetime.date(2001, 9, 11) + >>> to_date('xyzzy') + """ + import pyutils.datetimez.dateparse_utils as du + + try: + d = du.DateParser() # type: ignore + d.parse(in_str) + return d.get_date() + except du.ParseException: # type: ignore + msg = f'Unable to parse date {in_str}.' + logger.warning(msg) + return None + + +def extract_date(in_str: Any) -> Optional[datetime.datetime]: + """Finds and extracts a date from the string, if possible. + + Args: + in_str: the string to extract a date from + + Returns: + a datetime if date was found, otherwise None + + >>> extract_date("filename.txt dec 13, 2022") + datetime.datetime(2022, 12, 13, 0, 0) + + >>> extract_date("Dear Santa, please get me a pony.") + + """ + import itertools + + import pyutils.datetimez.dateparse_utils as du + + d = du.DateParser() # type: ignore + chunks = in_str.split() + for ngram in itertools.chain( + list_utils.ngrams(chunks, 5), + list_utils.ngrams(chunks, 4), + list_utils.ngrams(chunks, 3), + list_utils.ngrams(chunks, 2), + ): + try: + expr = " ".join(ngram) + logger.debug(f"Trying {expr}") + if d.parse(expr): + return d.get_datetime() + except du.ParseException: # type: ignore + pass + return None + + +def is_valid_date(in_str: str) -> bool: + """ + Args: + in_str: the string to check + + Returns: + True if the string represents a valid date that we can recognize + and False otherwise. This parser is relatively clever; see + :class:`datetimez.dateparse_utils` docs for details. + + >>> is_valid_date('1/2/2022') + True + >>> is_valid_date('christmas') + True + >>> is_valid_date('next wednesday') + True + >>> is_valid_date('xyzzy') + False + """ + import pyutils.datetimez.dateparse_utils as dp + + try: + d = dp.DateParser() # type: ignore + _ = d.parse(in_str) + return True + except dp.ParseException: # type: ignore + msg = f'Unable to parse date {in_str}.' + logger.warning(msg) + return False + + +def to_datetime(in_str: str) -> Optional[datetime.datetime]: + """ + Args: + in_str: string to parse into a datetime + + Returns: + A python datetime parsed from in_str or None to indicate + an error. This parser is relatively clever; see + :class:`datetimez.dateparse_utils` docs for details. + + >>> to_datetime('7/20/1969 02:56 GMT') + datetime.datetime(1969, 7, 20, 2, 56, tzinfo=) + """ + import pyutils.datetimez.dateparse_utils as dp + + try: + d = dp.DateParser() # type: ignore + dt = d.parse(in_str) + if isinstance(dt, datetime.datetime): + return dt + except Exception: + msg = f'Unable to parse datetime {in_str}.' + logger.warning(msg) + return None + + +def valid_datetime(in_str: str) -> bool: + """ + Args: + in_str: the string to check + + Returns: + True if in_str contains a valid datetime and False otherwise. + This parser is relatively clever; see + :class:`datetimez.dateparse_utils` docs for details. + + >>> valid_datetime('next wednesday at noon') + True + >>> valid_datetime('3 weeks ago at midnight') + True + >>> valid_datetime('next easter at 5:00 am') + True + >>> valid_datetime('sometime soon') + False + """ + _ = to_datetime(in_str) + if _ is not None: + return True + msg = f'Unable to parse datetime {in_str}.' + logger.warning(msg) + return False + + +def squeeze(in_str: str, character_to_squeeze: str = ' ') -> str: + """ + Args: + in_str: the string to squeeze + character_to_squeeze: the character to remove runs of + more than one in a row (default = space) + + Returns: A "squeezed string" where runs of more than one + character_to_squeeze into one. + + >>> squeeze(' this is a test ') + ' this is a test ' + + >>> squeeze('one|!||!|two|!||!|three', character_to_squeeze='|!|') + 'one|!|two|!|three' + + """ + return re.sub( + r'(' + re.escape(character_to_squeeze) + r')+', + character_to_squeeze, + in_str, + ) + + +def dedent(in_str: str) -> Optional[str]: + """ + Args: + in_str: the string to dedent + + Returns: + A string with tab indentation removed or None on error. + + .. note:: + + Inspired by analogous Scala function. + + >>> dedent('\t\ttest\\n\t\ting') + 'test\\ning' + """ + if not is_string(in_str): + return None + line_separator = '\n' + lines = [MARGIN_RE.sub('', line) for line in in_str.split(line_separator)] + return line_separator.join(lines) + + +def indent(in_str: str, amount: int) -> str: + """ + Args: + in_str: the string to indent + amount: count of spaces to indent each line by + + Returns: + An indented string created by prepending amount spaces. + + >>> indent('This is a test', 4) + ' This is a test' + """ + if not is_string(in_str): + raise ValueError(in_str) + line_separator = '\n' + lines = [" " * amount + line for line in in_str.split(line_separator)] + return line_separator.join(lines) + + +def sprintf(*args, **kwargs) -> str: + """ + Args: + This function uses the same syntax as the builtin print + function. + + Returns: + An interpolated string capturing print output, like man(3) + :code:sprintf. + """ + ret = "" + + sep = kwargs.pop("sep", None) + if sep is not None: + if not isinstance(sep, str): + raise TypeError("sep must be None or a string") + + end = kwargs.pop("end", None) + if end is not None: + if not isinstance(end, str): + raise TypeError("end must be None or a string") + + if kwargs: + raise TypeError("invalid keyword arguments to sprint()") + + if sep is None: + sep = " " + if end is None: + end = "\n" + for i, arg in enumerate(args): + if i: + ret += sep + if isinstance(arg, str): + ret += arg + else: + ret += str(arg) + ret += end + return ret + + +def strip_ansi_sequences(in_str: str) -> str: + """ + Args: + in_str: the string to strip + + Returns: + in_str with recognized ANSI escape sequences removed. + + .. warning:: + This method works by using a regular expression. + It works for all ANSI escape sequences I've tested with but + may miss some; caveat emptor. + + >>> import ansi as a + >>> s = a.fg('blue') + 'blue!' + a.reset() + >>> len(s) # '\x1b[38;5;21mblue!\x1b[m' + 18 + >>> len(strip_ansi_sequences(s)) + 5 + >>> strip_ansi_sequences(s) + 'blue!' + + """ + return re.sub(r'\x1b\[[\d+;]*[a-z]', '', in_str) + + +class SprintfStdout(contextlib.AbstractContextManager): + """ + A context manager that captures outputs to stdout to a buffer + without printing them. + + >>> with SprintfStdout() as buf: + ... print("test") + ... print("1, 2, 3") + ... + >>> print(buf(), end='') + test + 1, 2, 3 + + """ + + def __init__(self) -> None: + self.destination = io.StringIO() + self.recorder: contextlib.redirect_stdout + + def __enter__(self) -> Callable[[], str]: + self.recorder = contextlib.redirect_stdout(self.destination) + self.recorder.__enter__() + return lambda: self.destination.getvalue() + + def __exit__(self, *args) -> Literal[False]: + self.recorder.__exit__(*args) + self.destination.seek(0) + return False + + +def capitalize_first_letter(in_str: str) -> str: + """ + Args: + in_str: the string to capitalize + + Returns: + in_str with the first character capitalized. + + >>> capitalize_first_letter('test') + 'Test' + >>> capitalize_first_letter("ALREADY!") + 'ALREADY!' + + """ + return in_str[0].upper() + in_str[1:] + + +def it_they(n: int) -> str: + """ + Args: + n: how many of them are there? + + Returns: + 'it' if n is one or 'they' otherwize. + + Suggested usage:: + + n = num_files_saved_to_tmp() + print(f'Saved file{pluralize(n)} successfully.') + print(f'{it_they(n)} {is_are(n)} located in /tmp.') + + >>> it_they(1) + 'it' + >>> it_they(100) + 'they' + """ + if n == 1: + return "it" + return "they" + + +def is_are(n: int) -> str: + """ + Args: + n: how many of them are there? + + Returns: + 'is' if n is one or 'are' otherwize. + + Suggested usage:: + + n = num_files_saved_to_tmp() + print(f'Saved file{pluralize(n)} successfully.') + print(f'{it_they(n)} {is_are(n)} located in /tmp.') + + >>> is_are(1) + 'is' + >>> is_are(2) + 'are' + + """ + if n == 1: + return "is" + return "are" + + +def pluralize(n: int) -> str: + """ + Args: + n: how many of them are there? + + Returns: + 's' if n is greater than one otherwize ''. + + Suggested usage:: + + n = num_files_saved_to_tmp() + print(f'Saved file{pluralize(n)} successfully.') + print(f'{it_they(n)} {is_are(n)} located in /tmp.') + + >>> pluralize(15) + 's' + >>> count = 1 + >>> print(f'There {is_are(count)} {count} file{pluralize(count)}.') + There is 1 file. + >>> count = 4 + >>> print(f'There {is_are(count)} {count} file{pluralize(count)}.') + There are 4 files. + """ + if n == 1: + return "" + return "s" + + +def make_contractions(txt: str) -> str: + """This code glues words in txt together to form (English) + contractions. + + Args: + txt: the input text to be contractionized. + + Returns: + Output text identical to original input except for any + recognized contractions are formed. + + .. note:: + The order in which we create contractions is defined by the + implementation and what I thought made more sense when writing + this code. + + >>> make_contractions('It is nice today.') + "It's nice today." + + >>> make_contractions('I can not even...') + "I can't even..." + + >>> make_contractions('She could not see!') + "She couldn't see!" + + >>> make_contractions('But she will not go.') + "But she won't go." + + >>> make_contractions('Verily, I shall not.') + "Verily, I shan't." + + >>> make_contractions('No you cannot.') + "No you can't." + + >>> make_contractions('I said you can not go.') + "I said you can't go." + """ + + first_second = [ + ( + [ + 'are', + 'could', + 'did', + 'has', + 'have', + 'is', + 'must', + 'should', + 'was', + 'were', + 'would', + ], + ['(n)o(t)'], + ), + ( + [ + "I", + "you", + "he", + "she", + "it", + "we", + "they", + "how", + "why", + "when", + "where", + "who", + "there", + ], + ['woul(d)', 'i(s)', 'a(re)', 'ha(s)', 'ha(ve)', 'ha(d)', 'wi(ll)'], + ), + ] + + # Special cases: can't, shan't and won't. + txt = re.sub(r'\b(can)\s*no(t)\b', r"\1'\2", txt, count=0, flags=re.IGNORECASE) + txt = re.sub( + r'\b(sha)ll\s*(n)o(t)\b', r"\1\2'\3", txt, count=0, flags=re.IGNORECASE + ) + txt = re.sub( + r'\b(w)ill\s*(n)(o)(t)\b', + r"\1\3\2'\4", + txt, + count=0, + flags=re.IGNORECASE, + ) + + for first_list, second_list in first_second: + for first in first_list: + for second in second_list: + # Disallow there're/where're. They're valid English + # but sound weird. + if (first in ('there', 'where')) and second == 'a(re)': + continue + + pattern = fr'\b({first})\s+{second}\b' + if second == '(n)o(t)': + replacement = r"\1\2'\3" + else: + replacement = r"\1'\2" + txt = re.sub(pattern, replacement, txt, count=0, flags=re.IGNORECASE) + + return txt + + +def thify(n: int) -> str: + """ + Args: + n: how many of them are there? + + Returns: + The proper cardinal suffix for a number. + + Suggested usage:: + + attempt_count = 0 + while True: + attempt_count += 1 + if try_the_thing(): + break + print(f'The {attempt_count}{thify(attempt_count)} failed, trying again.') + + >>> thify(1) + 'st' + >>> thify(33) + 'rd' + >>> thify(16) + 'th' + """ + digit = str(n) + assert is_integer_number(digit) + digit = digit[-1:] + if digit == "1": + return "st" + elif digit == "2": + return "nd" + elif digit == "3": + return "rd" + else: + return "th" + + +def ngrams(txt: str, n: int): + """ + Args: + txt: the string to create ngrams using + n: how many words per ngram created? + + Returns: + Generates the ngrams from the input string. + + >>> [x for x in ngrams('This is a test', 2)] + ['This is', 'is a', 'a test'] + """ + words = txt.split() + for ngram in ngrams_presplit(words, n): + ret = '' + for word in ngram: + ret += f'{word} ' + yield ret.strip() + + +def ngrams_presplit(words: Sequence[str], n: int): + """ + Same as :meth:`ngrams` but with the string pre-split. + """ + return list_utils.ngrams(words, n) + + +def bigrams(txt: str): + """Generates the bigrams (n=2) of the given string.""" + return ngrams(txt, 2) + + +def trigrams(txt: str): + """Generates the trigrams (n=3) of the given string.""" + return ngrams(txt, 3) + + +def shuffle_columns_into_list( + input_lines: Sequence[str], column_specs: Iterable[Iterable[int]], delim='' +) -> Iterable[str]: + """Helper to shuffle / parse columnar data and return the results as a + list. + + Args: + input_lines: A sequence of strings that represents text that + has been broken into columns by the caller + column_specs: an iterable collection of numeric sequences that + indicate one or more column numbers to copy to form the Nth + position in the output list. See example below. + delim: for column_specs that indicate we should copy more than + one column from the input into this position, use delim to + separate source data. Defaults to ''. + + Returns: + A list of string created by following the instructions set forth + in column_specs. + + >>> cols = '-rwxr-xr-x 1 scott wheel 3.1K Jul 9 11:34 acl_test.py'.split() + >>> shuffle_columns_into_list( + ... cols, + ... [ [8], [2, 3], [5, 6, 7] ], + ... delim='!', + ... ) + ['acl_test.py', 'scott!wheel', 'Jul!9!11:34'] + """ + out = [] + + # Column specs map input lines' columns into outputs. + # [col1, col2...] + for spec in column_specs: + hunk = '' + for n in spec: + hunk = hunk + delim + input_lines[n] + hunk = hunk.strip(delim) + out.append(hunk) + return out + + +def shuffle_columns_into_dict( + input_lines: Sequence[str], + column_specs: Iterable[Tuple[str, Iterable[int]]], + delim='', +) -> Dict[str, str]: + """Helper to shuffle / parse columnar data and return the results + as a dict. + + Args: + input_lines: a sequence of strings that represents text that + has been broken into columns by the caller + column_specs: instructions for what dictionary keys to apply + to individual or compound input column data. See example + below. + delim: when forming compound output data by gluing more than + one input column together, use this character to separate + the source data. Defaults to ''. + + Returns: + A dict formed by applying the column_specs instructions. + + >>> cols = '-rwxr-xr-x 1 scott wheel 3.1K Jul 9 11:34 acl_test.py'.split() + >>> shuffle_columns_into_dict( + ... cols, + ... [ ('filename', [8]), ('owner', [2, 3]), ('mtime', [5, 6, 7]) ], + ... delim='!', + ... ) + {'filename': 'acl_test.py', 'owner': 'scott!wheel', 'mtime': 'Jul!9!11:34'} + """ + out = {} + + # Column specs map input lines' columns into outputs. + # "key", [col1, col2...] + for spec in column_specs: + hunk = '' + for n in spec[1]: + hunk = hunk + delim + input_lines[n] + hunk = hunk.strip(delim) + out[spec[0]] = hunk + return out + + +def interpolate_using_dict(txt: str, values: Dict[str, str]) -> str: + """ + Interpolate a string with data from a dict. + + Args: + txt: the mad libs template + values: what you and your kids chose for each category. + + >>> interpolate_using_dict('This is a {adjective} {noun}.', + ... {'adjective': 'good', 'noun': 'example'}) + 'This is a good example.' + """ + return sprintf(txt.format(**values), end='') + + +def to_ascii(txt: str): + """ + Args: + txt: the input data to encode + + Returns: + txt encoded as an ASCII byte string. + + >>> to_ascii('test') + b'test' + + >>> to_ascii(b'1, 2, 3') + b'1, 2, 3' + """ + if isinstance(txt, str): + return txt.encode('ascii') + if isinstance(txt, bytes): + return txt + raise Exception('to_ascii works with strings and bytes') + + +def to_base64(txt: str, *, encoding='utf-8', errors='surrogatepass') -> bytes: + """ + Args: + txt: the input data to encode + + Returns: + txt encoded with a 64-chracter alphabet. Similar to and compatible + with uuencode/uudecode. + + >>> to_base64('hello?') + b'aGVsbG8/\\n' + """ + return base64.encodebytes(txt.encode(encoding, errors)) + + +def is_base64(txt: str) -> bool: + """ + Args: + txt: the string to check + + Returns: + True if txt is a valid base64 encoded string. This assumes + txt was encoded with Python's standard base64 alphabet which + is the same as what uuencode/uudecode uses). + + >>> is_base64('test') # all letters in the b64 alphabet + True + + >>> is_base64('another test, how do you like this one?') + False + + >>> is_base64(b'aGVsbG8/\\n') # Ending newline is ok. + True + + """ + a = string.ascii_uppercase + string.ascii_lowercase + string.digits + '+/' + alphabet = set(a.encode('ascii')) + for char in to_ascii(txt.strip()): + if char not in alphabet: + return False + return True + + +def from_base64(b64: bytes, encoding='utf-8', errors='surrogatepass') -> str: + """ + Args: + b64: bytestring of 64-bit encoded data to decode / convert. + + Returns: + The decoded form of b64 as a normal python string. Similar to + and compatible with uuencode / uudecode. + + >>> from_base64(b'aGVsbG8/\\n') + 'hello?' + """ + return base64.decodebytes(b64).decode(encoding, errors) + + +def chunk(txt: str, chunk_size: int): + """ + Args: + txt: a string to be chunked into evenly spaced pieces. + chunk_size: the size of each chunk to make + + Returns: + The original string chunked into evenly spaced pieces. + + >>> ' '.join(chunk('010011011100010110101010101010101001111110101000', 8)) + '01001101 11000101 10101010 10101010 10011111 10101000' + """ + if len(txt) % chunk_size != 0: + msg = f'String to chunk\'s length ({len(txt)} is not an even multiple of chunk_size ({chunk_size})' + logger.warning(msg) + warnings.warn(msg, stacklevel=2) + for x in range(0, len(txt), chunk_size): + yield txt[x : x + chunk_size] + + +def to_bitstring(txt: str, *, delimiter='') -> str: + """ + Args: + txt: the string to convert into a bitstring + delimiter: character to insert between adjacent bytes. Note that + only bitstrings with delimiter='' are interpretable by + :meth:`from_bitstring`. + + Returns: + txt converted to ascii/binary and then chopped into bytes. + + >>> to_bitstring('hello?') + '011010000110010101101100011011000110111100111111' + + >>> to_bitstring('test', delimiter=' ') + '01110100 01100101 01110011 01110100' + + >>> to_bitstring(b'test') + '01110100011001010111001101110100' + """ + etxt = to_ascii(txt) + bits = bin(int.from_bytes(etxt, 'big')) + bits = bits[2:] + return delimiter.join(chunk(bits.zfill(8 * ((len(bits) + 7) // 8)), 8)) + + +def is_bitstring(txt: str) -> bool: + """ + Args: + txt: the string to check + + Returns: + True if txt is a recognized bitstring and False otherwise. + Note that if delimiter is non empty this code will not + recognize the bitstring. + + >>> is_bitstring('011010000110010101101100011011000110111100111111') + True + + >>> is_bitstring('1234') + False + """ + return is_binary_integer_number(f'0b{txt}') + + +def from_bitstring(bits: str, encoding='utf-8', errors='surrogatepass') -> str: + """ + Args: + bits: the bitstring to convert back into a python string + encoding: the encoding to use + + Returns: + The regular python string represented by bits. Note that this + code does not work with to_bitstring when delimiter is non-empty. + + >>> from_bitstring('011010000110010101101100011011000110111100111111') + 'hello?' + """ + n = int(bits, 2) + return n.to_bytes((n.bit_length() + 7) // 8, 'big').decode(encoding, errors) or '\0' + + +def ip_v4_sort_key(txt: str) -> Optional[Tuple[int, ...]]: + """ + Args: + txt: an IP address to chunk up for sorting purposes + + Returns: + A tuple of IP components arranged such that the sorting of + IP addresses using a normal comparator will do something sane + and desireable. + + >>> ip_v4_sort_key('10.0.0.18') + (10, 0, 0, 18) + + >>> ips = ['10.0.0.10', '100.0.0.1', '1.2.3.4', '10.0.0.9'] + >>> sorted(ips, key=lambda x: ip_v4_sort_key(x)) + ['1.2.3.4', '10.0.0.9', '10.0.0.10', '100.0.0.1'] + """ + if not is_ip_v4(txt): + print(f"not IP: {txt}") + return None + return tuple(int(x) for x in txt.split('.')) + + +def path_ancestors_before_descendants_sort_key(volume: str) -> Tuple[str, ...]: + """ + Args: + volume: the string to chunk up for sorting purposes + + Returns: + A tuple of volume's components such that the sorting of + volumes using a normal comparator will do something sane + and desireable. + + >>> path_ancestors_before_descendants_sort_key('/usr/local/bin') + ('usr', 'local', 'bin') + + >>> paths = ['/usr/local', '/usr/local/bin', '/usr'] + >>> sorted(paths, key=lambda x: path_ancestors_before_descendants_sort_key(x)) + ['/usr', '/usr/local', '/usr/local/bin'] + """ + return tuple(x for x in volume.split('/') if len(x) > 0) + + +def replace_all(in_str: str, replace_set: str, replacement: str) -> str: + """ + Execute several replace operations in a row. + + Args: + in_str: the string in which to replace characters + replace_set: the set of target characters to replace + replacement: the character to replace any member of replace_set + with + + Returns: + The string with replacements executed. + + >>> s = 'this_is a-test!' + >>> replace_all(s, ' _-!', '') + 'thisisatest' + """ + for char in replace_set: + in_str = in_str.replace(char, replacement) + return in_str + + +def replace_nth(in_str: str, source: str, target: str, nth: int): + """ + Replaces the nth occurrance of a substring within a string. + + Args: + in_str: the string in which to run the replacement + source: the substring to replace + target: the replacement text + nth: which occurrance of source to replace? + + >>> replace_nth('this is a test', ' ', '-', 3) + 'this is a-test' + """ + where = [m.start() for m in re.finditer(source, in_str)][nth - 1] + before = in_str[:where] + after = in_str[where:] + after = after.replace(source, target, 1) + return before + after + + +if __name__ == '__main__': + import doctest + + doctest.testmod() diff --git a/src/pyutils/text_utils.py b/src/pyutils/text_utils.py new file mode 100644 index 0000000..93355aa --- /dev/null +++ b/src/pyutils/text_utils.py @@ -0,0 +1,707 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# © Copyright 2021-2022, Scott Gasch + +"""Utilities for dealing with "text".""" + +import contextlib +import enum +import logging +import math +import os +import re +import sys +from collections import defaultdict +from dataclasses import dataclass +from typing import Dict, Generator, List, Literal, Optional, Tuple + +from pyutils import string_utils +from pyutils.ansi import fg, reset + +logger = logging.getLogger(__file__) + + +@dataclass +class RowsColumns: + """Row + Column""" + + rows: int = 0 + """Numer of rows""" + + columns: int = 0 + """Number of columns""" + + +def get_console_rows_columns() -> RowsColumns: + """ + Returns: + The number of rows/columns on the current console or None + if we can't tell or an error occurred. + """ + from pyutils.exec_utils import cmd + + rows: Optional[str] = os.environ.get('LINES', None) + cols: Optional[str] = os.environ.get('COLUMNS', None) + if not rows or not cols: + logger.debug('Rows: %s, cols: %s, trying stty.', rows, cols) + try: + rows, cols = cmd( + "stty size", + timeout_seconds=1.0, + ).split() + except Exception: + rows = None + cols = None + + if rows is None: + logger.debug('Rows: %s, cols: %s, tput rows.', rows, cols) + try: + rows = cmd( + "tput rows", + timeout_seconds=1.0, + ) + except Exception: + rows = None + + if cols is None: + logger.debug('Rows: %s, cols: %s, tput cols.', rows, cols) + try: + cols = cmd( + "tput cols", + timeout_seconds=1.0, + ) + except Exception: + cols = None + + if not rows or not cols: + raise Exception('Can\'t determine console size?!') + return RowsColumns(int(rows), int(cols)) + + +class BarGraphText(enum.Enum): + """What kind of text to include at the end of the bar graph?""" + + NONE = (0,) + """None, leave it blank.""" + + PERCENTAGE = (1,) + """XX.X%""" + + FRACTION = (2,) + """N / K""" + + +def bar_graph( + current: int, + total: int, + *, + width=70, + text: BarGraphText = BarGraphText.PERCENTAGE, + fgcolor=fg("school bus yellow"), + left_end="[", + right_end="]", + redraw=True, +) -> None: + """Draws a progress graph at the current cursor position. + + Args: + current: how many have we done so far? + total: how many are there to do total? + text: how should we render the text at the end? + width: how many columns wide should be progress graph be? + fgcolor: what color should "done" part of the graph be? + left_end: the character at the left side of the graph + right_end: the character at the right side of the graph + redraw: if True, omit a line feed after the carriage return + so that subsequent calls to this method redraw the graph + iteratively. + """ + ret = "\r" if redraw else "\n" + bar = bar_graph_string( + current, + total, + text=text, + width=width, + fgcolor=fgcolor, + left_end=left_end, + right_end=right_end, + ) + print(bar, end=ret, flush=True, file=sys.stderr) + + +def _make_bar_graph_text( + text: BarGraphText, current: int, total: int, percentage: float +): + if text == BarGraphText.NONE: + return "" + elif text == BarGraphText.PERCENTAGE: + return f'{percentage:.1f}' + elif text == BarGraphText.FRACTION: + return f'{current} / {total}' + raise ValueError(text) + + +def bar_graph_string( + current: int, + total: int, + *, + text: BarGraphText = BarGraphText.PERCENTAGE, + width=70, + fgcolor=fg("school bus yellow"), + reset_seq=reset(), + left_end="[", + right_end="]", +) -> str: + """Returns a string containing a bar graph. + + Args: + current: how many have we done so far? + total: how many are there to do total? + text: how should we render the text at the end? + width: how many columns wide should be progress graph be? + fgcolor: what color should "done" part of the graph be? + reset_seq: sequence to use to turn off color + left_end: the character at the left side of the graph + right_end: the character at the right side of the graph + + >>> bar_graph_string(5, 10, fgcolor='', reset_seq='') + '[███████████████████████████████████ ] 0.5' + + """ + + if total != 0: + percentage = float(current) / float(total) + else: + percentage = 0.0 + if percentage < 0.0 or percentage > 1.0: + raise ValueError(percentage) + text = _make_bar_graph_text(text, current, total, percentage) + whole_width = math.floor(percentage * width) + if whole_width == width: + whole_width -= 1 + part_char = "▉" + elif whole_width == 0 and percentage > 0.0: + part_char = "▏" + else: + remainder_width = (percentage * width) % 1 + part_width = math.floor(remainder_width * 8) + part_char = [" ", "▏", "▎", "▍", "▌", "▋", "▊", "▉"][part_width] + return ( + left_end + + fgcolor + + "█" * whole_width + + part_char + + " " * (width - whole_width - 1) + + reset_seq + + right_end + + " " + + text + ) + + +def sparkline(numbers: List[float]) -> Tuple[float, float, str]: + """ + Makes a "sparkline" little inline histogram graph. Auto scales. + + Args: + numbers: the population over which to create the sparkline + + Returns: + a three tuple containing: + + * the minimum number in the population + * the maximum number in the population + * a string representation of the population in a concise format + + >>> sparkline([1, 2, 3, 5, 10, 3, 5, 7]) + (1, 10, '▁▁▂▄█▂▄▆') + + >>> sparkline([104, 99, 93, 96, 82, 77, 85, 73]) + (73, 104, '█▇▆▆▃▂▄▁') + + """ + _bar = '▁▂▃▄▅▆▇█' # Unicode: 9601, 9602, 9603, 9604, 9605, 9606, 9607, 9608 + + barcount = len(_bar) + min_num, max_num = min(numbers), max(numbers) + span = max_num - min_num + sline = ''.join( + _bar[min([barcount - 1, int((n - min_num) / span * barcount)])] for n in numbers + ) + return min_num, max_num, sline + + +def distribute_strings( + strings: List[str], + *, + width: int = 80, + padding: str = " ", +) -> str: + """ + Distributes strings into a line for justified text. + + Args: + strings: a list of string tokens to distribute + width: the width of the line to create + padding: the padding character to place between string chunks + + Returns: + The distributed, justified string. + + >>> distribute_strings(['this', 'is', 'a', 'test'], width=40) + ' this is a test ' + """ + ret = ' ' + ' '.join(strings) + ' ' + assert len(string_utils.strip_ansi_sequences(ret)) < width + x = 0 + while len(string_utils.strip_ansi_sequences(ret)) < width: + spaces = [m.start() for m in re.finditer(r' ([^ ]|$)', ret)] + where = spaces[x] + before = ret[:where] + after = ret[where:] + ret = before + padding + after + x += 1 + if x >= len(spaces): + x = 0 + return ret + + +def _justify_string_by_chunk(string: str, width: int = 80, padding: str = " ") -> str: + """ + Justifies a string chunk by chunk. + + Args: + string: the string to be justified + width: how wide to make the output + padding: what padding character to use between chunks + + Returns: + the justified string + + >>> _justify_string_by_chunk("This is a test", 40) + 'This is a test' + >>> _justify_string_by_chunk("This is a test", 20) + 'This is a test' + + """ + assert len(string_utils.strip_ansi_sequences(string)) <= width + padding = padding[0] + first, *rest, last = string.split() + w = width - ( + len(string_utils.strip_ansi_sequences(first)) + + len(string_utils.strip_ansi_sequences(last)) + ) + ret = first + distribute_strings(rest, width=w, padding=padding) + last + return ret + + +def justify_string( + string: str, *, width: int = 80, alignment: str = "c", padding: str = " " +) -> str: + """Justify a string to width with left, right, center of justified + alignment. + + Args: + string: the string to justify + width: the width to justify the string to + alignment: a single character indicating the desired alignment: + * 'c' = centered within the width + * 'j' = justified at width + * 'l' = left alignment + * 'r' = right alignment + padding: the padding character to use while justifying + + >>> justify_string('This is another test', width=40, alignment='c') + ' This is another test ' + >>> justify_string('This is another test', width=40, alignment='l') + 'This is another test ' + >>> justify_string('This is another test', width=40, alignment='r') + ' This is another test' + >>> justify_string('This is another test', width=40, alignment='j') + 'This is another test' + """ + alignment = alignment[0] + padding = padding[0] + while len(string_utils.strip_ansi_sequences(string)) < width: + if alignment == "l": + string += padding + elif alignment == "r": + string = padding + string + elif alignment == "j": + return _justify_string_by_chunk(string, width=width, padding=padding) + elif alignment == "c": + if len(string) % 2 == 0: + string += padding + else: + string = padding + string + else: + raise ValueError + return string + + +def justify_text( + text: str, *, width: int = 80, alignment: str = "c", indent_by: int = 0 +) -> str: + """Justifies text with left, right, centered or justified alignment + and optionally with initial indentation. + + Args: + text: the text to be justified + width: the width at which to justify text + alignment: a single character indicating the desired alignment: + * 'c' = centered within the width + * 'j' = justified at width + * 'l' = left alignment + * 'r' = right alignment + indent_by: if non-zero, adds n prefix spaces to indent the text. + + Returns: + The justified text. + + >>> justify_text('This is a test of the emergency broadcast system. This is only a test.', + ... width=40, alignment='j') #doctest: +NORMALIZE_WHITESPACE + 'This is a test of the emergency\\nbroadcast system. This is only a test.' + + """ + retval = '' + indent = '' + if indent_by > 0: + indent += ' ' * indent_by + line = indent + + for word in text.split(): + if ( + len(string_utils.strip_ansi_sequences(line)) + + len(string_utils.strip_ansi_sequences(word)) + ) > width: + line = line[1:] + line = justify_string(line, width=width, alignment=alignment) + retval = retval + '\n' + line + line = indent + line = line + ' ' + word + if len(string_utils.strip_ansi_sequences(line)) > 0: + if alignment != 'j': + retval += "\n" + justify_string(line[1:], width=width, alignment=alignment) + else: + retval += "\n" + line[1:] + return retval[1:] + + +def generate_padded_columns(text: List[str]) -> Generator: + """Given a list of strings, break them into columns using :meth:`split` + and then compute the maximum width of each column. Finally, + distribute the columular chunks into the output padding each to + the proper width. + + Args: + text: a list of strings to chunk into padded columns + + Returns: + padded columns based on text.split() + + >>> for x in generate_padded_columns( + ... [ 'reading writing arithmetic', + ... 'mathematics psychology physics', + ... 'communications sociology anthropology' ]): + ... print(x.strip()) + reading writing arithmetic + mathematics psychology physics + communications sociology anthropology + """ + max_width: Dict[int, int] = defaultdict(int) + for line in text: + for pos, word in enumerate(line.split()): + max_width[pos] = max( + max_width[pos], len(string_utils.strip_ansi_sequences(word)) + ) + + for line in text: + out = "" + for pos, word in enumerate(line.split()): + width = max_width[pos] + word = justify_string(word, width=width, alignment='l') + out += f'{word} ' + yield out + + +def wrap_string(text: str, n: int) -> str: + """ + Args: + text: the string to be wrapped + n: the width after which to wrap text + + Returns: + The wrapped form of text + """ + chunks = text.split() + out = '' + width = 0 + for chunk in chunks: + if width + len(string_utils.strip_ansi_sequences(chunk)) > n: + out += '\n' + width = 0 + out += chunk + ' ' + width += len(string_utils.strip_ansi_sequences(chunk)) + 1 + return out + + +class Indenter(contextlib.AbstractContextManager): + """ + Context manager that indents stuff (even recursively). e.g.:: + + with Indenter(pad_count = 8) as i: + i.print('test') + with i: + i.print('-ing') + with i: + i.print('1, 2, 3') + + Yields:: + + test + -ing + 1, 2, 3 + """ + + def __init__( + self, + *, + pad_prefix: Optional[str] = None, + pad_char: str = ' ', + pad_count: int = 4, + ): + """Construct an Indenter. + + Args: + pad_prefix: an optional prefix to prepend to each line + pad_char: the character used to indent + pad_count: the number of pad_chars to use to indent + """ + self.level = -1 + if pad_prefix is not None: + self.pad_prefix = pad_prefix + else: + self.pad_prefix = '' + self.padding = pad_char * pad_count + + def __enter__(self): + self.level += 1 + return self + + def __exit__(self, exc_type, exc_value, exc_tb) -> Literal[False]: + self.level -= 1 + if self.level < -1: + self.level = -1 + return False + + def print(self, *arg, **kwargs): + text = string_utils.sprintf(*arg, **kwargs) + print(self.pad_prefix + self.padding * self.level + text, end='') + + +def header( + title: str, + *, + width: Optional[int] = None, + align: Optional[str] = None, + style: Optional[str] = 'solid', + color: Optional[str] = None, +): + """ + Creates a nice header line with a title. + + Args: + title: the title + width: how wide to make the header + align: "left" or "right" + style: "ascii", "solid" or "dashed" + + Returns: + The header as a string. + + >>> header('title', width=60, style='ascii') + '----[ title ]-----------------------------------------------' + """ + if not width: + try: + width = get_console_rows_columns().columns + except Exception: + width = 80 + if not align: + align = 'left' + if not style: + style = 'ascii' + + text_len = len(string_utils.strip_ansi_sequences(title)) + if align == 'left': + left = 4 + right = width - (left + text_len + 4) + elif align == 'right': + right = 4 + left = width - (right + text_len + 4) + else: + left = int((width - (text_len + 4)) / 2) + right = left + while left + text_len + 4 + right < width: + right += 1 + + if style == 'solid': + line_char = '━' + begin = '' + end = '' + elif style == 'dashed': + line_char = '┅' + begin = '' + end = '' + else: + line_char = '-' + begin = '[' + end = ']' + if color: + col = color + reset_seq = reset() + else: + col = '' + reset_seq = '' + return ( + line_char * left + + begin + + col + + ' ' + + title + + ' ' + + reset_seq + + end + + line_char * right + ) + + +def box( + title: Optional[str] = None, + text: Optional[str] = None, + *, + width: int = 80, + color: str = '', +) -> str: + """ + Make a nice unicode box (optionally with color) around some text. + + Args: + title: the title of the box + text: the text in the box + width: the box's width + color: the box's color + + Returns: + the box as a string + + >>> print(box('title', 'this is some text', width=20).strip()) + ╭──────────────────╮ + │ title │ + │ │ + │ this is some │ + │ text │ + ╰──────────────────╯ + """ + assert width > 4 + if text is not None: + text = justify_text(text, width=width - 4, alignment='l') + return preformatted_box(title, text, width=width, color=color) + + +def preformatted_box( + title: Optional[str] = None, + text: Optional[str] = None, + *, + width=80, + color: str = '', +) -> str: + """Creates a nice box with rounded corners and returns it as a string. + + Args: + title: the title of the box + text: the text inside the box + width: the width of the box + color: the box's color + + Returns: + the box as a string + + >>> print(preformatted_box('title', 'this\\nis\\nsome\\ntext', width=20).strip()) + ╭──────────────────╮ + │ title │ + │ │ + │ this │ + │ is │ + │ some │ + │ text │ + ╰──────────────────╯ + """ + assert width > 4 + ret = '' + if color == '': + rset = '' + else: + rset = reset() + w = width - 2 + ret += color + '╭' + '─' * w + '╮' + rset + '\n' + if title is not None: + ret += ( + color + + '│' + + rset + + justify_string(title, width=w, alignment='c') + + color + + '│' + + rset + + '\n' + ) + ret += color + '│' + ' ' * w + '│' + rset + '\n' + if text is not None: + for line in text.split('\n'): + tw = len(string_utils.strip_ansi_sequences(line)) + assert tw <= w + ret += ( + color + + '│ ' + + rset + + line + + ' ' * (w - tw - 2) + + color + + ' │' + + rset + + '\n' + ) + ret += color + '╰' + '─' * w + '╯' + rset + '\n' + return ret + + +def print_box( + title: Optional[str] = None, + text: Optional[str] = None, + *, + width: int = 80, + color: str = '', +) -> None: + """Draws a box with nice rounded corners. + + >>> print_box('Title', 'This is text', width=30) + ╭────────────────────────────╮ + │ Title │ + │ │ + │ This is text │ + ╰────────────────────────────╯ + + >>> print_box(None, 'OK', width=6) + ╭────╮ + │ OK │ + ╰────╯ + """ + print(preformatted_box(title, text, width=width, color=color), end='') + + +if __name__ == '__main__': + import doctest + + doctest.testmod() diff --git a/src/pyutils/typez/__init__.py b/src/pyutils/typez/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/pyutils/typez/centcount.py b/src/pyutils/typez/centcount.py new file mode 100644 index 0000000..b37898f --- /dev/null +++ b/src/pyutils/typez/centcount.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""An amount of money (USD) represented as an integral count of +cents.""" + +import re +from typing import Optional, Tuple + +from pyutils import math_utils + + +class CentCount(object): + """A class for representing monetary amounts potentially with + different currencies meant to avoid floating point rounding + issues by treating amount as a simple integral count of cents. + """ + + def __init__(self, centcount, currency: str = 'USD', *, strict_mode=False): + self.strict_mode = strict_mode + if isinstance(centcount, str): + ret = CentCount._parse(centcount) + if ret is None: + raise Exception(f'Unable to parse money string "{centcount}"') + centcount = ret[0] + currency = ret[1] + if isinstance(centcount, float): + centcount = int(centcount * 100.0) + if not isinstance(centcount, int): + centcount = int(centcount) + self.centcount = centcount + if not currency: + self.currency: Optional[str] = None + else: + self.currency = currency + + def __repr__(self): + a = float(self.centcount) + a /= 100 + a = round(a, 2) + s = f'{a:,.2f}' + if self.currency is not None: + return f'{s} {self.currency}' + else: + return f'${s}' + + def __pos__(self): + return CentCount(centcount=self.centcount, currency=self.currency) + + def __neg__(self): + return CentCount(centcount=-self.centcount, currency=self.currency) + + def __add__(self, other): + if isinstance(other, CentCount): + if self.currency == other.currency: + return CentCount( + centcount=self.centcount + other.centcount, + currency=self.currency, + ) + else: + raise TypeError('Incompatible currencies in add expression') + else: + if self.strict_mode: + raise TypeError('In strict_mode only two moneys can be added') + else: + return self.__add__(CentCount(other, self.currency)) + + def __sub__(self, other): + if isinstance(other, CentCount): + if self.currency == other.currency: + return CentCount( + centcount=self.centcount - other.centcount, + currency=self.currency, + ) + else: + raise TypeError('Incompatible currencies in add expression') + else: + if self.strict_mode: + raise TypeError('In strict_mode only two moneys can be added') + else: + return self.__sub__(CentCount(other, self.currency)) + + def __mul__(self, other): + if isinstance(other, CentCount): + raise TypeError('can not multiply monetary quantities') + else: + return CentCount( + centcount=int(self.centcount * float(other)), + currency=self.currency, + ) + + def __truediv__(self, other): + if isinstance(other, CentCount): + raise TypeError('can not divide monetary quantities') + else: + return CentCount( + centcount=int(float(self.centcount) / float(other)), + currency=self.currency, + ) + + def __int__(self): + return self.centcount.__int__() + + def __float__(self): + return self.centcount.__float__() / 100.0 + + def truncate_fractional_cents(self): + x = int(self) + self.centcount = int(math_utils.truncate_float(x)) + return self.centcount + + def round_fractional_cents(self): + x = int(self) + self.centcount = int(round(x, 2)) + return self.centcount + + __radd__ = __add__ + + def __rsub__(self, other): + if isinstance(other, CentCount): + if self.currency == other.currency: + return CentCount( + centcount=other.centcount - self.centcount, + currency=self.currency, + ) + else: + raise TypeError('Incompatible currencies in sub expression') + else: + if self.strict_mode: + raise TypeError('In strict_mode only two moneys can be added') + else: + return CentCount( + centcount=int(other) - self.centcount, + currency=self.currency, + ) + + __rmul__ = __mul__ + + # + # Override comparison operators to also compare currency. + # + def __eq__(self, other): + if other is None: + return False + if isinstance(other, CentCount): + return self.centcount == other.centcount and self.currency == other.currency + if self.strict_mode: + raise TypeError("In strict mode only two CentCounts can be compared") + else: + return self.centcount == int(other) + + def __ne__(self, other): + result = self.__eq__(other) + if result is NotImplemented: + return result + return not result + + def __lt__(self, other): + if isinstance(other, CentCount): + if self.currency == other.currency: + return self.centcount < other.centcount + else: + raise TypeError('can not directly compare different currencies') + else: + if self.strict_mode: + raise TypeError('In strict mode, only two CentCounts can be compated') + else: + return self.centcount < int(other) + + def __gt__(self, other): + if isinstance(other, CentCount): + if self.currency == other.currency: + return self.centcount > other.centcount + else: + raise TypeError('can not directly compare different currencies') + else: + if self.strict_mode: + raise TypeError('In strict mode, only two CentCounts can be compated') + else: + return self.centcount > int(other) + + def __le__(self, other): + return self < other or self == other + + def __ge__(self, other): + return self > other or self == other + + def __hash__(self) -> int: + return hash(self.__repr__) + + CENTCOUNT_RE = re.compile(r"^([+|-]?)(\d+)(\.\d+)$") + CURRENCY_RE = re.compile(r"^[A-Z][A-Z][A-Z]$") + + @classmethod + def _parse(cls, s: str) -> Optional[Tuple[int, str]]: + centcount = None + currency = None + s = s.strip() + chunks = s.split(' ') + try: + for chunk in chunks: + if CentCount.CENTCOUNT_RE.match(chunk) is not None: + centcount = int(float(chunk) * 100.0) + elif CentCount.CURRENCY_RE.match(chunk) is not None: + currency = chunk + except Exception: + pass + if centcount is not None and currency is not None: + return (centcount, currency) + elif centcount is not None: + return (centcount, 'USD') + return None + + @classmethod + def parse(cls, s: str) -> 'CentCount': + chunks = CentCount._parse(s) + if chunks is not None: + return CentCount(chunks[0], chunks[1]) + raise Exception(f'Unable to parse money string "{s}"') diff --git a/src/pyutils/typez/histogram.py b/src/pyutils/typez/histogram.py new file mode 100644 index 0000000..2887525 --- /dev/null +++ b/src/pyutils/typez/histogram.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# © Copyright 2021-2022, Scott Gasch + +"""A text-based simple histogram helper class.""" + +import math +from dataclasses import dataclass +from typing import Dict, Generic, Iterable, List, Optional, Tuple, TypeVar + +T = TypeVar("T", int, float) +Bound = int +Count = int + + +@dataclass +class BucketDetails: + """A collection of details about the internal histogram buckets.""" + + num_populated_buckets: int = 0 + """Count of populated buckets""" + + max_population: Optional[int] = None + """The max population in a bucket currently""" + + last_bucket_start: Optional[int] = None + """The last bucket starting point""" + + lowest_start: Optional[int] = None + """The lowest populated bucket's starting point""" + + highest_end: Optional[int] = None + """The highest populated bucket's ending point""" + + max_label_width: Optional[int] = None + """The maximum label width (for display purposes)""" + + +class SimpleHistogram(Generic[T]): + """A simple histogram.""" + + # Useful in defining wide open bottom/top bucket bounds: + POSITIVE_INFINITY = math.inf + NEGATIVE_INFINITY = -math.inf + + def __init__(self, buckets: List[Tuple[Bound, Bound]]): + """C'tor. + + Args: + buckets: a list of [start..end] tuples that define the + buckets we are counting population in. See also + :meth:`n_evenly_spaced_buckets` to generate these + buckets more easily. + """ + from pyutils.math_utils import NumericPopulation + + self.buckets: Dict[Tuple[Bound, Bound], Count] = {} + for start_end in buckets: + if self._get_bucket(start_end[0]) is not None: + raise Exception("Buckets overlap?!") + self.buckets[start_end] = 0 + self.sigma: float = 0.0 + self.stats: NumericPopulation = NumericPopulation() + self.maximum: Optional[T] = None + self.minimum: Optional[T] = None + self.count: Count = 0 + + @staticmethod + def n_evenly_spaced_buckets( + min_bound: T, + max_bound: T, + n: int, + ) -> List[Tuple[int, int]]: + """A helper method for generating the buckets argument to + our c'tor provided that you want N evenly spaced buckets. + + Args: + min_bound: the minimum possible value + max_bound: the maximum possible value + n: how many buckets to create + + Returns: + A list of bounds that define N evenly spaced buckets + """ + ret: List[Tuple[int, int]] = [] + stride = int((max_bound - min_bound) / n) + if stride <= 0: + raise Exception("Min must be < Max") + imax = math.ceil(max_bound) + imin = math.floor(min_bound) + for bucket_start in range(imin, imax, stride): + ret.append((bucket_start, bucket_start + stride)) + return ret + + def _get_bucket(self, item: T) -> Optional[Tuple[int, int]]: + """Given an item, what bucket is it in?""" + for start_end in self.buckets: + if start_end[0] <= item < start_end[1]: + return start_end + return None + + def add_item(self, item: T) -> bool: + """Adds a single item to the histogram (reculting in us incrementing + the population in the correct bucket. + + Args: + item: the item to be added + + Returns: + True if the item was successfully added or False if the item + is not within the bounds established during class construction. + """ + bucket = self._get_bucket(item) + if bucket is None: + return False + self.count += 1 + self.buckets[bucket] += 1 + self.sigma += item + self.stats.add_number(item) + if self.maximum is None or item > self.maximum: + self.maximum = item + if self.minimum is None or item < self.minimum: + self.minimum = item + return True + + def add_items(self, lst: Iterable[T]) -> bool: + """Adds a collection of items to the histogram and increments + the correct bucket's population for each item. + + Args: + lst: An iterable of items to be added + + Returns: + True if all items were added successfully or False if any + item was not able to be added because it was not within the + bounds established at object construction. + """ + all_true = True + for item in lst: + all_true = all_true and self.add_item(item) + return all_true + + def _get_bucket_details(self, label_formatter: str) -> BucketDetails: + """Get the details about one bucket.""" + details = BucketDetails() + for (start, end), pop in sorted(self.buckets.items(), key=lambda x: x[0]): + if pop > 0: + details.num_populated_buckets += 1 + details.last_bucket_start = start + if details.max_population is None or pop > details.max_population: + details.max_population = pop + if details.lowest_start is None or start < details.lowest_start: + details.lowest_start = start + if details.highest_end is None or end > details.highest_end: + details.highest_end = end + label = f'[{label_formatter}..{label_formatter}): ' % (start, end) + label_width = len(label) + if ( + details.max_label_width is None + or label_width > details.max_label_width + ): + details.max_label_width = label_width + return details + + def __repr__(self, *, width: int = 80, label_formatter: str = '%d') -> str: + """Returns a pretty (text) representation of the histogram and + some vital stats about the population in it (min, max, mean, + median, mode, stdev, etc...) + """ + from pyutils.text_utils import BarGraphText, bar_graph_string + + details = self._get_bucket_details(label_formatter) + txt = "" + if details.num_populated_buckets == 0: + return txt + assert details.max_label_width is not None + assert details.lowest_start is not None + assert details.highest_end is not None + assert details.max_population is not None + sigma_label = f'[{label_formatter}..{label_formatter}): ' % ( + details.lowest_start, + details.highest_end, + ) + if len(sigma_label) > details.max_label_width: + details.max_label_width = len(sigma_label) + bar_width = width - (details.max_label_width + 17) + + for (start, end), pop in sorted(self.buckets.items(), key=lambda x: x[0]): + if start < details.lowest_start: + continue + label = f'[{label_formatter}..{label_formatter}): ' % (start, end) + bar = bar_graph_string( + pop, + details.max_population, + text=BarGraphText.NONE, + width=bar_width, + left_end="", + right_end="", + ) + txt += label.rjust(details.max_label_width) + txt += bar + txt += f"({pop/self.count*100.0:5.2f}% n={pop})\n" + if start == details.last_bucket_start: + break + txt += '-' * width + '\n' + txt += sigma_label.rjust(details.max_label_width) + txt += ' ' * (bar_width - 2) + txt += f' pop(Σn)={self.count}\n' + txt += ' ' * (bar_width + details.max_label_width - 2) + txt += f' mean(x̄)={self.stats.get_mean():.3f}\n' + txt += ' ' * (bar_width + details.max_label_width - 2) + txt += f' median(p50)={self.stats.get_median():.3f}\n' + txt += ' ' * (bar_width + details.max_label_width - 2) + txt += f' mode(Mo)={self.stats.get_mode()[0]:.3f}\n' + txt += ' ' * (bar_width + details.max_label_width - 2) + txt += f' stdev(σ)={self.stats.get_stdev():.3f}\n' + txt += '\n' + return txt diff --git a/src/pyutils/typez/money.py b/src/pyutils/typez/money.py new file mode 100644 index 0000000..47c0a8e --- /dev/null +++ b/src/pyutils/typez/money.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""A class to represent money. See also centcount.py""" + +import re +from decimal import Decimal +from typing import Optional, Tuple + +from pyutils import math_utils + + +class Money(object): + """A class for representing monetary amounts potentially with + different currencies. + """ + + def __init__( + self, + amount: Decimal = Decimal("0"), + currency: str = 'USD', + *, + strict_mode=False, + ): + self.strict_mode = strict_mode + if isinstance(amount, str): + ret = Money._parse(amount) + if ret is None: + raise Exception(f'Unable to parse money string "{amount}"') + amount = ret[0] + currency = ret[1] + if not isinstance(amount, Decimal): + amount = Decimal(float(amount)) + self.amount = amount + if not currency: + self.currency: Optional[str] = None + else: + self.currency = currency + + def __repr__(self): + a = float(self.amount) + a = round(a, 2) + s = f'{a:,.2f}' + if self.currency is not None: + return f'{s} {self.currency}' + else: + return f'${s}' + + def __pos__(self): + return Money(amount=self.amount, currency=self.currency) + + def __neg__(self): + return Money(amount=-self.amount, currency=self.currency) + + def __add__(self, other): + if isinstance(other, Money): + if self.currency == other.currency: + return Money(amount=self.amount + other.amount, currency=self.currency) + else: + raise TypeError('Incompatible currencies in add expression') + else: + if self.strict_mode: + raise TypeError('In strict_mode only two moneys can be added') + else: + return Money( + amount=self.amount + Decimal(float(other)), + currency=self.currency, + ) + + def __sub__(self, other): + if isinstance(other, Money): + if self.currency == other.currency: + return Money(amount=self.amount - other.amount, currency=self.currency) + else: + raise TypeError('Incompatible currencies in add expression') + else: + if self.strict_mode: + raise TypeError('In strict_mode only two moneys can be added') + else: + return Money( + amount=self.amount - Decimal(float(other)), + currency=self.currency, + ) + + def __mul__(self, other): + if isinstance(other, Money): + raise TypeError('can not multiply monetary quantities') + else: + return Money( + amount=self.amount * Decimal(float(other)), + currency=self.currency, + ) + + def __truediv__(self, other): + if isinstance(other, Money): + raise TypeError('can not divide monetary quantities') + else: + return Money( + amount=self.amount / Decimal(float(other)), + currency=self.currency, + ) + + def __float__(self): + return self.amount.__float__() + + def truncate_fractional_cents(self): + x = float(self) + self.amount = Decimal(math_utils.truncate_float(x)) + return self.amount + + def round_fractional_cents(self): + x = float(self) + self.amount = Decimal(round(x, 2)) + return self.amount + + __radd__ = __add__ + + def __rsub__(self, other): + if isinstance(other, Money): + if self.currency == other.currency: + return Money(amount=other.amount - self.amount, currency=self.currency) + else: + raise TypeError('Incompatible currencies in sub expression') + else: + if self.strict_mode: + raise TypeError('In strict_mode only two moneys can be added') + else: + return Money( + amount=Decimal(float(other)) - self.amount, + currency=self.currency, + ) + + __rmul__ = __mul__ + + # + # Override comparison operators to also compare currency. + # + def __eq__(self, other): + if other is None: + return False + if isinstance(other, Money): + return self.amount == other.amount and self.currency == other.currency + if self.strict_mode: + raise TypeError("In strict mode only two Moneys can be compared") + else: + return self.amount == Decimal(float(other)) + + def __ne__(self, other): + result = self.__eq__(other) + if result is NotImplemented: + return result + return not result + + def __lt__(self, other): + if isinstance(other, Money): + if self.currency == other.currency: + return self.amount < other.amount + else: + raise TypeError('can not directly compare different currencies') + else: + if self.strict_mode: + raise TypeError('In strict mode, only two Moneys can be compated') + else: + return self.amount < Decimal(float(other)) + + def __gt__(self, other): + if isinstance(other, Money): + if self.currency == other.currency: + return self.amount > other.amount + else: + raise TypeError('can not directly compare different currencies') + else: + if self.strict_mode: + raise TypeError('In strict mode, only two Moneys can be compated') + else: + return self.amount > Decimal(float(other)) + + def __le__(self, other): + return self < other or self == other + + def __ge__(self, other): + return self > other or self == other + + def __hash__(self) -> int: + return hash(self.__repr__) + + AMOUNT_RE = re.compile(r"^([+|-]?)(\d+)(\.\d+)$") + CURRENCY_RE = re.compile(r"^[A-Z][A-Z][A-Z]$") + + @classmethod + def _parse(cls, s: str) -> Optional[Tuple[Decimal, str]]: + amount = None + currency = None + s = s.strip() + chunks = s.split(' ') + try: + for chunk in chunks: + if Money.AMOUNT_RE.match(chunk) is not None: + amount = Decimal(chunk) + elif Money.CURRENCY_RE.match(chunk) is not None: + currency = chunk + except Exception: + pass + if amount is not None and currency is not None: + return (amount, currency) + elif amount is not None: + return (amount, 'USD') + return None + + @classmethod + def parse(cls, s: str) -> 'Money': + chunks = Money._parse(s) + if chunks is not None: + return Money(chunks[0], chunks[1]) + raise Exception(f'Unable to parse money string "{s}"') diff --git a/src/pyutils/typez/rate.py b/src/pyutils/typez/rate.py new file mode 100644 index 0000000..5fc3f65 --- /dev/null +++ b/src/pyutils/typez/rate.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""A class to represent a rate of change.""" + +from typing import Optional + + +class Rate(object): + """A class to represent a rate of change.""" + + def __init__( + self, + multiplier: Optional[float] = None, + *, + percentage: Optional[float] = None, + percent_change: Optional[float] = None, + ): + count = 0 + if multiplier is not None: + if isinstance(multiplier, str): + multiplier = multiplier.replace('%', '') + m = float(multiplier) + m /= 100 + self.multiplier: float = m + else: + self.multiplier = multiplier + count += 1 + if percentage is not None: + self.multiplier = percentage / 100 + count += 1 + if percent_change is not None: + self.multiplier = 1.0 + percent_change / 100 + count += 1 + if count != 1: + raise Exception( + 'Exactly one of percentage, percent_change or multiplier is required.' + ) + + def apply_to(self, other): + return self.__mul__(other) + + def of(self, other): + return self.__mul__(other) + + def __float__(self): + return self.multiplier + + def __mul__(self, other): + return self.multiplier * float(other) + + __rmul__ = __mul__ + + def __truediv__(self, other): + return self.multiplier / float(other) + + def __add__(self, other): + return self.multiplier + float(other) + + __radd__ = __add__ + + def __sub__(self, other): + return self.multiplier - float(other) + + def __eq__(self, other): + return self.multiplier == float(other) + + def __ne__(self, other): + return not self.__eq__(other) + + def __lt__(self, other): + return self.multiplier < float(other) + + def __gt__(self, other): + return self.multiplier > float(other) + + def __le__(self, other): + return self < other or self == other + + def __ge__(self, other): + return self > other or self == other + + def __hash__(self): + return self.multiplier + + def __repr__(self, *, relative=False, places=3): + if relative: + percentage = (self.multiplier - 1.0) * 100.0 + else: + percentage = self.multiplier * 100.0 + return f'{percentage:+.{places}f}%' diff --git a/src/pyutils/typez/type_utils.py b/src/pyutils/typez/type_utils.py new file mode 100644 index 0000000..e760dba --- /dev/null +++ b/src/pyutils/typez/type_utils.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""Utility functions for dealing with typing.""" + +import logging +from typing import Any, Optional + +logger = logging.getLogger(__name__) + + +def unwrap_optional(x: Optional[Any]) -> Any: + """Unwrap an Optional[Type] argument returning a Type value back. + Use this to satisfy most type checkers that a value that could be + None isn't so as to drop the Optional typing hint. + + Args: + x: an Optional[Type] argument + + Returns: + If the Optional[Type] argument is non-None, return it. + If the Optional[Type] argument is None, however, raise an + exception. + + >>> x: Optional[bool] = True + >>> unwrap_optional(x) + True + + >>> y: Optional[str] = None + >>> unwrap_optional(y) + Traceback (most recent call last): + ... + AssertionError: Argument to unwrap_optional was unexpectedly None + """ + if x is None: + msg = 'Argument to unwrap_optional was unexpectedly None' + logger.critical(msg) + raise AssertionError(msg) + return x + + +if __name__ == '__main__': + import doctest + + doctest.testmod() diff --git a/src/pyutils/unittest_utils.py b/src/pyutils/unittest_utils.py new file mode 100644 index 0000000..20b87cd --- /dev/null +++ b/src/pyutils/unittest_utils.py @@ -0,0 +1,359 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""Helpers for unittests. + +.. note:: + + When you import this we automatically wrap unittest.main() + with a call to bootstrap.initialize so that we getLogger + config, commandline args, logging control, etc... this works + fine but it's a little hacky so caveat emptor. + +""" + +import contextlib +import functools +import inspect +import logging +import os +import pickle +import random +import statistics +import tempfile +import time +import unittest +import warnings +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List, Literal, Optional + +from pyutils import bootstrap, config, function_utils + +logger = logging.getLogger(__name__) +cfg = config.add_commandline_args( + f'Logging ({__file__})', 'Args related to function decorators' +) +cfg.add_argument( + '--unittests_ignore_perf', + action='store_true', + default=False, + help='Ignore unittest perf regression in @check_method_for_perf_regressions', +) +cfg.add_argument( + '--unittests_num_perf_samples', + type=int, + default=50, + help='The count of perf timing samples we need to see before blocking slow runs on perf grounds', +) +cfg.add_argument( + '--unittests_drop_perf_traces', + type=str, + nargs=1, + default=None, + help='The identifier (i.e. file!test_fixture) for which we should drop all perf data', +) +cfg.add_argument( + '--unittests_persistance_strategy', + choices=['FILE', 'DATABASE'], + default='FILE', + help='Should we persist perf data in a file or db?', +) +cfg.add_argument( + '--unittests_perfdb_filename', + type=str, + metavar='FILENAME', + default=f'{os.environ["HOME"]}/.python_unittest_performance_db', + help='File in which to store perf data (iff --unittests_persistance_strategy is FILE)', +) +cfg.add_argument( + '--unittests_perfdb_spec', + type=str, + metavar='DBSPEC', + default='mariadb+pymysql://python_unittest:@db.house:3306/python_unittest_performance', + help='Db connection spec for perf data (iff --unittest_persistance_strategy is DATABASE)', +) + +# >>> This is the hacky business, FYI. <<< +unittest.main = bootstrap.initialize(unittest.main) + + +class PerfRegressionDataPersister(ABC): + """A base class for a signature dealing with persisting perf + regression data.""" + + def __init__(self): + pass + + @abstractmethod + def load_performance_data(self, method_id: str) -> Dict[str, List[float]]: + pass + + @abstractmethod + def save_performance_data(self, method_id: str, data: Dict[str, List[float]]): + pass + + @abstractmethod + def delete_performance_data(self, method_id: str): + pass + + +class FileBasedPerfRegressionDataPersister(PerfRegressionDataPersister): + """A perf regression data persister that uses files.""" + + def __init__(self, filename: str): + super().__init__() + self.filename = filename + self.traces_to_delete: List[str] = [] + + def load_performance_data(self, method_id: str) -> Dict[str, List[float]]: + with open(self.filename, 'rb') as f: + return pickle.load(f) + + def save_performance_data(self, method_id: str, data: Dict[str, List[float]]): + for trace in self.traces_to_delete: + if trace in data: + data[trace] = [] + + with open(self.filename, 'wb') as f: + pickle.dump(data, f, pickle.HIGHEST_PROTOCOL) + + def delete_performance_data(self, method_id: str): + self.traces_to_delete.append(method_id) + + +# class DatabasePerfRegressionDataPersister(PerfRegressionDataPersister): +# """A perf regression data persister that uses a database backend.""" +# +# def __init__(self, dbspec: str): +# super().__init__() +# self.dbspec = dbspec +# self.engine = sa.create_engine(self.dbspec) +# self.conn = self.engine.connect() +# +# def load_performance_data(self, method_id: str) -> Dict[str, List[float]]: +# results = self.conn.execute( +# sa.text(f'SELECT * FROM runtimes_by_function WHERE function = "{method_id}";') +# ) +# ret: Dict[str, List[float]] = {method_id: []} +# for result in results.all(): +# ret[method_id].append(result['runtime']) +# results.close() +# return ret +# +# def save_performance_data(self, method_id: str, data: Dict[str, List[float]]): +# self.delete_performance_data(method_id) +# for (mid, perf_data) in data.items(): +# sql = 'INSERT INTO runtimes_by_function (function, runtime) VALUES ' +# for perf in perf_data: +# self.conn.execute(sql + f'("{mid}", {perf});') +# +# def delete_performance_data(self, method_id: str): +# sql = f'DELETE FROM runtimes_by_function WHERE function = "{method_id}"' +# self.conn.execute(sql) + + +def check_method_for_perf_regressions(func: Callable) -> Callable: + """ + This is meant to be used on a method in a class that subclasses + unittest.TestCase. When thus decorated it will time the execution + of the code in the method, compare it with a database of + historical perfmance, and fail the test with a perf-related + message if it has become too slow. + + """ + + @functools.wraps(func) + def wrapper_perf_monitor(*args, **kwargs): + if config.config['unittests_ignore_perf']: + return func(*args, **kwargs) + + if config.config['unittests_persistance_strategy'] == 'FILE': + filename = config.config['unittests_perfdb_filename'] + helper = FileBasedPerfRegressionDataPersister(filename) + elif config.config['unittests_persistance_strategy'] == 'DATABASE': + raise NotImplementedError( + 'Persisting to a database is not implemented in this version' + ) + else: + raise Exception('Unknown/unexpected --unittests_persistance_strategy value') + + func_id = function_utils.function_identifier(func) + func_name = func.__name__ + logger.debug('Watching %s\'s performance...', func_name) + logger.debug('Canonical function identifier = "%s"', func_id) + + try: + perfdb = helper.load_performance_data(func_id) + except Exception as e: + logger.exception(e) + msg = 'Unable to load perfdb; skipping it...' + logger.warning(msg) + warnings.warn(msg) + perfdb = {} + + # cmdline arg to forget perf traces for function + drop_id = config.config['unittests_drop_perf_traces'] + if drop_id is not None: + helper.delete_performance_data(drop_id) + + # Run the wrapped test paying attention to latency. + start_time = time.perf_counter() + value = func(*args, **kwargs) + end_time = time.perf_counter() + run_time = end_time - start_time + + # See if it was unexpectedly slow. + hist = perfdb.get(func_id, []) + if len(hist) < config.config['unittests_num_perf_samples']: + hist.append(run_time) + logger.debug('Still establishing a perf baseline for %s', func_name) + else: + stdev = statistics.stdev(hist) + logger.debug('For %s, performance stdev=%.2f', func_name, stdev) + slowest = hist[-1] + logger.debug('For %s, slowest perf on record is %.2fs', func_name, slowest) + limit = slowest + stdev * 4 + logger.debug('For %s, max acceptable runtime is %.2fs', func_name, limit) + logger.debug( + 'For %s, actual observed runtime was %.2fs', func_name, run_time + ) + if run_time > limit: + msg = f'''{func_id} performance has regressed unacceptably. +{slowest:f}s is the slowest runtime on record in {len(hist)} perf samples. +It just ran in {run_time:f}s which is 4+ stdevs slower than the slowest. +Here is the current, full db perf timing distribution: + +''' + for x in hist: + msg += f'{x:f}\n' + logger.error(msg) + slf = args[0] # Peek at the wrapped function's self ref. + slf.fail(msg) # ...to fail the testcase. + else: + hist.append(run_time) + + # Don't spam the database with samples; just pick a random + # sample from what we have and store that back. + n = min(config.config['unittests_num_perf_samples'], len(hist)) + hist = random.sample(hist, n) + hist.sort() + perfdb[func_id] = hist + helper.save_performance_data(func_id, perfdb) + return value + + return wrapper_perf_monitor + + +def check_all_methods_for_perf_regressions(prefix='test_'): + """Decorate unittests with this to pay attention to the perf of the + testcode and flag perf regressions. e.g. + + import pyutils.unittest_utils as uu + + @uu.check_all_methods_for_perf_regressions() + class TestMyClass(unittest.TestCase): + + def test_some_part_of_my_class(self): + ... + + """ + + def decorate_the_testcase(cls): + if issubclass(cls, unittest.TestCase): + for name, m in inspect.getmembers(cls, inspect.isfunction): + if name.startswith(prefix): + setattr(cls, name, check_method_for_perf_regressions(m)) + logger.debug('Wrapping %s:%s.', cls.__name__, name) + return cls + + return decorate_the_testcase + + +class RecordStdout(contextlib.AbstractContextManager): + """ + Record what is emitted to stdout. + + >>> with RecordStdout() as record: + ... print("This is a test!") + >>> print({record().readline()}) + {'This is a test!\\n'} + >>> record().close() + """ + + def __init__(self) -> None: + super().__init__() + self.destination = tempfile.SpooledTemporaryFile(mode='r+') + self.recorder: Optional[contextlib.redirect_stdout] = None + + def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]: + self.recorder = contextlib.redirect_stdout(self.destination) + assert self.recorder is not None + self.recorder.__enter__() + return lambda: self.destination + + def __exit__(self, *args) -> Literal[False]: + assert self.recorder is not None + self.recorder.__exit__(*args) + self.destination.seek(0) + return False + + +class RecordStderr(contextlib.AbstractContextManager): + """ + Record what is emitted to stderr. + + >>> import sys + >>> with RecordStderr() as record: + ... print("This is a test!", file=sys.stderr) + >>> print({record().readline()}) + {'This is a test!\\n'} + >>> record().close() + """ + + def __init__(self) -> None: + super().__init__() + self.destination = tempfile.SpooledTemporaryFile(mode='r+') + self.recorder: Optional[contextlib.redirect_stdout[Any]] = None + + def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]: + self.recorder = contextlib.redirect_stderr(self.destination) # type: ignore + assert self.recorder is not None + self.recorder.__enter__() + return lambda: self.destination + + def __exit__(self, *args) -> Literal[False]: + assert self.recorder is not None + self.recorder.__exit__(*args) + self.destination.seek(0) + return False + + +class RecordMultipleStreams(contextlib.AbstractContextManager): + """ + Record the output to more than one stream. + """ + + def __init__(self, *files) -> None: + super().__init__() + self.files = [*files] + self.destination = tempfile.SpooledTemporaryFile(mode='r+') + self.saved_writes: List[Callable[..., Any]] = [] + + def __enter__(self) -> Callable[[], tempfile.SpooledTemporaryFile]: + for f in self.files: + self.saved_writes.append(f.write) + f.write = self.destination.write + return lambda: self.destination + + def __exit__(self, *args) -> Literal[False]: + for f in self.files: + f.write = self.saved_writes.pop() + self.destination.seek(0) + return False + + +if __name__ == '__main__': + import doctest + + doctest.testmod() diff --git a/src/pyutils/unscrambler.py b/src/pyutils/unscrambler.py new file mode 100644 index 0000000..cada2a0 --- /dev/null +++ b/src/pyutils/unscrambler.py @@ -0,0 +1,315 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""A fast word unscrambler library.""" + +import logging +from typing import Dict, Mapping, Optional + +from pyutils import config, decorator_utils, list_utils +from pyutils.files import file_utils + +cfg = config.add_commandline_args( + f'Unscrambler base library ({__file__})', 'A fast word unscrambler.' +) +cfg.add_argument( + "--unscrambler_default_indexfile", + help="Path to a file of signature -> word index.", + metavar="FILENAME", + default="/usr/share/dict/sparse_index", +) + +logger = logging.getLogger(__name__) + +letters_bits = 32 +letters_mask = 2**letters_bits - 1 + +fprint_bits = 52 +fprint_mask = (2**fprint_bits - 1) << letters_bits + +fprint_feature_bit = { + 'e': 0, + 'i': 2, + 'a': 4, + 'o': 6, + 'r': 8, + 'n': 10, + 't': 12, + 's': 14, + 'l': 16, + 'c': 18, + 'u': 20, + 'p': 22, + 'm': 24, + 'd': 26, + 'h': 28, + 'y': 30, + 'g': 32, + 'b': 34, + 'f': 36, + 'v': 38, + 'k': 40, + 'w': 42, + 'z': 44, + 'x': 46, + 'q': 48, + 'j': 50, +} + +letter_sigs = { + 'a': 1789368711, + 'b': 3146859322, + 'c': 43676229, + 'd': 3522623596, + 'e': 3544234957, + 'f': 3448207591, + 'g': 1282648386, + 'h': 3672791226, + 'i': 1582316135, + 'j': 4001984784, + 'k': 831769172, + 'l': 1160692746, + 'm': 2430986565, + 'n': 1873586768, + 'o': 694443915, + 'p': 1602297017, + 'q': 533722196, + 'r': 3754550193, + 's': 1859447115, + 't': 1121373020, + 'u': 2414108708, + 'v': 2693866766, + 'w': 748799881, + 'x': 2627529228, + 'y': 2376066489, + 'z': 802338724, +} + + +class Unscrambler(object): + """A class that unscrambles words quickly by computing a signature + (sig) for the word based on its position independent letter + population and then using a pregenerated index to look up known + words the same set of letters. + + Note that each instance of Unscrambler caches its index to speed + up lookups number 2..N; careless reinstantiation will by slower. + + Sigs are designed to cluster similar words near each other so both + lookup methods support a "fuzzy match" argument that can be set to + request similar words that do not match exactly in addition to any + exact matches. + """ + + def __init__(self, indexfile: Optional[str] = None): + """ + Constructs an unscrambler. + + Args: + indexfile: overrides the default indexfile location if provided + """ + + # Cached index per instance. + self.sigs = [] + self.words = [] + + filename = Unscrambler.get_indexfile(indexfile) + with open(filename, 'r') as rf: + lines = rf.readlines() + for line in lines: + line = line[:-1] + (fsig, word) = line.split('+') + isig = int(fsig, 16) + self.sigs.append(isig) + self.words.append(word) + + @staticmethod + def get_indexfile(indexfile: Optional[str]) -> str: + """Returns the current indexfile location.""" + if indexfile is None: + if 'unscrambler_default_indexfile' in config.config: + indexfile = config.config['unscrambler_default_indexfile'] + else: + indexfile = "/usr/share/dict/sparse_index" + else: + assert file_utils.file_is_readable(indexfile), f"Can't read {indexfile}" + return indexfile + + # 52 bits + @staticmethod + def _compute_word_fingerprint(population: Mapping[str, int]) -> int: + fp = 0 + for pair in sorted(population.items(), key=lambda x: x[1], reverse=True): + letter = pair[0] + if letter in fprint_feature_bit: + count = min(pair[1], 3) + shift = fprint_feature_bit[letter] + s = count << shift + fp |= s + return fp << letters_bits + + # 32 bits + @staticmethod + def _compute_word_letter_sig( + lsigs: Mapping[str, int], + word: str, + population: Mapping[str, int], + ) -> int: + sig = 0 + for pair in sorted(population.items(), key=lambda x: x[1], reverse=True): + letter = pair[0] + if letter not in lsigs: + continue + s = lsigs[letter] + count = pair[1] + if count > 1: + s <<= count + s |= count + s &= letters_mask + sig ^= s + length = min(len(word), 31) + sig ^= length << 8 + sig &= letters_mask + return sig + + # 52 + 32 bits + @staticmethod + @decorator_utils.memoized + def compute_word_sig(word: str) -> int: + """Given a word, compute its signature for subsequent lookup + operations. Signatures are computed based on the letters in + the word and their frequencies. We try to cluster "similar" + words close to each other in the signature space. + + Args: + word: the word to compute a signature for + + Returns: + The word's signature. + + >>> train = Unscrambler.compute_word_sig('train') + >>> train + 23178969883741 + + >>> retain = Unscrambler.compute_word_sig('retrain') + >>> retain + 24282502197479 + + >>> retain - train + 1103532313738 + + """ + population = list_utils.population_counts(word) + fprint = Unscrambler._compute_word_fingerprint(population) + letter_sig = Unscrambler._compute_word_letter_sig(letter_sigs, word, population) + assert fprint & letter_sig == 0 + sig = fprint | letter_sig + return sig + + @staticmethod + def repopulate( + dictfile: str = '/usr/share/dict/words', + indexfile: str = '/usr/share/dict/sparse_index', + ) -> None: + """ + Repopulates the indexfile. + + .. warning:: + + Before calling this method, change letter_sigs from the + default above unless you want to populate the same exact + files. + """ + words_by_sigs: Dict[int, str] = {} + seen = set() + with open(dictfile, "r") as f: + for word in f: + word = word.replace('\n', '') + word = word.lower() + sig = Unscrambler.compute_word_sig(word) + logger.debug("%s => 0x%x", word, sig) + if word in seen: + continue + seen.add(word) + if sig in words_by_sigs: + words_by_sigs[sig] += f",{word}" + else: + words_by_sigs[sig] = word + with open(indexfile, 'w') as f: + for sig in sorted(words_by_sigs.keys()): + word = words_by_sigs[sig] + print(f'0x{sig:x}+{word}', file=f) + + def lookup(self, word: str, *, window_size: int = 5) -> Dict[str, bool]: + """Looks up a potentially scrambled word optionally including near + "fuzzy" matches. + + Args: + word: the word to lookup + window_size: the number of nearby fuzzy matches to return + + Returns: + A dict of word -> bool containing unscrambled words with (close + to or precisely) the same letters as the input word. The bool + values in this dict indicate whether the key word is an exact + or near match. The count of entries in this dict is controlled + by the window_size param. + + >>> u = Unscrambler() + >>> u.lookup('eanycleocipd', window_size=0) + {'encyclopedia': True} + """ + sig = Unscrambler.compute_word_sig(word) + return self.lookup_by_sig(sig, window_size=window_size) + + def lookup_by_sig(self, sig: int, *, window_size: int = 5) -> Dict[str, bool]: + """Looks up a word that has already been translated into a signature by + a previous call to Unscrambler.compute_word_sig. Optionally returns + near "fuzzy" matches. + + Args: + sig: the signature of the word to lookup (see :meth:`compute_word_sig` + to generate these signatures). + window_size: the number of nearby fuzzy matches to return + + Returns: + A dict of word -> bool containing unscrambled words with (close + to or precisely) the same letters as the input word. The bool + values in this dict indicate whether the key word is an exact + or near match. The count of entries in this dict is controlled + by the window_size param. + + >>> sig = Unscrambler.compute_word_sig('sunepsapetuargiarin') + >>> sig + 18491949645300288339 + + >>> u = Unscrambler() + >>> u.lookup_by_sig(sig) + {'pupigerous': False, 'pupigenous': False, 'unpurposing': False, 'superpurgation': False, 'unsupporting': False, 'superseptuaginarian': True, 'purpurogallin': False, 'scuppaug': False, 'purpurigenous': False, 'purpurogenous': False, 'proppage': False} + """ + ret = {} + (_, location) = list_utils.binary_search(self.sigs, sig) + start = location - window_size + start = max(start, 0) + end = location + 1 + window_size + end = min(end, len(self.words)) + + for x in range(start, end): + word = self.words[x] + fsig = self.sigs[x] + if window_size > 0 or (fsig == sig): + ret[word] = fsig == sig + return ret + + +# +# To repopulate, change letter_sigs and then call Unscrambler.repopulate. +# See notes above. See also ~/bin/unscramble.py --populate_destructively. +# + + +if __name__ == "__main__": + import doctest + + doctest.testmod() diff --git a/src/pyutils/zookeeper.py b/src/pyutils/zookeeper.py new file mode 100644 index 0000000..0f5d55e --- /dev/null +++ b/src/pyutils/zookeeper.py @@ -0,0 +1,424 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# © Copyright 2022, Scott Gasch + +"""This is a module for making it easier to deal with Zookeeper / Kazoo.""" + + +import datetime +import functools +import logging +import os +import platform +import sys +import threading +from typing import Any, Callable, Optional + +from kazoo.client import KazooClient +from kazoo.exceptions import CancelledError +from kazoo.protocol.states import KazooState +from kazoo.recipe.lease import NonBlockingLease + +from pyutils import argparse_utils, config +from pyutils.files import file_utils + +logger = logging.getLogger(__name__) + +cfg = config.add_commandline_args( + f'Zookeeper ({__file__})', + 'Args related python-zookeeper interactions', +) +cfg.add_argument( + '--zookeeper_nodes', + type=str, + default=None, + help='Comma separated host:port or ip:port address(es)', +) +cfg.add_argument( + '--zookeeper_client_cert_path', + type=argparse_utils.valid_filename, + default=None, + metavar='FILENAME', + help='Path to file containing client certificate.', +) +cfg.add_argument( + '--zookeeper_client_passphrase', + type=str, + default=None, + metavar='PASSPHRASE', + help='Pass phrase for unlocking the client certificate.', +) + + +# On module load, grab what we presume to be our process' program name. +# This is used, by default, to construct internal zookeeper paths (e.g. +# to identify a lease or election). +PROGRAM_NAME: str = os.path.basename(sys.argv[0]) + + +def get_started_zk_client() -> KazooClient: + zk = KazooClient( + hosts=config.config['zookeeper_nodes'], + use_ssl=True, + verify_certs=False, + keyfile=config.config['zookeeper_client_cert_path'], + keyfile_password=config.config['zookeeper_client_passphrase'], + certfile=config.config['zookeeper_client_cert_path'], + ) + zk.start() + logger.debug('We have an active zookeeper connection.') + return zk + + +class RenewableReleasableLease(NonBlockingLease): + """This is a hacky subclass of kazoo.recipe.lease.NonBlockingLease + that adds some behaviors: + + + Ability to renew the lease if it's already held without + going through the effort of reobtaining the same lease + name. + + + Ability to release the lease if it's held and not yet + expired. + + It also is more picky than the base class in terms of when it + evaluates to "True" (indicating that the lease is held); it will + begin to evaluate to "False" as soon as the lease has expired even + if you used to hold it. This means client code should be aware + that the lease can disappear (expire) while held and it also means + that the performance of evaulating the lease (i.e. if lease:) + requires a round trip to zookeeper every time. + + Note that it is not valid to release the lease more than once + (since you no longer have it the second time). The code ignores + the 2nd..nth attempt. It's also not possible to reobtain an + expired or released lease by calling renew. Go create a new lease + object at that point. Finally, note that when you renew the lease + it will evaluate to False briefly as it is reobtained. + """ + + def __init__( + self, + client: KazooClient, + path: str, + duration: datetime.timedelta, + identifier: str = None, + utcnow=datetime.datetime.utcnow, + ): + super().__init__(client, path, duration, identifier, utcnow) + self.client = client + self.path = path + self.identifier = identifier + self.utcnow = utcnow + + def release(self) -> bool: + """Release the lease, if it's presently being held. + + Returns: + True if the lease was successfully released, + False otherwise. + """ + self.client.ensure_path(self.path) + holder_path = self.path + "/lease_holder" + lock = self.client.Lock(self.path, self.identifier) + try: + with lock: + if not self._is_lease_held_pre_locked(): + logger.debug("Can't release lease; I don't have it!") + return False + + now = self.utcnow() + if self.client.exists(holder_path): + self.client.delete(holder_path) + end_lease = now.strftime(self._date_format) + + # Release by moving end to now. + data = { + 'version': self._version, + 'holder': self.identifier, + 'end': end_lease, + } + self.client.create(holder_path, self._encode(data)) + self.obtained = False + logger.debug('Successfully released lease') + return True + + except CancelledError as e: + logger.debug('Exception %s in zookeeper?', e) + return False + + def try_renew(self, duration: datetime.timedelta) -> bool: + """Attempt to renew a lease that is currently held. Note that + this will cause self to evaluate to False briefly as the lease + is renewed. + + Args: + duration: the amount of additional time to add to the + current lease expiration. + + Returns: + True if the lease was successfully renewed, + False otherwise. + """ + + if not self.obtained: + return False + self.obtained = False + self._attempt_obtaining( + self.client, self.path, duration, self.identifier, self.utcnow + ) + return self.obtained + + def _is_lease_held_pre_locked(self) -> bool: + self.client.ensure_path(self.path) + holder_path = self.path + "/lease_holder" + now = self.utcnow() + if self.client.exists(holder_path): + raw, _ = self.client.get(holder_path) + data = self._decode(raw) + if data["version"] != self._version: + return False + current_end = datetime.datetime.strptime(data['end'], self._date_format) + if data['holder'] == self.identifier and now <= current_end: + logger.debug('Yes, we hold the lease and it isn\'t expired.') + return True + return False + + def __bool__(self): + """Note that this implementation differs from that of the base + class in that it probes zookeeper to ensure that the lease is + not yet expired and is therefore more expensive. + """ + + if not self.obtained: + return False + lock = self.client.Lock(self.path, self.identifier) + try: + with lock: + ret = self._is_lease_held_pre_locked() + except CancelledError: + return False + return ret + + +def obtain_lease( + f: Optional[Callable] = None, + *, + lease_id: str = PROGRAM_NAME, + contender_id: str = platform.node(), + duration: datetime.timedelta = datetime.timedelta(minutes=5), + also_pass_lease: bool = False, + also_pass_zk_client: bool = False, +): + """Obtain an exclusive lease identified by the lease_id name + before invoking a function or skip invoking the function if the + lease cannot be obtained. + + Note that we use a hacky "RenewableReleasableLease" and not the + kazoo NonBlockingLease because the former allows us to release the + lease when the user code returns whereas the latter does not. + + Args: + lease_id: string identifying the lease to obtain + contender_id: string identifying who's attempting to obtain + duration: how long should the lease be held, if obtained? + also_pass_lease: pass the lease into the user function + also_pass_zk_client: pass our zk client into the user function + + >>> @obtain_lease( + ... lease_id='zookeeper_doctest', + ... duration=datetime.timedelta(seconds=5), + ... ) + ... def f(name: str) -> int: + ... print(f'Hello, {name}') + ... return 123 + + >>> f('Scott') + Hello, Scott + 123 + + """ + if not lease_id.startswith('/leases/'): + lease_id = f'/leases/{lease_id}' + lease_id = file_utils.fix_multiple_slashes(lease_id) + + def wrapper(func: Callable) -> Callable: + @functools.wraps(func) + def wrapper2(*args, **kwargs) -> Optional[Any]: + zk = get_started_zk_client() + logger.debug( + 'Trying to obtain %s for contender %s now...', + lease_id, + contender_id, + ) + lease = RenewableReleasableLease( + zk, + lease_id, + duration, + contender_id, + ) + if lease: + logger.debug( + 'Successfully obtained %s for contender %s; invoking user function.', + lease_id, + contender_id, + ) + if also_pass_zk_client: + args = (*args, zk) + if also_pass_lease: + args = (*args, lease) + ret = func(*args, *kwargs) + + # We don't care if this release operation succeeds; + # there are legitimate cases where it will fail such + # as when the user code has already voluntarily + # released the lease. + lease.release() + else: + logger.debug( + 'Failed to obtain %s for contender %s, shutting down.', + lease_id, + contender_id, + ) + ret = None + logger.debug('Shutting down zookeeper client.') + zk.stop() + return ret + + return wrapper2 + + if f is None: + return wrapper + else: + return wrapper(f) + + +def run_for_election( + f: Optional[Callable] = None, + *, + election_id: str = PROGRAM_NAME, + contender_id: str = platform.node(), + also_pass_zk_client: bool = False, +): + """Run as a contender for a leader election. If/when we become + the leader, invoke the user's function. + + The user's function will be executed on a new thread and must + accept a "stop processing" event that it must check regularly. + This event will be set automatically by the wrapper in the event + that we lose connection to zookeeper (and hence are no longer + confident that we are still the leader). + + The user's function may return at any time which will cause + the wrapper to also return and effectively cede leadership. + + Because the user's code is run in a separate thread, it may + not return anything / whatever it returns will be dropped. + + Args: + election_id: global string identifier for the election + contender_id: string identifying who is running for leader + also_pass_zk_client: pass the zk client into the user code + + >>> @run_for_election( + ... election_id='zookeeper_doctest', + ... also_pass_zk_client=True + ... ) + ... def g(name: str, zk: KazooClient, stop_now: threading.Event): + ... import time + ... count = 0 + ... while True: + ... print(f"Hello, {name}, I'm the leader.") + ... if stop_now.is_set(): + ... print("Oops, not anymore?!") + ... return + ... time.sleep(0.1) + ... count += 1 + ... if count >= 3: + ... print("I'm sick of being leader.") + ... return + + >>> g("Scott") + Hello, Scott, I'm the leader. + Hello, Scott, I'm the leader. + Hello, Scott, I'm the leader. + I'm sick of being leader. + + """ + if not election_id.startswith('/elections/'): + election_id = f'/elections/{election_id}' + election_id = file_utils.fix_multiple_slashes(election_id) + + class wrapper: + """Helper wrapper class.""" + + def __init__(self, func: Callable) -> None: + functools.update_wrapper(self, func) + self.func = func + self.zk = get_started_zk_client() + self.stop_event = threading.Event() + self.stop_event.clear() + + def zk_listener(self, state: KazooState) -> None: + logger.debug('Listener received state %s.', state) + if state != KazooState.CONNECTED: + logger.debug( + 'Bad connection to zookeeper (state=%s); bailing out.', + state, + ) + self.stop_event.set() + + def runit(self, *args, **kwargs) -> None: + # Possibly augment args if requested; always pass stop_event + if also_pass_zk_client: + args = (*args, self.zk) + args = (*args, self.stop_event) + + thread = threading.Thread( + target=self.func, + args=args, + kwargs=kwargs, + ) + logger.debug( + 'Invoking user code on separate thread: %s', + thread.getName(), + ) + thread.start() + + # Periodically poll the zookeeper state (fail safe for + # listener) and the state of the child thread. + while True: + state = self.zk.client_state + if state != KazooState.CONNECTED: + logger.error( + 'Bad connection to zookeeper (state=%s); bailing out.', + state, + ) + self.stop_event.set() + logger.debug('Waiting for user thread to tear down...') + thread.join() + logger.debug('User thread exited after our notification.') + return + + thread.join(timeout=5.0) + if not thread.is_alive(): + logger.info('User thread exited on its own.') + return + + def __call__(self, *args, **kwargs): + election = self.zk.Election(election_id, contender_id) + self.zk.add_listener(self.zk_listener) + election.run(self.runit, *args, **kwargs) + self.zk.stop() + + if f is None: + return wrapper + else: + return wrapper(f) + + +if __name__ == '__main__': + import doctest + + doctest.testmod() diff --git a/tests/.coveragerc b/tests/.coveragerc new file mode 100644 index 0000000..07eaf71 --- /dev/null +++ b/tests/.coveragerc @@ -0,0 +1,2 @@ +[run] +parallel = true diff --git a/tests/.gitignore b/tests/.gitignore new file mode 100644 index 0000000..88f79a4 --- /dev/null +++ b/tests/.gitignore @@ -0,0 +1,2 @@ +.coverage +test_output/* diff --git a/tests/README b/tests/README new file mode 100644 index 0000000..4553b73 --- /dev/null +++ b/tests/README @@ -0,0 +1,29 @@ + +This directory contains the (non-doctest) testcode for pyutils (i.e. unit tests +and integration tests). It also contains a couple of helpers to run the tests. + +The easiest way to run the tests is, from within this tests/ directory, run: + + ./run_tests_serially.sh -a + +If you only want to run a subset of the tests (e.g. all doctests only) run: + + ./run_tests_serially.sh -d + +As you can tell from the name, this shell script runs the tests in serial. +If you want to go faster (and put more load on your machine), try: + + ./run_tests.py --all + +Or: + + ./run_tests.py -d + +Both of these runners store test output under ./test_output. + +Both of them can optionally use coverage (pip install coverage) to generate a +code coverage report: + + ./run_tests.py --all --coverage + +I use ./run_tests.py --all --coverage as a .git/hooks/pre-commit-hook. diff --git a/tests/collectionz/shared_dict_test.py b/tests/collectionz/shared_dict_test.py new file mode 100755 index 0000000..0a684f4 --- /dev/null +++ b/tests/collectionz/shared_dict_test.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""shared_dict unittest.""" + +import random +import unittest + +from pyutils import unittest_utils +from pyutils.collectionz.shared_dict import SharedDict +from pyutils.parallelize import parallelize as p +from pyutils.parallelize import smart_future + + +class SharedDictTest(unittest.TestCase): + @p.parallelize(method=p.Method.PROCESS) + def doit(self, n: int, dict_name: str, parent_lock_id: int): + assert id(SharedDict.LOCK) == parent_lock_id + d = SharedDict(dict_name, None) + try: + msg = f'Hello from shard {n}' + for x in range(0, 1000): + d[n] = msg + self.assertTrue(n in d) + self.assertEqual(msg, d[n]) + y = d.get(random.randrange(0, 99), None) + return n + finally: + d.close() + + def test_basic_operations(self): + dict_name = 'test_shared_dict' + d = SharedDict(dict_name, 4096) + try: + self.assertEqual(dict_name, d.get_name()) + results = [] + for n in range(100): + f = self.doit(n, d.get_name(), id(SharedDict.LOCK)) + results.append(f) + smart_future.wait_all(results) + for f in results: + self.assertTrue(f.wrapped_future.done()) + for k in d: + self.assertEqual(d[k], f'Hello from shard {k}') + assert len(d) == 100 + finally: + d.close() + d.cleanup() + + @p.parallelize(method=p.Method.PROCESS) + def add_one(self, name: str, expected_lock_id: int): + d = SharedDict(name) + self.assertEqual(id(SharedDict.LOCK), expected_lock_id) + try: + for x in range(1000): + with SharedDict.LOCK: + d["sum"] += 1 + finally: + d.close() + + def test_locking_works(self): + dict_name = 'test_shared_dict_lock' + d = SharedDict(dict_name, 4096) + try: + d["sum"] = 0 + results = [] + for n in range(10): + f = self.add_one(d.get_name(), id(SharedDict.LOCK)) + results.append(f) + smart_future.wait_all(results) + self.assertEqual(10000, d["sum"]) + finally: + d.close() + d.cleanup() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/color_vars.sh b/tests/color_vars.sh new file mode 100644 index 0000000..55e3436 --- /dev/null +++ b/tests/color_vars.sh @@ -0,0 +1,104 @@ +# https://en.wikipedia.org/wiki/ANSI_escape_code +# https://coolors.co/aa4465-ec7632-dcd5a0-f9cb40-875f00-698c6e-0b3820-5fafff-2f5d9d-7f5e97 + +# Sets shell environment variables to code for ANSI colors. + +# Attributes +export NORMAL='\e[0m' +export BOLD='\e[1m' +export ITALICS='\e[3m' +export UNDERLINE='\e[4m' +export STRIKETHROUGH='\e[9m' + +# Foreground colors +export BLACK='\e[30m' +export BROWN='\e[38;2;135;95;0m' #875F00 +export BRIGHT_RED='\e[38;2;175;0;0m' +export LIGHT_RED='\e[38;2;170;68;101m' #AA4465 +export RED='\e[31m' +export DARK_RED='\e[38;5;52m' +export PINK='\e[38;2;231;90;124m' #E75A7C +export ORANGE='\e[38;2;236;118;50m' #EC7632 +export YELLOW='\e[38;2;249;203;64m' #F9CB40 +export GOLD='\e[38;5;94m' +export GREEN='\e[38;2;105;140;110m' #698C6E +export DARK_GREEN='\e[38;2;10;51;29m' #0A331D +export TEAL='\e[38;5;45m' +export CYAN='\e[38;2;95;175;255m' #5FAFFF +export BLUE='\e[38;2;47;93;157m' #2F5D9D +export DARK_BLUE='\e[38;5;18m' +export MAGENTA='\e[38;2;170;68;101m' #AA4465 +export DARK_MAGENTA='\e[38;5;63m' +export PURPLE='\e[38;2;127;94;151m' #7F5E97 +export ON_PURPLE='\e[48;2;127;94;151m' #7F5E97 +export DARK_PURPLE='\e[38;5;56m' +export WHITE='\e[37m' +export LIGHT_GRAY='\e[38;2;25;25;25m' +export LGRAY='\e[38;2;25;25;25m' +export GRAY='\e[30m' + +# 8-bit: As 256-color lookup tables became common on graphic cards, +# escape sequences were added to select from a pre-defined set of 256 +# colors: +# +# ESC[ 38;5;⟨n⟩ m Select foreground color +# ESC[ 48;5;⟨n⟩ m Select background color +# 0- 7: standard colors (as in ESC [ 30–37 m) +# 8- 15: high intensity colors (as in ESC [ 90–97 m) +# 16-231: 6 × 6 × 6 cube (216 colors): 16 + 36 × r + 6 × g + b (0 ≤ r, g, b ≤ 5) +# 232-255: grayscale from black to white in 24 steps +function make_fg_color() { + if [ $# -ge 3 ]; then + if [ "$1" -gt "5" ] || [ "$2" -gt "5" ] || [ "$3" -gt "5" ]; then + echo -ne "\e[38;2;$1;$2;$3m" + else + N=$(( 16 + 36 * $1 + 6 * $2 + $3 )) + echo -ne "\e[38;5;${N}m" + fi + elif [ $# -eq 2 ]; then + N=$(( 16 + 36 * 0 + 6 * $1 + $2 )) + echo -ne "\e[38;5;${N}m" + elif [ $# -eq 1 ]; then + echo -ne "\e[38;5;$1m" + else + echo -ne ${LIGHT_GRAY} + fi +} + +function make_bg_color() { + if [ $# -ge 3 ]; then + if [ "$1" -gt "5" ] || [ "$2" -gt "5" ] || [ "$3" -gt "5" ]; then + echo -ne "\e[48;2;$1;$2;$3m" + else + N=$(( 16 + 36 * $1 + 6 * $2 + $3 )) + echo -ne "\e[48;5;${N}m" + fi + elif [ $# -eq 2 ]; then + N=$(( 16 + 36 * 0 + 6 * $1 + $2 )) + echo -ne "\e[48;5;${N}m" + elif [ $# -eq 1 ]; then + echo -ne "\e[48;5;$1m" + else + echo -ne ${ON_GRAY} + fi +} + +# Backgrounds +export ON_BLACK='\e[40m' +export ON_LGRAY='\e[48;2;25;25;25m' +export ON_RED='\e[41m' +export ON_ORANGE='\e[48;2;236;118;50m' +export ON_GREEN='\e[48;2;105;140;110m' +export ON_YELLOW='\e[48;2;249;203;64m' +export ON_BLUE='\e[48;2;47;93;157m' +export ON_MAGENTA='\e[48;2;170;68;101m' +export ON_CYAN='\e[48;2;95;175;255m' +export ON_WHITE='\e[47m' +export ON_DARK_PURPLE='\e[48;5;56m' + +# Cursor color +export CURSOR='\e[38;5;208m' # ORANGY +export ON_CURSOR='\e[48;5;208m' + +# Reset sequence +export NC="\e[0m" # color reset diff --git a/tests/compress/letter_compress_test.py b/tests/compress/letter_compress_test.py new file mode 100755 index 0000000..3304dfa --- /dev/null +++ b/tests/compress/letter_compress_test.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""letter_compress unittest.""" + +import math +import random +import unittest + +from pyutils import bootstrap +from pyutils import unittest_utils as uu +from pyutils.compress import letter_compress + + +class TestLetterCompress(unittest.TestCase): + def test_with_random_strings(self): + alphabet = 'abcdefghijklmnopqrstuvwxyz .,"-' + for n in range(20): + message = "" + for letter in random.choices(alphabet, k=random.randrange(10, 5000)): + message += letter + mlen = len(message) + compressed = letter_compress.compress(message) + clen = len(compressed) + self.assertEqual(math.ceil(mlen * 5.0 / 8.0), clen) + decompressed = letter_compress.decompress(compressed) + self.assertEqual( + decompressed, message, f'The bad message string was "{message}"' + ) + + +if __name__ == '__main__': + bootstrap.initialize(unittest.main)() diff --git a/tests/datetimez/dateparse_utils_test.py b/tests/datetimez/dateparse_utils_test.py new file mode 100755 index 0000000..93c7b96 --- /dev/null +++ b/tests/datetimez/dateparse_utils_test.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""dateparse_utils unittest.""" + +import datetime +import random +import re +import unittest + +import pytz + +import pyutils.datetimez.dateparse_utils as du +import pyutils.unittest_utils as uu + +parsable_expressions = [ + ('today', datetime.datetime(2021, 7, 2)), + ('tomorrow', datetime.datetime(2021, 7, 3)), + ('yesterday', datetime.datetime(2021, 7, 1)), + ('21:30', datetime.datetime(2021, 7, 2, 21, 30, 0, 0)), + ('12:01am', datetime.datetime(2021, 7, 2, 0, 1, 0, 0)), + ('12:02p', datetime.datetime(2021, 7, 2, 12, 2, 0, 0)), + ('0:03', datetime.datetime(2021, 7, 2, 0, 3, 0, 0)), + ('last wednesday', datetime.datetime(2021, 6, 30)), + ('this wed', datetime.datetime(2021, 7, 7)), + ('next wed', datetime.datetime(2021, 7, 14)), + ('this coming tues', datetime.datetime(2021, 7, 6)), + ('this past monday', datetime.datetime(2021, 6, 28)), + ('4 days ago', datetime.datetime(2021, 6, 28)), + ('4 mondays ago', datetime.datetime(2021, 6, 7)), + ('4 months ago', datetime.datetime(2021, 3, 2)), + ('3 days back', datetime.datetime(2021, 6, 29)), + ('13 weeks from now', datetime.datetime(2021, 10, 1)), + ('1 year from now', datetime.datetime(2022, 7, 2)), + ('4 weeks from now', datetime.datetime(2021, 7, 30)), + ('3 saturdays ago', datetime.datetime(2021, 6, 12)), + ('4 months from today', datetime.datetime(2021, 11, 2)), + ('4 years from yesterday', datetime.datetime(2025, 7, 1)), + ('4 weeks from tomorrow', datetime.datetime(2021, 7, 31)), + ('april 15, 2005', datetime.datetime(2005, 4, 15)), + ('april 14', datetime.datetime(2021, 4, 14)), + ('9:30am on last wednesday', datetime.datetime(2021, 6, 30, 9, 30)), + ('2005/apr/15', datetime.datetime(2005, 4, 15)), + ('2005 apr 15', datetime.datetime(2005, 4, 15)), + ('the 1st wednesday in may', datetime.datetime(2021, 5, 5)), + ('last sun of june', datetime.datetime(2021, 6, 27)), + ('this Easter', datetime.datetime(2021, 4, 4)), + ('last christmas', datetime.datetime(2020, 12, 25)), + ('last Xmas', datetime.datetime(2020, 12, 25)), + ('xmas, 1999', datetime.datetime(1999, 12, 25)), + ('next mlk day', datetime.datetime(2022, 1, 17)), + ('Halloween, 2020', datetime.datetime(2020, 10, 31)), + ('5 work days after independence day', datetime.datetime(2021, 7, 12)), + ('50 working days from last wed', datetime.datetime(2021, 9, 10)), + ('25 working days before columbus day', datetime.datetime(2021, 9, 3)), + ('today +1 week', datetime.datetime(2021, 7, 9)), + ('sunday -3 weeks', datetime.datetime(2021, 6, 13)), + ('4 weeks before xmas, 1999', datetime.datetime(1999, 11, 27)), + ('3 days before new years eve, 2000', datetime.datetime(2000, 12, 28)), + ('july 4th', datetime.datetime(2021, 7, 4)), + ('the ides of march', datetime.datetime(2021, 3, 15)), + ('the nones of april', datetime.datetime(2021, 4, 5)), + ('the kalends of may', datetime.datetime(2021, 5, 1)), + ('9/11/2001', datetime.datetime(2001, 9, 11)), + ('4 sundays before veterans\' day', datetime.datetime(2021, 10, 17)), + ('xmas eve', datetime.datetime(2021, 12, 24)), + ('this friday at 5pm', datetime.datetime(2021, 7, 9, 17, 0, 0)), + ('presidents day', datetime.datetime(2021, 2, 15)), + ('memorial day, 1921', datetime.datetime(1921, 5, 30)), + ('today -4 wednesdays', datetime.datetime(2021, 6, 9)), + ('thanksgiving', datetime.datetime(2021, 11, 25)), + ('2 sun in jun', datetime.datetime(2021, 6, 13)), + ('easter -40 days', datetime.datetime(2021, 2, 23)), + ('easter +39 days', datetime.datetime(2021, 5, 13)), + ('2nd Sunday in May, 2022', datetime.datetime(2022, 5, 8)), + ('1st tuesday in nov, 2024', datetime.datetime(2024, 11, 5)), + ( + '2 days before last xmas at 3:14:15.92a', + datetime.datetime(2020, 12, 23, 3, 14, 15, 92), + ), + ( + '3 weeks after xmas, 1995 at midday', + datetime.datetime(1996, 1, 15, 12, 0, 0), + ), + ( + '4 months before easter, 1992 at midnight', + datetime.datetime(1991, 12, 19), + ), + ( + '5 months before halloween, 1995 at noon', + datetime.datetime(1995, 5, 31, 12), + ), + ('4 days before last wednesday', datetime.datetime(2021, 6, 26)), + ('44 months after today', datetime.datetime(2025, 3, 2)), + ('44 years before today', datetime.datetime(1977, 7, 2)), + ('44 weeks ago', datetime.datetime(2020, 8, 28)), + ('15 minutes to 3am', datetime.datetime(2021, 7, 2, 2, 45)), + ('quarter past 4pm', datetime.datetime(2021, 7, 2, 16, 15)), + ('half past 9', datetime.datetime(2021, 7, 2, 9, 30)), + ('4 seconds to midnight', datetime.datetime(2021, 7, 1, 23, 59, 56)), + ( + '4 seconds to midnight, tomorrow', + datetime.datetime(2021, 7, 2, 23, 59, 56), + ), + ('2021/apr/15T21:30:44.55', datetime.datetime(2021, 4, 15, 21, 30, 44, 55)), + ( + '2021/apr/15 at 21:30:44.55', + datetime.datetime(2021, 4, 15, 21, 30, 44, 55), + ), + ( + '2021/4/15 at 21:30:44.55', + datetime.datetime(2021, 4, 15, 21, 30, 44, 55), + ), + ( + '2021/04/15 at 21:30:44.55', + datetime.datetime(2021, 4, 15, 21, 30, 44, 55), + ), + ( + '2021/04/15 at 21:30:44.55Z', + datetime.datetime(2021, 4, 15, 21, 30, 44, 55, tzinfo=pytz.timezone('UTC')), + ), + ( + '2021/04/15 at 21:30:44.55EST', + datetime.datetime(2021, 4, 15, 21, 30, 44, 55, tzinfo=pytz.timezone('EST')), + ), + ( + '13 days after last memorial day at 12 seconds before 4pm', + datetime.datetime(2020, 6, 7, 15, 59, 48), + ), + ( + ' 2 days before yesterday at 9am ', + datetime.datetime(2021, 6, 29, 9), + ), + ('-3 days before today', datetime.datetime(2021, 7, 5)), + ( + '3 days before yesterday at midnight EST', + datetime.datetime(2021, 6, 28, tzinfo=pytz.timezone('EST')), + ), +] + + +class TestDateparseUtils(unittest.TestCase): + @uu.check_method_for_perf_regressions + def test_dateparsing(self): + dp = du.DateParser(override_now_for_test_purposes=datetime.datetime(2021, 7, 2)) + + for (txt, expected_dt) in parsable_expressions: + try: + actual_dt = dp.parse(txt) + self.assertIsNotNone(actual_dt) + self.assertEqual( + actual_dt, + expected_dt, + f'"{txt}", got "{actual_dt}" while expecting "{expected_dt}"', + ) + except du.ParseException: + self.fail(f'Expected "{txt}" to parse successfully.') + + def test_whitespace_handling(self): + dp = du.DateParser(override_now_for_test_purposes=datetime.datetime(2021, 7, 2)) + + for (txt, expected_dt) in parsable_expressions: + try: + txt = f' {txt} ' + i = random.randint(2, 5) + replacement = ' ' * i + txt = re.sub(r'\s', replacement, txt) + actual_dt = dp.parse(txt) + self.assertIsNotNone(actual_dt) + self.assertEqual( + actual_dt, + expected_dt, + f'"{txt}", got "{actual_dt}" while expecting "{expected_dt}"', + ) + except du.ParseException: + self.fail(f'Expected "{txt}" to parse successfully.') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/decorator_utils_test.py b/tests/decorator_utils_test.py new file mode 100755 index 0000000..48c06d7 --- /dev/null +++ b/tests/decorator_utils_test.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""decorator_utils unittest such as it is.""" + +import unittest + +from pyutils import decorator_utils as du +from pyutils import unittest_utils as uu + + +class TestDecorators(unittest.TestCase): + def test_singleton(self): + @du.singleton + class FooBar: + pass + + x = FooBar() + y = FooBar() + self.assertTrue(x is y) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/dict_utils_test.py b/tests/dict_utils_test.py new file mode 100755 index 0000000..ab3e04b --- /dev/null +++ b/tests/dict_utils_test.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""dict_utils unittest.""" + +import unittest + +from pyutils import dict_utils as du +from pyutils import unittest_utils # Needed for --unittests_ignore_perf flag + + +class TestDictUtils(unittest.TestCase): + def test_init_or_inc(self): + d = {} + du.init_or_inc(d, 'a') + du.init_or_inc(d, 'b') + du.init_or_inc(d, 'a') + du.init_or_inc(d, 'b') + du.init_or_inc(d, 'c') + du.init_or_inc(d, 'c') + du.init_or_inc(d, 'd') + du.init_or_inc(d, 'e') + du.init_or_inc(d, 'a') + du.init_or_inc(d, 'b') + e = {'a': 3, 'b': 3, 'c': 2, 'd': 1, 'e': 1} + self.assertEqual(d, e) + + def test_shard_coalesce(self): + d = {'a': 3, 'b': 3, 'c': 2, 'd': 1, 'e': 1} + shards = du.shard(d, 2) + merged = du.coalesce(shards) + self.assertEqual(d, merged) + + def test_item_with_max_value(self): + d = {'a': 4, 'b': 3, 'c': 2, 'd': 1, 'e': 1} + self.assertEqual('a', du.item_with_max_value(d)[0]) + self.assertEqual(4, du.item_with_max_value(d)[1]) + self.assertEqual('a', du.key_with_max_value(d)) + self.assertEqual(4, du.max_value(d)) + + def test_item_with_min_value(self): + d = {'a': 4, 'b': 3, 'c': 2, 'd': 1, 'e': 0} + self.assertEqual('e', du.item_with_min_value(d)[0]) + self.assertEqual(0, du.item_with_min_value(d)[1]) + self.assertEqual('e', du.key_with_min_value(d)) + self.assertEqual(0, du.min_value(d)) + + def test_min_max_key(self): + d = {'a': 4, 'b': 3, 'c': 2, 'd': 1, 'e': 0} + self.assertEqual('a', du.min_key(d)) + self.assertEqual('e', du.max_key(d)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/exec_utils_test.py b/tests/exec_utils_test.py new file mode 100755 index 0000000..0af0cb0 --- /dev/null +++ b/tests/exec_utils_test.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""exec_utils unittest.""" + +import subprocess +import unittest + +from pyutils import exec_utils, unittest_utils + + +class TestExecUtils(unittest.TestCase): + def test_cmd_showing_output(self): + with unittest_utils.RecordStdout() as record: + ret = exec_utils.cmd_showing_output('/usr/bin/printf hello') + self.assertEqual('hello', record().readline()) + self.assertEqual(0, ret) + record().close() + + def test_cmd_showing_output_with_timeout(self): + try: + exec_utils.cmd_showing_output('sleep 10', timeout_seconds=0.1) + except subprocess.TimeoutExpired: + pass + else: + self.fail('Expected a TimeoutException, didn\'t see one.') + + def test_cmd_showing_output_fails(self): + with unittest_utils.RecordStdout() as record: + ret = exec_utils.cmd_showing_output('/usr/bin/printf hello && false') + self.assertEqual('hello', record().readline()) + self.assertEqual(1, ret) + record().close() + + def test_cmd_in_background(self): + p = exec_utils.cmd_in_background('sleep 100') + self.assertEqual(None, p.poll()) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/logging_utils_test.py b/tests/logging_utils_test.py new file mode 100755 index 0000000..79f0aad --- /dev/null +++ b/tests/logging_utils_test.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""logging_utils unittest.""" + +import os +import sys +import tempfile +import unittest + +from pyutils import logging_utils as lutils +from pyutils import string_utils as sutils +from pyutils import unittest_utils as uu + + +class TestLoggingUtils(unittest.TestCase): + def test_output_context(self): + unique_suffix = sutils.generate_uuid(True) + filename = f'/tmp/logging_utils_test.{unique_suffix}' + secret_message = f'This is a test, {unique_suffix}.' + + with tempfile.SpooledTemporaryFile(mode='r+') as tmpfile1: + with uu.RecordStdout() as record: + with lutils.OutputMultiplexerContext( + lutils.OutputMultiplexer.Destination.FILENAMES + | lutils.OutputMultiplexer.Destination.FILEHANDLES + | lutils.OutputMultiplexer.Destination.LOG_INFO, + filenames=[filename, '/dev/null'], + handles=[tmpfile1, sys.stdout], + ) as mplex: + mplex.print(secret_message, end='') + + # Make sure it was written to the filename. + with open(filename, 'r') as rf: + self.assertEqual(rf.readline(), secret_message) + os.unlink(filename) + + # Make sure it was written to stdout. + tmp = record().readline() + self.assertEqual(tmp, secret_message) + + # Make sure it was written to the filehandle. + tmpfile1.seek(0) + tmp = tmpfile1.readline() + self.assertEqual(tmp, secret_message) + + def test_record_streams(self): + with uu.RecordMultipleStreams(sys.stderr, sys.stdout) as record: + print("This is a test!") + print("This is one too.", file=sys.stderr) + self.assertEqual( + record().readlines(), ["This is a test!\n", "This is one too.\n"] + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/parallelize/parallelize_itest.py b/tests/parallelize/parallelize_itest.py new file mode 100755 index 0000000..80c19ca --- /dev/null +++ b/tests/parallelize/parallelize_itest.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""parallelize unittest.""" + +import logging +import sys + +from pyutils import bootstrap, decorator_utils +from pyutils.parallelize import executors +from pyutils.parallelize import parallelize as p +from pyutils.parallelize import smart_future + +logger = logging.getLogger(__name__) + + +@p.parallelize(method=p.Method.THREAD) +def compute_factorial_thread(n): + total = 1 + for x in range(2, n): + total *= x + return total + + +@p.parallelize(method=p.Method.PROCESS) +def compute_factorial_process(n): + total = 1 + for x in range(2, n): + total *= x + return total + + +@p.parallelize(method=p.Method.REMOTE) +def compute_factorial_remote(n): + total = 1 + for x in range(2, n): + total *= x + return total + + +@decorator_utils.timed +def test_thread_parallelization() -> None: + results = [] + for _ in range(50): + f = compute_factorial_thread(_) + results.append(f) + smart_future.wait_all(results) + for future in results: + print(f'Thread: {future}') + texecutor = executors.DefaultExecutors().thread_pool() + texecutor.shutdown() + + +@decorator_utils.timed +def test_process_parallelization() -> None: + results = [] + for _ in range(50): + results.append(compute_factorial_process(_)) + for future in smart_future.wait_any(results): + print(f'Process: {future}') + pexecutor = executors.DefaultExecutors().process_pool() + pexecutor.shutdown() + + +@decorator_utils.timed +def test_remote_parallelization() -> None: + results = [] + for _ in range(10): + results.append(compute_factorial_remote(_)) + for result in smart_future.wait_any(results): + print(result) + rexecutor = executors.DefaultExecutors().remote_pool() + rexecutor.shutdown() + + +@bootstrap.initialize +def main() -> None: + test_thread_parallelization() + test_process_parallelization() + test_remote_parallelization() + sys.exit(0) + + +if __name__ == '__main__': + try: + main() + except Exception as e: + logger.exception(e) + sys.exit(1) diff --git a/tests/parallelize/thread_utils_test.py b/tests/parallelize/thread_utils_test.py new file mode 100755 index 0000000..9eea538 --- /dev/null +++ b/tests/parallelize/thread_utils_test.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""thread_utils unittest.""" + +import threading +import time +import unittest + +from pyutils import unittest_utils +from pyutils.parallelize import thread_utils + + +class TestThreadUtils(unittest.TestCase): + invocation_count = 0 + + @thread_utils.background_thread + def background_thread(self, a: int, b: str, stop_event: threading.Event) -> None: + while not stop_event.is_set(): + self.assertEqual(123, a) + self.assertEqual('abc', b) + time.sleep(0.1) + + def test_background_thread(self): + (thread, event) = self.background_thread(123, 'abc') + self.assertTrue(thread.is_alive()) + time.sleep(1.0) + event.set() + thread.join() + self.assertFalse(thread.is_alive()) + + @thread_utils.periodically_invoke(period_sec=0.3, stop_after=3) + def periodic_invocation_target(self, a: int, b: str): + self.assertEqual(123, a) + self.assertEqual('abc', b) + TestThreadUtils.invocation_count += 1 + + def test_periodically_invoke_with_limit(self): + TestThreadUtils.invocation_count = 0 + (thread, event) = self.periodic_invocation_target(123, 'abc') + self.assertTrue(thread.is_alive()) + time.sleep(1.0) + self.assertEqual(3, TestThreadUtils.invocation_count) + self.assertFalse(thread.is_alive()) + + @thread_utils.periodically_invoke(period_sec=0.1, stop_after=None) + def forever_periodic_invocation_target(self, a: int, b: str): + self.assertEqual(123, a) + self.assertEqual('abc', b) + TestThreadUtils.invocation_count += 1 + + def test_periodically_invoke_runs_forever(self): + TestThreadUtils.invocation_count = 0 + (thread, event) = self.forever_periodic_invocation_target(123, 'abc') + self.assertTrue(thread.is_alive()) + time.sleep(1.0) + self.assertTrue(thread.is_alive()) + time.sleep(1.0) + event.set() + thread.join() + self.assertFalse(thread.is_alive()) + self.assertTrue(TestThreadUtils.invocation_count >= 19) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/run_tests.py b/tests/run_tests.py new file mode 100755 index 0000000..025f2d2 --- /dev/null +++ b/tests/run_tests.py @@ -0,0 +1,544 @@ +#!/usr/bin/env python3 + +""" +A smart, fast test runner. Used in a git pre-commit hook. +""" + +import logging +import os +import re +import subprocess +import threading +import time +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +from overrides import overrides + +from pyutils import ansi, bootstrap, config, exec_utils, text_utils +from pyutils.files import file_utils +from pyutils.parallelize import parallelize as par +from pyutils.parallelize import smart_future, thread_utils + +logger = logging.getLogger(__name__) +args = config.add_commandline_args(f'({__file__})', f'Args related to {__file__}') +args.add_argument('--unittests', '-u', action='store_true', help='Run unittests.') +args.add_argument('--doctests', '-d', action='store_true', help='Run doctests.') +args.add_argument( + '--integration', '-i', action='store_true', help='Run integration tests.' +) +args.add_argument( + '--all', + '-a', + action='store_true', + help='Run unittests, doctests and integration tests. Equivalient to -u -d -i', +) +args.add_argument( + '--coverage', + '-c', + action='store_true', + help='Run tests and capture code coverage data', +) + +HOME = os.environ['HOME'] + +# These tests will be run twice in --coverage mode: once to get code +# coverage and then again with not coverage enabeled. This is because +# they pay attention to code performance which is adversely affected +# by coverage. +PERF_SENSATIVE_TESTS = set(['string_utils_test.py']) +TESTS_TO_SKIP = set(['zookeeper_test.py', 'run_tests.py']) + +ROOT = ".." + + +@dataclass +class TestingParameters: + halt_on_error: bool + """Should we stop as soon as one error has occurred?""" + + halt_event: threading.Event + """An event that, when set, indicates to stop ASAP.""" + + +@dataclass +class TestToRun: + name: str + """The name of the test""" + + kind: str + """The kind of the test""" + + cmdline: str + """The command line to execute""" + + +@dataclass +class TestResults: + name: str + """The name of this test / set of tests.""" + + tests_executed: List[str] + """Tests that were executed.""" + + tests_succeeded: List[str] + """Tests that succeeded.""" + + tests_failed: List[str] + """Tests that failed.""" + + tests_timed_out: List[str] + """Tests that timed out.""" + + def __add__(self, other): + self.tests_executed.extend(other.tests_executed) + self.tests_succeeded.extend(other.tests_succeeded) + self.tests_failed.extend(other.tests_failed) + self.tests_timed_out.extend(other.tests_timed_out) + return self + + __radd__ = __add__ + + def __repr__(self) -> str: + out = f'{self.name}: ' + out += f'{ansi.fg("green")}' + out += f'{len(self.tests_succeeded)}/{len(self.tests_executed)} passed' + out += f'{ansi.reset()}.\n' + + if len(self.tests_failed) > 0: + out += f' ..{ansi.fg("red")}' + out += f'{len(self.tests_failed)} tests failed' + out += f'{ansi.reset()}:\n' + for test in self.tests_failed: + out += f' {test}\n' + out += '\n' + + if len(self.tests_timed_out) > 0: + out += f' ..{ansi.fg("yellow")}' + out += f'{len(self.tests_timed_out)} tests timed out' + out += f'{ansi.reset()}:\n' + for test in self.tests_failed: + out += f' {test}\n' + out += '\n' + return out + + +class TestRunner(ABC, thread_utils.ThreadWithReturnValue): + """A Base class for something that runs a test.""" + + def __init__(self, params: TestingParameters): + """Create a TestRunner. + + Args: + params: Test running paramters. + + """ + super().__init__(self, target=self.begin, args=[params]) + self.params = params + self.test_results = TestResults( + name=self.get_name(), + tests_executed=[], + tests_succeeded=[], + tests_failed=[], + tests_timed_out=[], + ) + self.tests_started = 0 + self.lock = threading.Lock() + + @abstractmethod + def get_name(self) -> str: + """The name of this test collection.""" + pass + + def get_status(self) -> Tuple[int, TestResults]: + """Ask the TestRunner for its status.""" + with self.lock: + return (self.tests_started, self.test_results) + + @abstractmethod + def begin(self, params: TestingParameters) -> TestResults: + """Start execution.""" + pass + + +class TemplatedTestRunner(TestRunner, ABC): + """A TestRunner that has a recipe for executing the tests.""" + + @abstractmethod + def identify_tests(self) -> List[TestToRun]: + """Return a list of tuples (test, cmdline) that should be executed.""" + pass + + @abstractmethod + def run_test(self, test: TestToRun) -> TestResults: + """Run a single test and return its TestResults.""" + pass + + def check_for_abort(self): + """Periodically caled to check to see if we need to stop.""" + + if self.params.halt_event.is_set(): + logger.debug('Thread %s saw halt event; exiting.', self.get_name()) + raise Exception("Kill myself!") + if self.params.halt_on_error: + if len(self.test_results.tests_failed) > 0: + logger.error( + 'Thread %s saw abnormal results; exiting.', self.get_name() + ) + raise Exception("Kill myself!") + + def persist_output(self, test: TestToRun, message: str, output: str) -> None: + """Called to save the output of a test run.""" + + dest = f'{test.name}-output.txt' + with open(f'./test_output/{dest}', 'w') as wf: + print(message, file=wf) + print('-' * len(message), file=wf) + wf.write(output) + + def execute_commandline( + self, + test: TestToRun, + *, + timeout: float = 120.0, + ) -> TestResults: + """Execute a particular commandline to run a test.""" + + try: + output = exec_utils.cmd( + test.cmdline, + timeout_seconds=timeout, + ) + self.persist_output( + test, f'{test.name} ({test.cmdline}) succeeded.', output + ) + logger.debug( + '%s: %s (%s) succeeded', self.get_name(), test.name, test.cmdline + ) + return TestResults(test.name, [test.name], [test.name], [], []) + except subprocess.TimeoutExpired as e: + msg = f'{self.get_name()}: {test.name} ({test.cmdline}) timed out after {e.timeout:.1f} seconds.' + logger.error(msg) + logger.debug( + '%s: %s output when it timed out: %s', + self.get_name(), + test.name, + e.output, + ) + self.persist_output(test, msg, e.output.decode('utf-8')) + return TestResults( + test.name, + [test.name], + [], + [], + [test.name], + ) + except subprocess.CalledProcessError as e: + msg = f'{self.get_name()}: {test.name} ({test.cmdline}) failed; exit code {e.returncode}' + logger.error(msg) + logger.debug( + '%s: %s output when it failed: %s', self.get_name(), test.name, e.output + ) + self.persist_output(test, msg, e.output.decode('utf-8')) + return TestResults( + test.name, + [test.name], + [], + [test.name], + [], + ) + + @overrides + def begin(self, params: TestingParameters) -> TestResults: + logger.debug('Thread %s started.', self.get_name()) + interesting_tests = self.identify_tests() + logger.debug( + '%s: Identified %d tests to be run.', + self.get_name(), + len(interesting_tests), + ) + + # Note: because of @parallelize on run_tests it actually + # returns a SmartFuture with a TestResult inside of it. + # That's the reason for this Any business. + running: List[Any] = [] + for test_to_run in interesting_tests: + running.append(self.run_test(test_to_run)) + logger.debug( + '%s: Test %s started in the background.', + self.get_name(), + test_to_run.name, + ) + self.tests_started += 1 + + for future in smart_future.wait_any(running): + self.check_for_abort() + result = future._resolve() + logger.debug('Test %s finished.', result.name) + self.test_results += result + + logger.debug('Thread %s finished.', self.get_name()) + return self.test_results + + +class UnittestTestRunner(TemplatedTestRunner): + """Run all known Unittests.""" + + @overrides + def get_name(self) -> str: + return "Unittests" + + @overrides + def identify_tests(self) -> List[TestToRun]: + ret = [] + for test in file_utils.get_matching_files_recursive(ROOT, '*_test.py'): + basename = file_utils.without_path(test) + if basename in TESTS_TO_SKIP: + continue + if config.config['coverage']: + ret.append( + TestToRun( + name=basename, + kind='unittest capturing coverage', + cmdline=f'coverage run --source ../src {test} --unittests_ignore_perf 2>&1', + ) + ) + if basename in PERF_SENSATIVE_TESTS: + ret.append( + TestToRun( + name=basename, + kind='unittest w/o coverage to record perf', + cmdline=f'{test} 2>&1', + ) + ) + else: + ret.append( + TestToRun( + name=basename, + kind='unittest', + cmdline=f'{test} 2>&1', + ) + ) + return ret + + @par.parallelize + def run_test(self, test: TestToRun) -> TestResults: + return self.execute_commandline(test) + + +class DoctestTestRunner(TemplatedTestRunner): + """Run all known Doctests.""" + + @overrides + def get_name(self) -> str: + return "Doctests" + + @overrides + def identify_tests(self) -> List[TestToRun]: + ret = [] + out = exec_utils.cmd(f'grep -lR "^ *import doctest" {ROOT}/*') + for test in out.split('\n'): + if re.match(r'.*\.py$', test): + basename = file_utils.without_path(test) + if basename in TESTS_TO_SKIP: + continue + if config.config['coverage']: + ret.append( + TestToRun( + name=basename, + kind='doctest capturing coverage', + cmdline=f'coverage run --source ../src {test} 2>&1', + ) + ) + if basename in PERF_SENSATIVE_TESTS: + ret.append( + TestToRun( + name=basename, + kind='doctest w/o coverage to record perf', + cmdline=f'python3 {test} 2>&1', + ) + ) + else: + ret.append( + TestToRun( + name=basename, + kind='doctest', + cmdline=f'python3 {test} 2>&1', + ) + ) + return ret + + @par.parallelize + def run_test(self, test: TestToRun) -> TestResults: + return self.execute_commandline(test) + + +class IntegrationTestRunner(TemplatedTestRunner): + """Run all know Integration tests.""" + + @overrides + def get_name(self) -> str: + return "Integration Tests" + + @overrides + def identify_tests(self) -> List[TestToRun]: + ret = [] + for test in file_utils.get_matching_files_recursive(ROOT, '*_itest.py'): + basename = file_utils.without_path(test) + if basename in TESTS_TO_SKIP: + continue + if config.config['coverage']: + ret.append( + TestToRun( + name=basename, + kind='integration test capturing coverage', + cmdline=f'coverage run --source ../src {test} 2>&1', + ) + ) + if basename in PERF_SENSATIVE_TESTS: + ret.append( + TestToRun( + name=basename, + kind='integration test w/o coverage to capture perf', + cmdline=f'{test} 2>&1', + ) + ) + else: + ret.append( + TestToRun( + name=basename, kind='integration test', cmdline=f'{test} 2>&1' + ) + ) + return ret + + @par.parallelize + def run_test(self, test: TestToRun) -> TestResults: + return self.execute_commandline(test) + + +def test_results_report(results: Dict[str, TestResults]) -> int: + """Give a final report about the tests that were run.""" + total_problems = 0 + for result in results.values(): + print(result, end='') + total_problems += len(result.tests_failed) + total_problems += len(result.tests_timed_out) + + if total_problems > 0: + print('Reminder: look in ./test_output to view test output logs') + return total_problems + + +def code_coverage_report(): + """Give a final code coverage report.""" + text_utils.header('Code Coverage') + exec_utils.cmd('coverage combine .coverage*') + out = exec_utils.cmd( + 'coverage report --omit=config-3.*.py,*_test.py,*_itest.py --sort=-cover' + ) + print(out) + print( + """To recall this report w/o re-running the tests: + + $ coverage report --omit=config-3.*.py,*_test.py,*_itest.py --sort=-cover + +...from the 'tests' directory. Note that subsequent calls to +run_tests.py with --coverage will klobber previous results. See: + + https://coverage.readthedocs.io/en/6.2/ +""" + ) + + +@bootstrap.initialize +def main() -> Optional[int]: + saw_flag = False + halt_event = threading.Event() + threads: List[TestRunner] = [] + + halt_event.clear() + params = TestingParameters( + halt_on_error=True, + halt_event=halt_event, + ) + + if config.config['coverage']: + logger.debug('Clearing existing coverage data via "coverage erase".') + exec_utils.cmd('coverage erase') + + if config.config['unittests'] or config.config['all']: + saw_flag = True + threads.append(UnittestTestRunner(params)) + if config.config['doctests'] or config.config['all']: + saw_flag = True + threads.append(DoctestTestRunner(params)) + if config.config['integration'] or config.config['all']: + saw_flag = True + threads.append(IntegrationTestRunner(params)) + + if not saw_flag: + config.print_usage() + print('ERROR: one of --unittests, --doctests or --integration is required.') + return 1 + + for thread in threads: + thread.start() + + results: Dict[str, TestResults] = {} + while len(results) != len(threads): + started = 0 + done = 0 + failed = 0 + + for thread in threads: + (s, tr) = thread.get_status() + started += s + failed += len(tr.tests_failed) + len(tr.tests_timed_out) + done += failed + len(tr.tests_succeeded) + if not thread.is_alive(): + tid = thread.name + if tid not in results: + result = thread.join() + if result: + results[tid] = result + if len(result.tests_failed) > 0: + logger.error( + 'Thread %s returned abnormal results; killing the others.', + tid, + ) + halt_event.set() + + if started > 0: + percent_done = done / started + else: + percent_done = 0.0 + + if failed == 0: + color = ansi.fg('green') + else: + color = ansi.fg('red') + + if percent_done < 100.0: + print( + text_utils.bar_graph_string( + done, + started, + text=text_utils.BarGraphText.FRACTION, + width=80, + fgcolor=color, + ), + end='\r', + flush=True, + ) + time.sleep(0.5) + + print(f'{ansi.clear_line()}Final Report:') + if config.config['coverage']: + code_coverage_report() + total_problems = test_results_report(results) + return total_problems + + +if __name__ == '__main__': + main() diff --git a/tests/run_tests_serially.sh b/tests/run_tests_serially.sh new file mode 100755 index 0000000..d9c8590 --- /dev/null +++ b/tests/run_tests_serially.sh @@ -0,0 +1,198 @@ +#!/bin/bash + +# Run tests in serial. Invoke from within tests/ directory. + +ROOT=.. +DOCTEST=0 +UNITTEST=0 +INTEGRATION=0 +FAILURES=0 +TESTS_RUN=0 +COVERAGE=0 +PERF_TESTS=("string_utils_test.py") + + +if [ -f color_vars.sh ]; then + source color_vars.sh +fi + + +dup() { + if [ $# -ne 2 ]; then + echo "Usage: dup " + return + fi + local times=$(seq 1 $2) + for x in ${times}; do + echo -n "$1" + done +} + +make_header() { + if [ $# -ne 2 ]; then + echo "Usage: make_header " + return + fi + local title="$1" + local title_len=${#title} + title_len=$((title_len + 4)) + local width=70 + local left=4 + local right=$(($width-($title_len+$left))) + local color="$2" + dup '-' $left + echo -ne "[ ${color}${title}${NC} ]" + dup '-' $right + echo +} + +function usage() { + echo "Usage: $0 [-a]|[-i][-u][-d] [--coverage]" + echo + echo "Runs tests under ${ROOT}. Options control which test types:" + echo + echo " -a | --all . . . . . . . . . . . . Run all types of tests" + echo " -d | --doctests . . . . . . . . . Run doctests" + echo " -u | --unittests . . . . . . . . . Run unittests" + echo " -i | --integration . . . . . . . . Run integration tests" + echo + exit 1 +} + +while [[ $# -gt 0 ]]; do + key="$1" + case $key in + -a|--all) + DOCTEST=1 + UNITTEST=1 + INTEGRATION=1 + ;; + -d|--doctests) + DOCTEST=1 + ;; + -u|--unittests) + UNITTEST=1 + ;; + -i|--integration) + INTEGRATION=1 + ;; + --coverage) + COVERAGE=1 + ;; + *) # unknown option + echo "Argument $key was not recognized." + echo + usage + exit 1 + ;; + esac + shift +done + +if [ $(expr ${DOCTEST} + ${UNITTEST} + ${INTEGRATION}) -eq 0 ]; then + usage + exit 2 +fi + +if [ ${COVERAGE} -eq 1 ]; then + coverage erase +fi + +FAILED_TESTS="" +if [ ${DOCTEST} -eq 1 ]; then + for doctest in $(find ${ROOT} -name "*.py" -exec grep -l "import doctest" {} \;); do + BASE=$(basename ${doctest}) + HDR="${BASE} (doctest)" + make_header "${HDR}" "${CYAN}" + if [ ${COVERAGE} -eq 1 ]; then + OUT=$( coverage run --source ../src ${doctest} >./test_output/${BASE}-output.txt 2>&1 ) + else + OUT=$( python3 ${doctest} >./test_output/${BASE}-output.txt 2>&1 ) + fi + TESTS_RUN=$((TESTS_RUN+1)) + FAILED=$( echo "${OUT}" | grep '\*\*\*Test Failed\*\*\*' | wc -l ) + if [ $FAILED == 0 ]; then + echo "OK" + else + echo -e "${FAILED}" + FAILURES=$((FAILURES+1)) + FAILED_TESTS="${FAILED_TESTS},${BASE} (python3 ${doctest})" + fi + done +fi + +if [ ${UNITTEST} -eq 1 ]; then + for test in $(find ${ROOT} -name "*_test.py" -print); do + BASE=$(basename ${test}) + HDR="${BASE} (unittest)" + make_header "${HDR}" "${GREEN}" + if [ ${COVERAGE} -eq 1 ]; then + coverage run --source ../src ${test} --unittests_ignore_perf >./test_output/${BASE}-output.txt 2>&1 + if [[ " ${PERF_TESTS[*]} " =~ " ${BASE} " ]]; then + echo "(re-running w/o coverage to record perf results)." + ${test} + fi + else + ${test} >./test_output/${BASE}-output.txt 2>&1 + fi + if [ $? -eq 0 ]; then + echo "OK" + else + FAILURES=$((FAILURES+1)) + FAILED_TESTS="${FAILED_TESTS},${BASE} (python3 ${test})" + fi + TESTS_RUN=$((TESTS_RUN+1)) + done +fi + +if [ ${INTEGRATION} -eq 1 ]; then + for test in $(find ${ROOT} -name "*_itest.py" -print); do + BASE=$(basename ${test}) + HDR="${BASE} (integration test)" + make_header "${HDR}" "${ORANGE}" + if [ ${COVERAGE} -eq 1 ]; then + coverage run --source ../src ${test} >./test_output/${BASE}-output.txt 2>&1 + else + ${test} >./test_output/${BASE}-output.txt 2>&1 + fi + if [ $? -eq 0 ]; then + echo "OK" + else + FAILURES=$((FAILURES+1)) + FAILED_TESTS="${FAILED_TESTS},${BASE} (python3 ${test})" + fi + TESTS_RUN=$((TESTS_RUN+1)) + done +fi + +if [ ${COVERAGE} -eq 1 ]; then + make_header "Code Coverage Report" "${GREEN}" + coverage combine .coverage* + coverage report --omit=config-3.9.py,*_test.py,*_itest.py --sort=-cover + echo + echo "To recall this report w/o re-running the tests:" + echo + echo " $ coverage report --omit=config-3.8.py,*_test.py,*_itest.py --sort=-cover" + echo + echo "...from the 'tests' directory. Note that subsequent calls to " + echo "run_tests.sh with --coverage will klobber previous results. See:" + echo + echo " https://coverage.readthedocs.io/en/6.2/" + echo +fi + +if [ ${FAILURES} -ne 0 ]; then + FAILED_TESTS=$(echo ${FAILED_TESTS} | sed 's/^,/__/g') + FAILED_TESTS=$(echo ${FAILED_TESTS} | sed 's/,/\n__/g') + if [ ${FAILURES} -eq 1 ]; then + echo -e "${RED}There was ${FAILURES}/${TESTS_RUN} failure:" + else + echo -e "${RED}There were ${FAILURES}/${TESTS_RUN} failures:" + fi + echo "${FAILED_TESTS}" + echo -e "${NC}" + exit ${FAILURES} +else + echo -e "${BLACK}${ON_GREEN}All (${TESTS_RUN}) test(s) passed.${NC}" + exit 0 +fi diff --git a/tests/security/acl_test.py b/tests/security/acl_test.py new file mode 100755 index 0000000..8055d2d --- /dev/null +++ b/tests/security/acl_test.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""acl unittest.""" + +import re +import unittest + +from pyutils import unittest_utils # Needed for --unittests_ignore_perf flag +from pyutils.security import acl + + +class TestSimpleACL(unittest.TestCase): + def test_set_based_acl(self): + even = acl.SetBasedACL( + allow_set=set([2, 4, 6, 8, 10]), + deny_set=set([1, 3, 5, 7, 9]), + order_to_check_allow_deny=acl.Order.ALLOW_DENY, + default_answer=False, + ) + self.assertTrue(even(2)) + self.assertFalse(even(3)) + self.assertFalse(even(-4)) + + def test_wildcard_based_acl(self): + a_or_b = acl.StringWildcardBasedACL( + allowed_patterns=['a*', 'b*'], + order_to_check_allow_deny=acl.Order.ALLOW_DENY, + default_answer=False, + ) + self.assertTrue(a_or_b('aardvark')) + self.assertTrue(a_or_b('baboon')) + self.assertFalse(a_or_b('cheetah')) + + def test_re_based_acl(self): + weird = acl.StringREBasedACL( + denied_regexs=[re.compile('^a.*a$'), re.compile('^b.*b$')], + order_to_check_allow_deny=acl.Order.DENY_ALLOW, + default_answer=True, + ) + self.assertTrue(weird('aardvark')) + self.assertFalse(weird('anaconda')) + self.assertFalse(weird('blackneb')) + self.assertTrue(weird('crow')) + + def test_compound_acls_disjunction(self): + a_b_c = acl.StringWildcardBasedACL( + allowed_patterns=['a*', 'b*', 'c*'], + order_to_check_allow_deny=acl.Order.ALLOW_DENY, + default_answer=False, + ) + c_d_e = acl.StringWildcardBasedACL( + allowed_patterns=['c*', 'd*', 'e*'], + order_to_check_allow_deny=acl.Order.ALLOW_DENY, + default_answer=False, + ) + disjunction = acl.AnyCompoundACL( + subacls=[a_b_c, c_d_e], + order_to_check_allow_deny=acl.Order.ALLOW_DENY, + default_answer=False, + ) + self.assertTrue(disjunction('aardvark')) + self.assertTrue(disjunction('caribou')) + self.assertTrue(disjunction('eagle')) + self.assertFalse(disjunction('newt')) + + def test_compound_acls_conjunction(self): + a_b_c = acl.StringWildcardBasedACL( + allowed_patterns=['a*', 'b*', 'c*'], + order_to_check_allow_deny=acl.Order.ALLOW_DENY, + default_answer=False, + ) + c_d_e = acl.StringWildcardBasedACL( + allowed_patterns=['c*', 'd*', 'e*'], + order_to_check_allow_deny=acl.Order.ALLOW_DENY, + default_answer=False, + ) + conjunction = acl.AllCompoundACL( + subacls=[a_b_c, c_d_e], + order_to_check_allow_deny=acl.Order.ALLOW_DENY, + default_answer=False, + ) + self.assertFalse(conjunction('aardvark')) + self.assertTrue(conjunction('caribou')) + self.assertTrue(conjunction('condor')) + self.assertFalse(conjunction('eagle')) + self.assertFalse(conjunction('newt')) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/string_utils_test.py b/tests/string_utils_test.py new file mode 100755 index 0000000..5aeb33d --- /dev/null +++ b/tests/string_utils_test.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""string_utils unittest.""" + +import unittest + +from pyutils import bootstrap +from pyutils import string_utils as su +from pyutils import unittest_utils as uu +from pyutils.ansi import bg, fg, reset + + +@uu.check_all_methods_for_perf_regressions() +class TestStringUtils(unittest.TestCase): + def test_is_none_or_empty(self): + self.assertTrue(su.is_none_or_empty(None)) + self.assertTrue(su.is_none_or_empty("")) + self.assertTrue(su.is_none_or_empty("\n")) + self.assertTrue(su.is_none_or_empty(' ')) + self.assertTrue(su.is_none_or_empty(' \n \r \t ')) + self.assertFalse(su.is_none_or_empty("Covfefe")) + self.assertFalse(su.is_none_or_empty("1234")) + + def test_is_string(self): + self.assertTrue(su.is_string("test")) + self.assertTrue(su.is_string("")) + self.assertFalse(su.is_string(bytes(0x1234))) + self.assertFalse(su.is_string(1234)) + + def test_is_empty_string(self): + self.assertTrue(su.is_empty_string('')) + self.assertTrue(su.is_empty_string(' \t\t \n \r ')) + self.assertFalse(su.is_empty_string(' this is a test ')) + self.assertFalse(su.is_empty_string(22)) + + def test_is_full_string(self): + self.assertFalse(su.is_full_string('')) + self.assertFalse(su.is_full_string(' \t\t \n \r ')) + self.assertTrue(su.is_full_string(' this is a test ')) + self.assertFalse(su.is_full_string(22)) + + def test_is_number(self): + self.assertTrue(su.is_number("1234")) + self.assertTrue(su.is_number("-1234")) + self.assertTrue(su.is_number("1234.55")) + self.assertTrue(su.is_number("-1234.55")) + self.assertTrue(su.is_number("+1234")) + self.assertTrue(su.is_number("+1234.55")) + self.assertTrue(su.is_number("-0.8485996602e10")) + self.assertTrue(su.is_number("-0.8485996602E10")) + self.assertFalse(su.is_number("-0.8485996602t10")) + self.assertFalse(su.is_number(" 1234 ")) + self.assertFalse(su.is_number(" 1234")) + self.assertFalse(su.is_number("1234 ")) + self.assertFalse(su.is_number("fifty")) + + def test_is_integer_number(self): + self.assertTrue(su.is_integer_number("1234")) + self.assertTrue(su.is_integer_number("-1234")) + self.assertFalse(su.is_integer_number("1234.55")) + self.assertFalse(su.is_integer_number("-1234.55")) + self.assertTrue(su.is_integer_number("+1234")) + self.assertTrue(su.is_integer_number("0x1234")) + self.assertTrue(su.is_integer_number("0xdeadbeef")) + self.assertFalse(su.is_integer_number("+1234.55")) + self.assertTrue(su.is_octal_integer_number("+0o777")) + self.assertFalse(su.is_integer_number("-0.8485996602e10")) + self.assertFalse(su.is_integer_number("-0.8485996602E10")) + self.assertFalse(su.is_integer_number("-0.8485996602t10")) + self.assertFalse(su.is_integer_number(" 1234 ")) + self.assertFalse(su.is_integer_number(" 1234")) + self.assertFalse(su.is_integer_number("1234 ")) + self.assertFalse(su.is_integer_number("fifty")) + + def test_is_hexidecimal_integer_number(self): + self.assertTrue(su.is_hexidecimal_integer_number("0x1234")) + self.assertTrue(su.is_hexidecimal_integer_number("0X1234")) + self.assertTrue(su.is_hexidecimal_integer_number("0x1234D")) + self.assertTrue(su.is_hexidecimal_integer_number("0xF1234")) + self.assertTrue(su.is_hexidecimal_integer_number("0xe1234")) + self.assertTrue(su.is_hexidecimal_integer_number("0x1234a")) + self.assertTrue(su.is_hexidecimal_integer_number("0xdeadbeef")) + self.assertTrue(su.is_hexidecimal_integer_number("-0xdeadbeef")) + self.assertTrue(su.is_hexidecimal_integer_number("+0xdeadbeef")) + self.assertFalse(su.is_hexidecimal_integer_number("0xH1234")) + self.assertFalse(su.is_hexidecimal_integer_number("0x1234H")) + self.assertFalse(su.is_hexidecimal_integer_number("nine")) + + def test_is_octal_integer_number(self): + self.assertTrue(su.is_octal_integer_number("0o111")) + self.assertTrue(su.is_octal_integer_number("0O111")) + self.assertTrue(su.is_octal_integer_number("-0o111")) + self.assertTrue(su.is_octal_integer_number("+0o777")) + self.assertFalse(su.is_octal_integer_number("-+0o111")) + self.assertFalse(su.is_octal_integer_number("0o181")) + self.assertFalse(su.is_octal_integer_number("0o1a1")) + self.assertFalse(su.is_octal_integer_number("one")) + + def test_is_binary_integer_number(self): + self.assertTrue(su.is_binary_integer_number("0b10100101110")) + self.assertTrue(su.is_binary_integer_number("+0b10100101110")) + self.assertTrue(su.is_binary_integer_number("-0b10100101110")) + self.assertTrue(su.is_binary_integer_number("0B10100101110")) + self.assertTrue(su.is_binary_integer_number("+0B10100101110")) + self.assertTrue(su.is_binary_integer_number("-0B10100101110")) + self.assertFalse(su.is_binary_integer_number("-0B10100101110 ")) + self.assertFalse(su.is_binary_integer_number(" -0B10100101110")) + self.assertFalse(su.is_binary_integer_number("-0B10100 101110")) + self.assertFalse(su.is_binary_integer_number("0b10100201110")) + self.assertFalse(su.is_binary_integer_number("0b10100101e110")) + self.assertFalse(su.is_binary_integer_number("fred")) + + def test_to_int(self): + self.assertEqual(su.to_int("1234"), 1234) + self.assertEqual(su.to_int("0x1234"), 4660) + self.assertEqual(su.to_int("0o777"), 511) + self.assertEqual(su.to_int("0b111"), 7) + + def test_is_decimal_number(self): + self.assertTrue(su.is_decimal_number('4.3')) + self.assertTrue(su.is_decimal_number('.3')) + self.assertTrue(su.is_decimal_number('0.3')) + self.assertFalse(su.is_decimal_number('3.')) + self.assertTrue(su.is_decimal_number('3.0')) + self.assertTrue(su.is_decimal_number('3.0492949249e20')) + self.assertFalse(su.is_decimal_number('3')) + self.assertFalse(su.is_decimal_number('0x11')) + + def test_strip_escape_sequences(self): + s = f' {fg("red")} this is a test {bg("white")} this is a test {reset()} ' + self.assertEqual( + su.strip_escape_sequences(s), + ' this is a test this is a test ', + ) + s = ' this is another test ' + self.assertEqual(su.strip_escape_sequences(s), s) + + def test_is_url(self): + self.assertTrue(su.is_url("http://host.domain/uri/uri#shard?param=value+s")) + self.assertTrue(su.is_url("http://127.0.0.1/uri/uri#shard?param=value+s")) + self.assertTrue( + su.is_url("http://user:pass@127.0.0.1:81/uri/uri#shard?param=value+s") + ) + self.assertTrue(su.is_url("ftp://127.0.0.1/uri/uri")) + + def test_is_email(self): + self.assertTrue(su.is_email('scott@gasch.org')) + self.assertTrue(su.is_email('scott.gasch@gmail.com')) + self.assertFalse(su.is_email('@yahoo.com')) + self.assertFalse(su.is_email('indubidibly')) + self.assertFalse(su.is_email('frank997!!@foobar.excellent')) + + def test_suffix_string_to_number(self): + self.assertEqual(1024, su.suffix_string_to_number('1Kb')) + self.assertEqual(1024 * 1024, su.suffix_string_to_number('1Mb')) + self.assertEqual(1024, su.suffix_string_to_number('1k')) + self.assertEqual(1024, su.suffix_string_to_number('1kb')) + self.assertEqual(None, su.suffix_string_to_number('1Jl')) + self.assertEqual(None, su.suffix_string_to_number('undeniable')) + + def test_number_to_suffix_string(self): + self.assertEqual('1.0Kb', su.number_to_suffix_string(1024)) + self.assertEqual('1.0Mb', su.number_to_suffix_string(1024 * 1024)) + self.assertEqual('123', su.number_to_suffix_string(123)) + + def test_is_credit_card(self): + self.assertTrue(su.is_credit_card('4242424242424242')) + self.assertTrue(su.is_credit_card('5555555555554444')) + self.assertTrue(su.is_credit_card('378282246310005')) + self.assertTrue(su.is_credit_card('6011111111111117')) + self.assertTrue(su.is_credit_card('4000000360000006')) + self.assertFalse(su.is_credit_card('8000000360110099')) + self.assertFalse(su.is_credit_card('')) + + def test_is_camel_case(self): + self.assertFalse(su.is_camel_case('thisisatest')) + self.assertTrue(su.is_camel_case('thisIsATest')) + self.assertFalse(su.is_camel_case('this_is_a_test')) + + def test_is_snake_case(self): + self.assertFalse(su.is_snake_case('thisisatest')) + self.assertFalse(su.is_snake_case('thisIsATest')) + self.assertTrue(su.is_snake_case('this_is_a_test')) + + def test_sprintf_context(self): + with su.SprintfStdout() as buf: + print("This is a test.") + print("This is another one.") + self.assertEqual('This is a test.\nThis is another one.\n', buf()) + + +if __name__ == '__main__': + bootstrap.initialize(unittest.main)() diff --git a/tests/typez/centcount_test.py b/tests/typez/centcount_test.py new file mode 100755 index 0000000..5ba60b1 --- /dev/null +++ b/tests/typez/centcount_test.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""centcount unittest.""" + +import unittest + +from pyutils import unittest_utils +from pyutils.typez.centcount import CentCount + + +class TestCentCount(unittest.TestCase): + def test_basic_utility(self): + amount = CentCount(1.45) + another = CentCount.parse("USD 1.45") + self.assertEqual(amount, another) + + def test_negation(self): + amount = CentCount(1.45) + amount = -amount + self.assertEqual(CentCount(-1.45), amount) + + def test_addition_and_subtraction(self): + amount = CentCount(1.00) + another = CentCount(2.00) + total = amount + another + self.assertEqual(CentCount(3.00), total) + delta = another - amount + self.assertEqual(CentCount(1.00), delta) + neg = amount - another + self.assertEqual(CentCount(-1.00), neg) + neg += another + self.assertEqual(CentCount(1.00), neg) + neg += 1.00 + self.assertEqual(CentCount(2.00), neg) + neg -= 1.00 + self.assertEqual(CentCount(1.00), neg) + x = 1000 - amount + self.assertEqual(CentCount(9.0), x) + + def test_multiplication(self): + amount = CentCount(3.00) + amount *= 3 + self.assertEqual(CentCount(9.00), amount) + with self.assertRaises(TypeError): + another = CentCount(0.33) + amount *= another + + def test_division(self): + amount = CentCount(10.00) + x = amount / 5.0 + self.assertEqual(CentCount(2.00), x) + with self.assertRaises(TypeError): + another = CentCount(1.33) + amount /= another + + def test_equality(self): + usa = CentCount(1.0, 'USD') + can = CentCount(1.0, 'CAD') + self.assertNotEqual(usa, can) + eh = CentCount(1.0, 'CAD') + self.assertEqual(can, eh) + + def test_comparison(self): + one = CentCount(1.0) + two = CentCount(2.0) + three = CentCount(3.0) + neg_one = CentCount(-1) + self.assertLess(one, two) + self.assertLess(neg_one, one) + self.assertGreater(one, neg_one) + self.assertGreater(three, one) + looney = CentCount(1.0, 'CAD') + with self.assertRaises(TypeError): + print(looney < one) + + def test_strict_mode(self): + one = CentCount(1.0, strict_mode=True) + two = CentCount(2.0, strict_mode=True) + with self.assertRaises(TypeError): + x = one + 2.4 + self.assertEqual(CentCount(3.0), one + two) + with self.assertRaises(TypeError): + x = two - 1.9 + self.assertEqual(CentCount(1.0), two - one) + with self.assertRaises(TypeError): + print(one == 1.0) + self.assertTrue(CentCount(1.0) == one) + with self.assertRaises(TypeError): + print(one < 2.0) + self.assertTrue(one < two) + with self.assertRaises(TypeError): + print(two > 1.0) + self.assertTrue(two > one) + + def test_truncate_and_round(self): + ten = CentCount(10.0) + x = ten * 2 / 3 + x.truncate_fractional_cents() + self.assertEqual(CentCount(6.66), x) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/typez/money_test.py b/tests/typez/money_test.py new file mode 100755 index 0000000..ee1e392 --- /dev/null +++ b/tests/typez/money_test.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""money unittest.""" + +import unittest + +from pyutils import unittest_utils +from pyutils.typez.money import Money + + +class TestMoney(unittest.TestCase): + def test_basic_utility(self): + amount = Money(1.45) + another = Money.parse("USD 1.45") + self.assertAlmostEqual(amount.amount, another.amount) + + def test_negation(self): + amount = Money(1.45) + amount = -amount + self.assertAlmostEqual(Money(-1.45).amount, amount.amount) + + def test_addition_and_subtraction(self): + amount = Money(1.00) + another = Money(2.00) + total = amount + another + self.assertEqual(Money(3.00), total) + delta = another - amount + self.assertEqual(Money(1.00), delta) + neg = amount - another + self.assertEqual(Money(-1.00), neg) + neg += another + self.assertEqual(Money(1.00), neg) + neg += 1.00 + self.assertEqual(Money(2.00), neg) + neg -= 1 + self.assertEqual(Money(1.00), neg) + x = 10 - amount + self.assertEqual(Money(9.0), x) + + def test_multiplication(self): + amount = Money(3.00) + amount *= 3 + self.assertEqual(Money(9.00), amount) + with self.assertRaises(TypeError): + another = Money(0.33) + amount *= another + + def test_division(self): + amount = Money(10.00) + x = amount / 5.0 + self.assertEqual(Money(2.00), x) + with self.assertRaises(TypeError): + another = Money(1.33) + amount /= another + + def test_equality(self): + usa = Money(1.0, 'USD') + can = Money(1.0, 'CAD') + self.assertNotEqual(usa, can) + eh = Money(1.0, 'CAD') + self.assertEqual(can, eh) + + def test_comparison(self): + one = Money(1.0) + two = Money(2.0) + three = Money(3.0) + neg_one = Money(-1) + self.assertLess(one, two) + self.assertLess(neg_one, one) + self.assertGreater(one, neg_one) + self.assertGreater(three, one) + looney = Money(1.0, 'CAD') + with self.assertRaises(TypeError): + print(looney < one) + + def test_strict_mode(self): + one = Money(1.0, strict_mode=True) + two = Money(2.0, strict_mode=True) + with self.assertRaises(TypeError): + x = one + 2.4 + self.assertEqual(Money(3.0), one + two) + with self.assertRaises(TypeError): + x = two - 1.9 + self.assertEqual(Money(1.0), two - one) + with self.assertRaises(TypeError): + print(one == 1.0) + self.assertTrue(Money(1.0) == one) + with self.assertRaises(TypeError): + print(one < 2.0) + self.assertTrue(one < two) + with self.assertRaises(TypeError): + print(two > 1.0) + self.assertTrue(two > one) + + def test_truncate_and_round(self): + ten = Money(10.0) + x = ten * 2 / 3 + self.assertEqual(6.66, x.truncate_fractional_cents()) + x = ten * 2 / 3 + self.assertEqual(6.67, x.round_fractional_cents()) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/typez/rate_test.py b/tests/typez/rate_test.py new file mode 100755 index 0000000..800d360 --- /dev/null +++ b/tests/typez/rate_test.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""rate unittest.""" + +import unittest + +from pyutils import unittest_utils +from pyutils.typez.money import Money +from pyutils.typez.rate import Rate + + +class TestRate(unittest.TestCase): + def test_basic_utility(self): + my_stock_returns = Rate(percent_change=-20.0) + my_portfolio = 1000.0 + self.assertAlmostEqual(800.0, my_stock_returns.apply_to(my_portfolio)) + + my_bond_returns = Rate(percentage=104.5) + my_money = Money(500.0) + self.assertAlmostEqual(Money(522.5), my_bond_returns.apply_to(my_money)) + + my_multiplier = Rate(multiplier=1.72) + my_nose_length = 3.2 + self.assertAlmostEqual(5.504, my_multiplier.apply_to(my_nose_length)) + + def test_conversions(self): + x = Rate(104.55) + s = x.__repr__() + y = Rate(s) + self.assertAlmostEqual(x, y) + f = float(x) + z = Rate(f) + self.assertAlmostEqual(x, z) + + def test_divide(self): + x = Rate(20.0) + x /= 2 + self.assertAlmostEqual(10.0, x) + x = Rate(-20.0) + x /= 2 + self.assertAlmostEqual(-10.0, x) + + def test_add(self): + x = Rate(5.0) + y = Rate(10.0) + z = x + y + self.assertAlmostEqual(15.0, z) + x = Rate(-5.0) + x += y + self.assertAlmostEqual(5.0, x) + + def test_sub(self): + x = Rate(5.0) + y = Rate(10.0) + z = x - y + self.assertAlmostEqual(-5.0, z) + z = y - x + self.assertAlmostEqual(5.0, z) + + def test_repr(self): + x = Rate(percent_change=-50.0) + s = x.__repr__(relative=True) + self.assertEqual("-50.000%", s) + s = x.__repr__() + self.assertEqual("+50.000%", s) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/zookeeper_test.py b/tests/zookeeper_test.py new file mode 100755 index 0000000..332a715 --- /dev/null +++ b/tests/zookeeper_test.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 + +# © Copyright 2021-2022, Scott Gasch + +"""zookeeper unittest.""" + +import datetime +import logging +import time +import unittest + +from pyutils import unittest_utils, zookeeper + +logger = logging.getLogger(__name__) + + +class TestZookeeper(unittest.TestCase): + @zookeeper.obtain_lease( + also_pass_lease=True, duration=datetime.timedelta(minutes=1) + ) + def test_release_lease(self, lease: zookeeper.RenewableReleasableLease): + self.assertTrue(lease) + self.assertTrue(lease.release()) + self.assertFalse(lease) + self.assertFalse(lease.release()) + self.assertFalse(lease) + + @zookeeper.obtain_lease( + also_pass_lease=True, duration=datetime.timedelta(minutes=1) + ) + def test_renew_lease(self, lease: zookeeper.RenewableReleasableLease): + self.assertTrue(lease) + self.assertTrue(lease.try_renew(datetime.timedelta(minutes=2))) + self.assertTrue(lease) + self.assertTrue(lease.release()) + + @zookeeper.obtain_lease( + also_pass_lease=True, + duration=datetime.timedelta(minutes=1), + ) + def test_cant_renew_lease_after_released( + self, lease: zookeeper.RenewableReleasableLease + ): + self.assertTrue(lease) + self.assertTrue(lease.release()) + self.assertFalse(lease) + self.assertFalse(lease.try_renew(datetime.timedelta(minutes=2))) + + @zookeeper.obtain_lease( + also_pass_lease=True, duration=datetime.timedelta(seconds=5) + ) + def test_lease_expiration(self, lease: zookeeper.RenewableReleasableLease): + self.assertTrue(lease) + time.sleep(7) + self.assertFalse(lease) + + def test_leases_are_exclusive(self): + @zookeeper.obtain_lease( + contender_id='second', + duration=datetime.timedelta(seconds=10), + ) + def i_will_fail_to_get_the_lease(): + logger.debug("I seem to have gotten the lease, wtf?!?!") + self.fail("I should not have gotten the lease?!") + + @zookeeper.obtain_lease( + contender_id='first', + duration=datetime.timedelta(seconds=10), + ) + def i_will_hold_the_lease(): + logger.debug("I have the lease.") + time.sleep(1) + self.assertFalse(i_will_fail_to_get_the_lease()) + + i_will_hold_the_lease() + + +if __name__ == '__main__': + unittest.main() -- 2.45.2