Reduce the doctest lease duration...
[python_utils.git] / ml / quick_label.py
index da9a1d2e00d15e32cce697e3782a5cfce87ea584..feab67b2e88484132ac80506f5add82b395e35bf 100644 (file)
@@ -1,16 +1,24 @@
 #!/usr/bin/env python3
 
-"""A helper to facilitate quick manual labeling of ML training data."""
+# © Copyright 2021-2022, Scott Gasch
+
+"""A helper to facilitate quick manual labeling of ML training data.
+
+To use, implement a subclass that implements the QuickLabelHelper
+interface and pass it into the quick_label function.
+
+"""
 
-import glob
 import logging
 import os
-import warnings
-from dataclasses import dataclass
-from typing import Callable, List, Optional, Set
+import sys
+from abc import abstractmethod
+from typing import Any, Dict, List, Optional, Set, Tuple
 
+import ansi
 import argparse_utils
 import config
+import input_utils
 
 logger = logging.getLogger(__name__)
 parser = config.add_commandline_args(
@@ -34,25 +42,97 @@ parser.add_argument(
     "--ml_quick_label_overwrite_labels",
     default=False,
     action=argparse_utils.ActionNoYes,
-    help='Enable overwriting existing labels; default is to not relabel.',
+    help='Enable overwriting of existing labels; default is to not relabel.',
+)
+parser.add_argument(
+    '--ml_quick_label_skip_where_model_agrees',
+    default=False,
+    action=argparse_utils.ActionNoYes,
+    help='Do not filter examples where the model disagrees with the current label.',
 )
+parser.add_argument(
+    '--ml_quick_label_delete_invalid_examples',
+    default=False,
+    action='store_true',
+    help='If set we will delete invalid training examples.',
+)
+
+
+class QuickLabelHelper:
+    '''To use this quick labeler your code must create a subclass of this
+    class and implement the abstract methods below.  See comments for
+    detailed semantics.'''
+
+    @abstractmethod
+    def get_candidate_files(self) -> List[str]:
+        '''This must return a list of raw candidate files for labeling.'''
+        pass
+
+    @abstractmethod
+    def get_features_for_file(self, filename: str) -> Optional[str]:
+        '''Given a raw file, return its features file.'''
+        pass
+
+    @abstractmethod
+    def render_example(self, filename: str, features: str, lines: List[str]) -> None:
+        '''Render a raw file with its features for the user to judge.'''
+        pass
 
+    @abstractmethod
+    def unrender_example(self, filename: str, features: str) -> None:
+        '''Unrender a raw file with its features (if necessary)...'''
+        pass
 
-@dataclass
-class InputSpec:
-    """A wrapper around the input data we need to operate; should be
-    populated by the caller."""
+    @abstractmethod
+    def is_valid_example(self, filename: str, features: str, lines: List[str]) -> bool:
+        '''Returns true iff the example is valid (all features are valid, there
+        are the correct number of features, etc...'''
+        pass
 
-    image_file_glob: Optional[str] = None
-    image_file_prepopulated_list: Optional[List[str]] = None
-    image_file_to_features_file: Optional[Callable[[str], str]] = None
-    label: str = ''
-    valid_keystrokes: List[str] = []
-    prompt: str = ''
-    keystroke_to_label: Optional[Callable[[str], str]] = None
+    @abstractmethod
+    def ask_current_model_about_example(
+        self,
+        filename: str,
+        features: str,
+        lines: List[str],
+    ) -> Any:
+        '''Ask the current ML model about this example, if necessary/possible.
+        Returns None to indicate no model to consult.'''
+        pass
 
+    @abstractmethod
+    def get_labelling_keystrokes(self) -> Dict[str, Any]:
+        '''What keystrokes should be considered valid label actions and what
+        label does each keystroke map into.  e.g. if you want to ask
+        the user to hit 'y' for 'yes' and code that as 255 in your
+        features or to hit 'n' for 'no' and code that as 0 in your
+        features, return:
+
+            { 'y': 255, 'n': 0 }
+
+        '''
+        pass
+
+    @abstractmethod
+    def get_everything_label(self) -> Any:
+        '''If this returns something other than None it indicates that every
+        example selected should be labeled with this result.  Caveat
+        emptor, we will klobber all your files.
+
+        '''
+        pass
+
+    @abstractmethod
+    def get_label_feature(self) -> str:
+        '''What feature denotes the example's label?  This is used to detect
+        when examples already have a label and to assign labels to
+        examples.'''
+        pass
+
+
+def _maybe_read_skip_list() -> Set[str]:
+    '''Reads the skip list (files to just bypass) into memory if using.'''
 
-def read_skip_list() -> Set[str]:
     ret: Set[str] = set()
     if config.config['ml_quick_label_use_skip_lists']:
         quick_skip_file = config.config['ml_quick_label_skip_list_path']
@@ -67,7 +147,9 @@ def read_skip_list() -> Set[str]:
     return ret
 
 
-def write_skip_list(skip_list) -> None:
+def _maybe_write_skip_list(skip_list) -> None:
+    '''Writes the skip list (files to just bypass) to disk if using.'''
+
     if config.config['ml_quick_label_use_skip_lists']:
         quick_skip_file = config.config['ml_quick_label_skip_list_path']
         with open(quick_skip_file, 'w') as f:
@@ -78,54 +160,244 @@ def write_skip_list(skip_list) -> None:
         logger.debug('Updated %s', quick_skip_file)
 
 
-def label(in_spec: InputSpec) -> None:
-    import input_utils
-
-    images = []
-    if in_spec.image_file_glob is not None:
-        images += glob.glob(in_spec.image_file_glob)
-    elif in_spec.image_file_prepopulated_list is not None:
-        images += in_spec.image_file_prepopulated_list
-    else:
-        raise ValueError('One of image_file_glob or image_file_prepopulated_list is required')
+def _filter_images(
+    images: List[str], skip_list: Set[str], helper: QuickLabelHelper
+) -> List[Tuple[str, str]]:
+    '''Discard examples that have particular characteristics.  e.g.
+    those that are already labeled and whose current label agrees with
+    the ML model, etc...'''
 
-    skip_list = read_skip_list()
+    filtered_images = []
+    label_label = helper.get_label_feature()
     for image in images:
         if image in skip_list:
-            logger.debug('Skipping %s because of the skip list', image)
+            logger.debug('Skipping %s because of the skip list.', image)
             continue
-        assert in_spec.image_file_to_features_file
-        features = in_spec.image_file_to_features_file(image)
+
+        features = helper.get_features_for_file(image)
         if features is None or not os.path.exists(features):
-            msg = f'File {image} yielded file {features} which does not exist, SKIPPING.'
-            logger.warning(msg)
-            warnings.warn(msg)
+            logger.warning('%s/%s: features file doesn\'t exist, skipping.', image, features)
+            continue
+
+        # Read it.
+        label = None
+        filtered_lines = []
+        with open(features, 'r') as rf:
+            lines = rf.readlines()
+        for line in lines:
+            line = line[:-1]
+            if line.startswith(label_label):
+                label = ''.join(line.split(':')[1:])
+                label = label.strip()
+            else:
+                filtered_lines.append(line)
+
+        if not helper.is_valid_example(image, features, filtered_lines):
+            logger.warning('%s/%s: Invalid example.', image, features)
+            if config.config['ml_quick_label_delete_invalid_examples']:
+                os.remove(image)
+                os.remove(features)
+            continue
+
+        if label and not config.config['ml_quick_label_overwrite_labels']:
+            logger.warning('%s/%s: already has label, skipping.', image, features)
             continue
 
-        # Render features and image.
+        if config.config['ml_quick_label_skip_where_model_agrees']:
+            model_says = helper.ask_current_model_about_example(image, features, filtered_lines)
+            if model_says and label:
+                if model_says[0] == int(label):
+                    logger.warning(
+                        '%s/%s: Model agrees with current label (%s), skipping.',
+                        image,
+                        features,
+                        label,
+                    )
+                    continue
+
+                print(f'{image}/{features}: The model disagrees with the current label.')
+                print(f'    ...model says {model_says[0]} with probability {model_says[1]}.')
+                print(f'    ...the example is currently labeled {label}')
+
+        filtered_images.append((image, features))
+    return filtered_images
+
+
+def _make_prompt(
+    helper: QuickLabelHelper,
+    cursor: int,
+    num_filtered_images: int,
+    current_image: str,
+    current_features: str,
+    labeled_features: Dict[Tuple[str, str], str],  # Examples already labeled
+) -> None:
+    '''Tell an interactive user where they are in the set of examples that
+    may be labeled and the details of the current example.'''
+
+    label_label = helper.get_label_feature()  # the key: of a label in features
+    filtered_lines = []
+    label = labeled_features.get((current_image, current_features), None)
+    with open(current_features, 'r') as rf:
+        lines = rf.readlines()
+    for line in lines:
+        line = line[:-1]
+        if len(line) == 0:
+            continue
+        if not line.startswith(label_label):
+            filtered_lines.append(line)
+        else:
+            assert not label
+            label = line
+
+    # Prompt...
+    helper.render_example(current_image, current_features, filtered_lines)
+    print(f'{cursor}/{num_filtered_images} ({cursor/num_filtered_images*100.0:.1f}%) | ', end='')
+    print(f'{ansi.bold()}{current_image} / {current_features}{ansi.reset()}:')
+    print(f'    ...{len(labeled_features)} currently unwritten label(s) ("W" to write).')
+    if label:
+        if (current_image, current_features) in labeled_features:
+            print(f'    ...This example is labeled but not yet saved: {label}')
+        else:
+            print(f'    ...This example is already labeled on disk: {label}')
+    else:
+        print('    ...This example is currently unlabeled')
+    guess = helper.ask_current_model_about_example(current_image, current_features, filtered_lines)
+    if guess:
+        print(f'    ...The ML Model says {guess}')
+    print()
+
+
+def _write_labeled_features(
+    helper: QuickLabelHelper,
+    labeled_features: Dict[Tuple[str, str], str],
+    skip_list: Set[str],
+) -> None:
+    label_label = helper.get_label_feature()
+    for image_features, label in labeled_features.items():
+        image = image_features[0]
+        features = image_features[1]
         filtered_lines = []
-        with open(features, "r") as f:
-            lines = f.readlines()
-        saw_label = False
+        with open(features, 'r') as rf:
+            lines = rf.readlines()
         for line in lines:
             line = line[:-1]
-            if in_spec.label not in line:
+            line = line.strip()
+            if line == '':
+                continue
+            if not line.startswith(label_label):
                 filtered_lines.append(line)
+
+        filtered_lines.append(f'{label_label}: {label}')
+        with open(features, 'w') as f:
+            f.writelines(line + '\n' for line in filtered_lines)
+            if config.config['ml_quick_label_use_skip_lists']:
+                skip_list.add(image)
+    print(f'Wrote {len(labeled_features)} labels.')
+
+
+def quick_label(helper: QuickLabelHelper) -> None:
+    '''Pass your QuickLabelHelper implementing class to this function and
+    it will allow users to label examples and persist them to disk.
+
+    '''
+    skip_list = _maybe_read_skip_list()
+
+    # Ask helper for an initial set of files.
+    images = helper.get_candidate_files()
+    if len(images) == 0:
+        logger.warning('No images files to operate on.')
+        return
+    logger.info('There are %d starting candidate images.', len(images))
+
+    # Filter out any that can't be converted to features or already have a
+    # label (unless they used --ml_qukck_label_overwrite_labels).
+    filtered_images = _filter_images(images, skip_list, helper)
+    if len(filtered_images) == 0:
+        logger.warning('No image files to operate on (post filter).')
+        return
+    logger.info('There are %d candidate images post filtering.', len(filtered_images))
+
+    # Allow the user to label the non-filtered images one by one.
+    labeled_features: Dict[Tuple[str, str], str] = {}
+    cursor = 0
+    while True:
+        assert 0 <= cursor < len(filtered_images)
+
+        image = filtered_images[cursor][0]
+        assert os.path.exists(image)
+        features = filtered_images[cursor][1]
+        assert features and os.path.exists(features)
+
+        # Render the features, image and prompt.
+        _make_prompt(helper, cursor, len(filtered_images), image, features, labeled_features)
+        try:
+            # Did they want everything labelled the same?
+            label_everything = helper.get_everything_label()
+            if label_everything:
+                labeled_features[(image, features)] = label_everything
+                filtered_images.remove((image, features))
+                if len(filtered_images) == 0:
+                    print('Nothing more to label.')
+                    break
+                if cursor >= len(filtered_images):
+                    cursor -= 1
+
+            # Otherwise ask about each individual example.
             else:
-                saw_label = True
-
-        if not saw_label or config.config['ml_quick_label_overwrite_labels']:
-            logger.info(features)
-            os.system(f'xv {image} &')
-            keystroke = input_utils.single_keystroke_response(
-                in_spec.valid_keystrokes,
-                prompt=in_spec.prompt,
-            )
-            os.system('killall xv')
-            assert in_spec.keystroke_to_label
-            label_value = in_spec.keystroke_to_label(keystroke)
-            filtered_lines.append(f"{in_spec.label}: {label_value}\n")
-            with open(features, 'w') as f:
-                f.writelines(line + '\n' for line in filtered_lines)
-            skip_list.add(image)
-    write_skip_list(skip_list)
+                labelling_keystrokes = helper.get_labelling_keystrokes()
+                valid_keystrokes = ['<', '>', 'W', 'Q', '?']
+                valid_keystrokes += labelling_keystrokes.keys()
+                prompt = ','.join(valid_keystrokes)
+                print(f'What should I do?  [{prompt}]: ', end='')
+                sys.stdout.flush()
+                keystroke = input_utils.single_keystroke_response(valid_keystrokes)
+                print()
+
+                if keystroke == '?':
+                    print(
+                        '''
+        >   =   Don't label, move to the next example.
+        <   =   Don't label, move to the previous example.
+        W   =   Write pending labels to disk now.
+        Q   =   Quit labeling now.
+        ?   =   This message.
+      else  =   These keystrokes assign a label to the example and persist it.'''
+                    )
+                    input_utils.press_any_key()
+                elif keystroke == 'Q':
+                    logger.info('Ok, stopping for now.')
+                    if len(labeled_features):
+                        logger.info('Discarding %d unsaved labels.', len(labeled_features))
+                    break
+                elif keystroke == '>':
+                    cursor += 1
+                    if cursor >= len(filtered_images):
+                        print('Wrapping around...')
+                        cursor = 0
+                elif keystroke == '<':
+                    cursor -= 1
+                    if cursor < 0:
+                        print('Wrapping around...')
+                        cursor = len(filtered_images) - 1
+                elif keystroke == 'W':
+                    _write_labeled_features(helper, labeled_features, skip_list)
+                    labeled_features = {}
+                elif keystroke in labelling_keystrokes:
+                    label_value = labelling_keystrokes[keystroke]
+                    labeled_features[(image, features)] = label_value
+                    filtered_images.remove((image, features))
+                    if len(filtered_images) == 0:
+                        print('Nothing more to label.')
+                        break
+                    if cursor >= len(filtered_images):
+                        cursor -= 1
+                else:
+                    print(f'Unknown keystroke: {keystroke}.  Try again.')
+        finally:
+            helper.unrender_example(image, features)
+
+    if len(labeled_features):
+        yn = input_utils.yn_response(f'Save the {len(labeled_features)} labels to disk?  [Y/N]: ')
+        if yn in ('Y', 'y'):
+            _write_labeled_features(helper, labeled_features, skip_list)
+    _maybe_write_skip_list(skip_list)