Reduce the doctest lease duration...
[python_utils.git] / ml / quick_label.py
index 05efbaebcee5217b516a3c3fc37421b1e6a982fd..feab67b2e88484132ac80506f5add82b395e35bf 100644 (file)
@@ -2,17 +2,23 @@
 
 # © Copyright 2021-2022, Scott Gasch
 
-"""A helper to facilitate quick manual labeling of ML training data."""
+"""A helper to facilitate quick manual labeling of ML training data.
+
+To use, implement a subclass that implements the QuickLabelHelper
+interface and pass it into the quick_label function.
+
+"""
 
 import logging
 import os
 import sys
-import time
 from abc import abstractmethod
 from typing import Any, Dict, List, Optional, Set, Tuple
 
+import ansi
 import argparse_utils
 import config
+import input_utils
 
 logger = logging.getLogger(__name__)
 parser = config.add_commandline_args(
@@ -36,7 +42,7 @@ parser.add_argument(
     "--ml_quick_label_overwrite_labels",
     default=False,
     action=argparse_utils.ActionNoYes,
-    help='Enable overwriting existing labels; default is to not relabel.',
+    help='Enable overwriting of existing labels; default is to not relabel.',
 )
 parser.add_argument(
     '--ml_quick_label_skip_where_model_agrees',
@@ -45,7 +51,7 @@ parser.add_argument(
     help='Do not filter examples where the model disagrees with the current label.',
 )
 parser.add_argument(
-    'ml_quick_label_delete_invalid_examples',
+    '--ml_quick_label_delete_invalid_examples',
     default=False,
     action='store_true',
     help='If set we will delete invalid training examples.',
@@ -69,11 +75,11 @@ class QuickLabelHelper:
 
     @abstractmethod
     def render_example(self, filename: str, features: str, lines: List[str]) -> None:
-        '''Render a raw file with its features for the user.'''
+        '''Render a raw file with its features for the user to judge.'''
         pass
 
     @abstractmethod
-    def unrender_example(self, filename: str, features: str, lines: List[str]) -> None:
+    def unrender_example(self, filename: str, features: str) -> None:
         '''Unrender a raw file with its features (if necessary)...'''
         pass
 
@@ -90,7 +96,8 @@ class QuickLabelHelper:
         features: str,
         lines: List[str],
     ) -> Any:
-        '''Ask the current ML model about this example, if necessary.'''
+        '''Ask the current ML model about this example, if necessary/possible.
+        Returns None to indicate no model to consult.'''
         pass
 
     @abstractmethod
@@ -156,18 +163,23 @@ def _maybe_write_skip_list(skip_list) -> None:
 def _filter_images(
     images: List[str], skip_list: Set[str], helper: QuickLabelHelper
 ) -> List[Tuple[str, str]]:
+    '''Discard examples that have particular characteristics.  e.g.
+    those that are already labeled and whose current label agrees with
+    the ML model, etc...'''
+
     filtered_images = []
     label_label = helper.get_label_feature()
     for image in images:
         if image in skip_list:
-            logger.debug('Skipping %s because of the skip list', image)
+            logger.debug('Skipping %s because of the skip list.', image)
             continue
 
         features = helper.get_features_for_file(image)
         if features is None or not os.path.exists(features):
-            logger.warning('%s/%s: features doesn\'t exist, SKIPPING.', image, features)
+            logger.warning('%s/%s: features file doesn\'t exist, skipping.', image, features)
             continue
 
+        # Read it.
         label = None
         filtered_lines = []
         with open(features, 'r') as rf:
@@ -188,22 +200,106 @@ def _filter_images(
             continue
 
         if label and not config.config['ml_quick_label_overwrite_labels']:
-            logger.warning('%s/%s: already has label, SKIPPING.', image, features)
+            logger.warning('%s/%s: already has label, skipping.', image, features)
             continue
 
         if config.config['ml_quick_label_skip_where_model_agrees']:
             model_says = helper.ask_current_model_about_example(image, features, filtered_lines)
             if model_says and label:
                 if model_says[0] == int(label):
+                    logger.warning(
+                        '%s/%s: Model agrees with current label (%s), skipping.',
+                        image,
+                        features,
+                        label,
+                    )
                     continue
+
                 print(f'{image}/{features}: The model disagrees with the current label.')
                 print(f'    ...model says {model_says[0]} with probability {model_says[1]}.')
                 print(f'    ...the example is currently labeled {label}')
+
         filtered_images.append((image, features))
     return filtered_images
 
 
+def _make_prompt(
+    helper: QuickLabelHelper,
+    cursor: int,
+    num_filtered_images: int,
+    current_image: str,
+    current_features: str,
+    labeled_features: Dict[Tuple[str, str], str],  # Examples already labeled
+) -> None:
+    '''Tell an interactive user where they are in the set of examples that
+    may be labeled and the details of the current example.'''
+
+    label_label = helper.get_label_feature()  # the key: of a label in features
+    filtered_lines = []
+    label = labeled_features.get((current_image, current_features), None)
+    with open(current_features, 'r') as rf:
+        lines = rf.readlines()
+    for line in lines:
+        line = line[:-1]
+        if len(line) == 0:
+            continue
+        if not line.startswith(label_label):
+            filtered_lines.append(line)
+        else:
+            assert not label
+            label = line
+
+    # Prompt...
+    helper.render_example(current_image, current_features, filtered_lines)
+    print(f'{cursor}/{num_filtered_images} ({cursor/num_filtered_images*100.0:.1f}%) | ', end='')
+    print(f'{ansi.bold()}{current_image} / {current_features}{ansi.reset()}:')
+    print(f'    ...{len(labeled_features)} currently unwritten label(s) ("W" to write).')
+    if label:
+        if (current_image, current_features) in labeled_features:
+            print(f'    ...This example is labeled but not yet saved: {label}')
+        else:
+            print(f'    ...This example is already labeled on disk: {label}')
+    else:
+        print('    ...This example is currently unlabeled')
+    guess = helper.ask_current_model_about_example(current_image, current_features, filtered_lines)
+    if guess:
+        print(f'    ...The ML Model says {guess}')
+    print()
+
+
+def _write_labeled_features(
+    helper: QuickLabelHelper,
+    labeled_features: Dict[Tuple[str, str], str],
+    skip_list: Set[str],
+) -> None:
+    label_label = helper.get_label_feature()
+    for image_features, label in labeled_features.items():
+        image = image_features[0]
+        features = image_features[1]
+        filtered_lines = []
+        with open(features, 'r') as rf:
+            lines = rf.readlines()
+        for line in lines:
+            line = line[:-1]
+            line = line.strip()
+            if line == '':
+                continue
+            if not line.startswith(label_label):
+                filtered_lines.append(line)
+
+        filtered_lines.append(f'{label_label}: {label}')
+        with open(features, 'w') as f:
+            f.writelines(line + '\n' for line in filtered_lines)
+            if config.config['ml_quick_label_use_skip_lists']:
+                skip_list.add(image)
+    print(f'Wrote {len(labeled_features)} labels.')
+
+
 def quick_label(helper: QuickLabelHelper) -> None:
+    '''Pass your QuickLabelHelper implementing class to this function and
+    it will allow users to label examples and persist them to disk.
+
+    '''
     skip_list = _maybe_read_skip_list()
 
     # Ask helper for an initial set of files.
@@ -211,6 +307,7 @@ def quick_label(helper: QuickLabelHelper) -> None:
     if len(images) == 0:
         logger.warning('No images files to operate on.')
         return
+    logger.info('There are %d starting candidate images.', len(images))
 
     # Filter out any that can't be converted to features or already have a
     # label (unless they used --ml_qukck_label_overwrite_labels).
@@ -218,12 +315,11 @@ def quick_label(helper: QuickLabelHelper) -> None:
     if len(filtered_images) == 0:
         logger.warning('No image files to operate on (post filter).')
         return
+    logger.info('There are %d candidate images post filtering.', len(filtered_images))
 
     # Allow the user to label the non-filtered images one by one.
-    import input_utils
-
+    labeled_features: Dict[Tuple[str, str], str] = {}
     cursor = 0
-    label_label = helper.get_label_feature()
     while True:
         assert 0 <= cursor < len(filtered_images)
 
@@ -232,93 +328,76 @@ def quick_label(helper: QuickLabelHelper) -> None:
         features = filtered_images[cursor][1]
         assert features and os.path.exists(features)
 
-        filtered_lines = []
-        label = None
-        with open(features, 'r') as rf:
-            lines = rf.readlines()
-        for line in lines:
-            line = line[:-1]
-            if not line.startswith(label_label):
-                filtered_lines.append(line)
-            else:
-                label = line
-
-        # Render...
-        helper.render_example(image, features, filtered_lines)
-
-        # Prompt...
-        print(
-            f'{cursor} of {len(filtered_images)} {cursor/len(filtered_images)*100.0:.1f}%): {image}, {features}'
-        )
-        if label:
-            print(f'    ...Already labelled: {label}')
-        else:
-            print('    ...Currently unlabeled')
-        guess = helper.ask_current_model_about_example(image, features, filtered_lines)
-        if guess:
-            print(f'    ...Model says {guess}')
-        print()
-
-        # Did they want everything labelled the same?
-        label_everything = helper.get_everything_label()
-        if label_everything:
-            filtered_lines.append(f"{label_label}: {label_everything}\n")
-            with open(features, 'w') as f:
-                f.writelines(line + '\n' for line in filtered_lines)
-            if config.config['ml_quick_label_use_skip_lists']:
-                skip_list.add(image)
-            cursor += 1
-            if cursor >= len(filtered_images):
-                helper.unrender_example(image, features, filtered_lines)
-                break
-
-        # Otherwise ask about each example.
-        else:
-            labelling_keystrokes = helper.get_labelling_keystrokes()
-            valid_keystrokes = ['<', '>', 'Q', '?']
-            valid_keystrokes += labelling_keystrokes.keys()
-            prompt = ','.join(valid_keystrokes)
-            print(f'What should I do ({prompt})? ', end='')
-            sys.stdout.flush()
-            keystroke = input_utils.single_keystroke_response(valid_keystrokes)
-            print()
-            if keystroke == 'Q':
-                logger.info('Ok, stopping for now.  Labeled examples are written to disk')
-                helper.unrender_example(image, features, filtered_lines)
-                break
-            elif keystroke == '?':
-                print(
-                    '''
-    >   =   Don't label, move to the next example.
-    <   =   Don't label, move to the previous example.
-    Q   =   Quit labeling now.
-    ?   =   This message.
-  else  =   These keystrokes assign a label to the example and persist it.'''
-                )
-                time.sleep(3.0)
-
-            elif keystroke == '>':
-                cursor += 1
+        # Render the features, image and prompt.
+        _make_prompt(helper, cursor, len(filtered_images), image, features, labeled_features)
+        try:
+            # Did they want everything labelled the same?
+            label_everything = helper.get_everything_label()
+            if label_everything:
+                labeled_features[(image, features)] = label_everything
+                filtered_images.remove((image, features))
+                if len(filtered_images) == 0:
+                    print('Nothing more to label.')
+                    break
                 if cursor >= len(filtered_images):
-                    print('Wrapping around...')
-                    cursor = 0
-            elif keystroke == '<':
-                cursor -= 1
-                if cursor < 0:
-                    print('Wrapping around...')
-                    cursor = len(filtered_images) - 1
-            elif keystroke in labelling_keystrokes:
-                label_value = labelling_keystrokes[keystroke]
-                filtered_lines.append(f"{label_label}: {label_value}\n")
-                with open(features, 'w') as f:
-                    f.writelines(line + '\n' for line in filtered_lines)
-                if config.config['ml_quick_label_use_skip_lists']:
-                    skip_list.add(image)
-                cursor += 1
-                if cursor >= len(filtered_images):
-                    print('Wrapping around...')
-                    cursor = 0
+                    cursor -= 1
+
+            # Otherwise ask about each individual example.
             else:
-                print(f'Unknown keystroke: {keystroke}')
-        helper.unrender_example(image, features, filtered_lines)
+                labelling_keystrokes = helper.get_labelling_keystrokes()
+                valid_keystrokes = ['<', '>', 'W', 'Q', '?']
+                valid_keystrokes += labelling_keystrokes.keys()
+                prompt = ','.join(valid_keystrokes)
+                print(f'What should I do?  [{prompt}]: ', end='')
+                sys.stdout.flush()
+                keystroke = input_utils.single_keystroke_response(valid_keystrokes)
+                print()
+
+                if keystroke == '?':
+                    print(
+                        '''
+        >   =   Don't label, move to the next example.
+        <   =   Don't label, move to the previous example.
+        W   =   Write pending labels to disk now.
+        Q   =   Quit labeling now.
+        ?   =   This message.
+      else  =   These keystrokes assign a label to the example and persist it.'''
+                    )
+                    input_utils.press_any_key()
+                elif keystroke == 'Q':
+                    logger.info('Ok, stopping for now.')
+                    if len(labeled_features):
+                        logger.info('Discarding %d unsaved labels.', len(labeled_features))
+                    break
+                elif keystroke == '>':
+                    cursor += 1
+                    if cursor >= len(filtered_images):
+                        print('Wrapping around...')
+                        cursor = 0
+                elif keystroke == '<':
+                    cursor -= 1
+                    if cursor < 0:
+                        print('Wrapping around...')
+                        cursor = len(filtered_images) - 1
+                elif keystroke == 'W':
+                    _write_labeled_features(helper, labeled_features, skip_list)
+                    labeled_features = {}
+                elif keystroke in labelling_keystrokes:
+                    label_value = labelling_keystrokes[keystroke]
+                    labeled_features[(image, features)] = label_value
+                    filtered_images.remove((image, features))
+                    if len(filtered_images) == 0:
+                        print('Nothing more to label.')
+                        break
+                    if cursor >= len(filtered_images):
+                        cursor -= 1
+                else:
+                    print(f'Unknown keystroke: {keystroke}.  Try again.')
+        finally:
+            helper.unrender_example(image, features)
+
+    if len(labeled_features):
+        yn = input_utils.yn_response(f'Save the {len(labeled_features)} labels to disk?  [Y/N]: ')
+        if yn in ('Y', 'y'):
+            _write_labeled_features(helper, labeled_features, skip_list)
     _maybe_write_skip_list(skip_list)