Preformatted box that doesn't wrap the contents.
[python_utils.git] / ml / quick_label.py
index 7e0a6bf64921533e00d719223d2657fb21ebbccf..05efbaebcee5217b516a3c3fc37421b1e6a982fd 100644 (file)
@@ -1,10 +1,15 @@
 #!/usr/bin/env python3
 
-import glob
+# © Copyright 2021-2022, Scott Gasch
+
+"""A helper to facilitate quick manual labeling of ML training data."""
+
 import logging
 import os
-import warnings
-from typing import Callable, List, NamedTuple, Optional, Set
+import sys
+import time
+from abc import abstractmethod
+from typing import Any, Dict, List, Optional, Set, Tuple
 
 import argparse_utils
 import config
@@ -33,19 +38,94 @@ parser.add_argument(
     action=argparse_utils.ActionNoYes,
     help='Enable overwriting existing labels; default is to not relabel.',
 )
+parser.add_argument(
+    '--ml_quick_label_skip_where_model_agrees',
+    default=False,
+    action=argparse_utils.ActionNoYes,
+    help='Do not filter examples where the model disagrees with the current label.',
+)
+parser.add_argument(
+    'ml_quick_label_delete_invalid_examples',
+    default=False,
+    action='store_true',
+    help='If set we will delete invalid training examples.',
+)
+
+
+class QuickLabelHelper:
+    '''To use this quick labeler your code must create a subclass of this
+    class and implement the abstract methods below.  See comments for
+    detailed semantics.'''
+
+    @abstractmethod
+    def get_candidate_files(self) -> List[str]:
+        '''This must return a list of raw candidate files for labeling.'''
+        pass
+
+    @abstractmethod
+    def get_features_for_file(self, filename: str) -> Optional[str]:
+        '''Given a raw file, return its features file.'''
+        pass
 
+    @abstractmethod
+    def render_example(self, filename: str, features: str, lines: List[str]) -> None:
+        '''Render a raw file with its features for the user.'''
+        pass
 
-class InputSpec(NamedTuple):
-    image_file_glob: Optional[str]
-    image_file_prepopulated_list: Optional[List[str]]
-    image_file_to_features_file: Callable[[str], str]
-    label: str
-    valid_keystrokes: List[str]
-    prompt: str
-    keystroke_to_label: Callable[[str], str]
+    @abstractmethod
+    def unrender_example(self, filename: str, features: str, lines: List[str]) -> None:
+        '''Unrender a raw file with its features (if necessary)...'''
+        pass
 
+    @abstractmethod
+    def is_valid_example(self, filename: str, features: str, lines: List[str]) -> bool:
+        '''Returns true iff the example is valid (all features are valid, there
+        are the correct number of features, etc...'''
+        pass
+
+    @abstractmethod
+    def ask_current_model_about_example(
+        self,
+        filename: str,
+        features: str,
+        lines: List[str],
+    ) -> Any:
+        '''Ask the current ML model about this example, if necessary.'''
+        pass
+
+    @abstractmethod
+    def get_labelling_keystrokes(self) -> Dict[str, Any]:
+        '''What keystrokes should be considered valid label actions and what
+        label does each keystroke map into.  e.g. if you want to ask
+        the user to hit 'y' for 'yes' and code that as 255 in your
+        features or to hit 'n' for 'no' and code that as 0 in your
+        features, return:
+
+            { 'y': 255, 'n': 0 }
+
+        '''
+        pass
+
+    @abstractmethod
+    def get_everything_label(self) -> Any:
+        '''If this returns something other than None it indicates that every
+        example selected should be labeled with this result.  Caveat
+        emptor, we will klobber all your files.
+
+        '''
+        pass
+
+    @abstractmethod
+    def get_label_feature(self) -> str:
+        '''What feature denotes the example's label?  This is used to detect
+        when examples already have a label and to assign labels to
+        examples.'''
+        pass
+
+
+def _maybe_read_skip_list() -> Set[str]:
+    '''Reads the skip list (files to just bypass) into memory if using.'''
 
-def read_skip_list() -> Set[str]:
     ret: Set[str] = set()
     if config.config['ml_quick_label_use_skip_lists']:
         quick_skip_file = config.config['ml_quick_label_skip_list_path']
@@ -56,11 +136,13 @@ def read_skip_list() -> Set[str]:
                 line = line[:-1]
                 line.strip()
                 ret.add(line)
-        logger.debug(f'Read {quick_skip_file} and found {len(ret)} entries.')
+        logger.debug('Read %s and found %d entries.', quick_skip_file, len(ret))
     return ret
 
 
-def write_skip_list(skip_list) -> None:
+def _maybe_write_skip_list(skip_list) -> None:
+    '''Writes the skip list (files to just bypass) to disk if using.'''
+
     if config.config['ml_quick_label_use_skip_lists']:
         quick_skip_file = config.config['ml_quick_label_skip_list_path']
         with open(quick_skip_file, 'w') as f:
@@ -68,55 +150,175 @@ def write_skip_list(skip_list) -> None:
                 filename = filename.strip()
                 if len(filename) > 0:
                     f.write(f'{filename}\n')
-        logger.debug(f'Updated {quick_skip_file}')
-
-
-def label(in_spec: InputSpec) -> None:
-    import input_utils
+        logger.debug('Updated %s', quick_skip_file)
 
-    images = []
-    if in_spec.image_file_glob is not None:
-        images += glob.glob(in_spec.image_file_glob)
-    elif in_spec.image_file_prepopulated_list is not None:
-        images += in_spec.image_file_prepopulated_list
-    else:
-        raise ValueError('One of image_file_glob or image_file_prepopulated_list is required')
 
-    skip_list = read_skip_list()
+def _filter_images(
+    images: List[str], skip_list: Set[str], helper: QuickLabelHelper
+) -> List[Tuple[str, str]]:
+    filtered_images = []
+    label_label = helper.get_label_feature()
     for image in images:
         if image in skip_list:
-            logger.debug(f'Skipping {image} because of the skip list')
+            logger.debug('Skipping %s because of the skip list', image)
             continue
-        features = in_spec.image_file_to_features_file(image)
+
+        features = helper.get_features_for_file(image)
         if features is None or not os.path.exists(features):
-            msg = f'File {image} yielded file {features} which does not exist, SKIPPING.'
-            logger.warning(msg)
-            warnings.warn(msg)
+            logger.warning('%s/%s: features doesn\'t exist, SKIPPING.', image, features)
             continue
 
-        # Render features and image.
+        label = None
+        filtered_lines = []
+        with open(features, 'r') as rf:
+            lines = rf.readlines()
+        for line in lines:
+            line = line[:-1]
+            if line.startswith(label_label):
+                label = ''.join(line.split(':')[1:])
+                label = label.strip()
+            else:
+                filtered_lines.append(line)
+
+        if not helper.is_valid_example(image, features, filtered_lines):
+            logger.warning('%s/%s: Invalid example.', image, features)
+            if config.config['ml_quick_label_delete_invalid_examples']:
+                os.remove(image)
+                os.remove(features)
+            continue
+
+        if label and not config.config['ml_quick_label_overwrite_labels']:
+            logger.warning('%s/%s: already has label, SKIPPING.', image, features)
+            continue
+
+        if config.config['ml_quick_label_skip_where_model_agrees']:
+            model_says = helper.ask_current_model_about_example(image, features, filtered_lines)
+            if model_says and label:
+                if model_says[0] == int(label):
+                    continue
+                print(f'{image}/{features}: The model disagrees with the current label.')
+                print(f'    ...model says {model_says[0]} with probability {model_says[1]}.')
+                print(f'    ...the example is currently labeled {label}')
+        filtered_images.append((image, features))
+    return filtered_images
+
+
+def quick_label(helper: QuickLabelHelper) -> None:
+    skip_list = _maybe_read_skip_list()
+
+    # Ask helper for an initial set of files.
+    images = helper.get_candidate_files()
+    if len(images) == 0:
+        logger.warning('No images files to operate on.')
+        return
+
+    # Filter out any that can't be converted to features or already have a
+    # label (unless they used --ml_qukck_label_overwrite_labels).
+    filtered_images = _filter_images(images, skip_list, helper)
+    if len(filtered_images) == 0:
+        logger.warning('No image files to operate on (post filter).')
+        return
+
+    # Allow the user to label the non-filtered images one by one.
+    import input_utils
+
+    cursor = 0
+    label_label = helper.get_label_feature()
+    while True:
+        assert 0 <= cursor < len(filtered_images)
+
+        image = filtered_images[cursor][0]
+        assert os.path.exists(image)
+        features = filtered_images[cursor][1]
+        assert features and os.path.exists(features)
+
         filtered_lines = []
-        with open(features, "r") as f:
-            lines = f.readlines()
-        saw_label = False
+        label = None
+        with open(features, 'r') as rf:
+            lines = rf.readlines()
         for line in lines:
             line = line[:-1]
-            if in_spec.label not in line:
+            if not line.startswith(label_label):
                 filtered_lines.append(line)
             else:
-                saw_label = True
-
-        if not saw_label or config.config['ml_quick_label_overwrite_labels']:
-            logger.info(features)
-            os.system(f'xv {image} &')
-            keystroke = input_utils.single_keystroke_response(
-                in_spec.valid_keystrokes,
-                prompt=in_spec.prompt,
-            )
-            os.system('killall xv')
-            label_value = in_spec.keystroke_to_label(keystroke)
-            filtered_lines.append(f"{in_spec.label}: {label_value}\n")
+                label = line
+
+        # Render...
+        helper.render_example(image, features, filtered_lines)
+
+        # Prompt...
+        print(
+            f'{cursor} of {len(filtered_images)} {cursor/len(filtered_images)*100.0:.1f}%): {image}, {features}'
+        )
+        if label:
+            print(f'    ...Already labelled: {label}')
+        else:
+            print('    ...Currently unlabeled')
+        guess = helper.ask_current_model_about_example(image, features, filtered_lines)
+        if guess:
+            print(f'    ...Model says {guess}')
+        print()
+
+        # Did they want everything labelled the same?
+        label_everything = helper.get_everything_label()
+        if label_everything:
+            filtered_lines.append(f"{label_label}: {label_everything}\n")
             with open(features, 'w') as f:
-                f.writelines("%s\n" % line for line in filtered_lines)
-            skip_list.add(image)
-    write_skip_list(skip_list)
+                f.writelines(line + '\n' for line in filtered_lines)
+            if config.config['ml_quick_label_use_skip_lists']:
+                skip_list.add(image)
+            cursor += 1
+            if cursor >= len(filtered_images):
+                helper.unrender_example(image, features, filtered_lines)
+                break
+
+        # Otherwise ask about each example.
+        else:
+            labelling_keystrokes = helper.get_labelling_keystrokes()
+            valid_keystrokes = ['<', '>', 'Q', '?']
+            valid_keystrokes += labelling_keystrokes.keys()
+            prompt = ','.join(valid_keystrokes)
+            print(f'What should I do ({prompt})? ', end='')
+            sys.stdout.flush()
+            keystroke = input_utils.single_keystroke_response(valid_keystrokes)
+            print()
+            if keystroke == 'Q':
+                logger.info('Ok, stopping for now.  Labeled examples are written to disk')
+                helper.unrender_example(image, features, filtered_lines)
+                break
+            elif keystroke == '?':
+                print(
+                    '''
+    >   =   Don't label, move to the next example.
+    <   =   Don't label, move to the previous example.
+    Q   =   Quit labeling now.
+    ?   =   This message.
+  else  =   These keystrokes assign a label to the example and persist it.'''
+                )
+                time.sleep(3.0)
+
+            elif keystroke == '>':
+                cursor += 1
+                if cursor >= len(filtered_images):
+                    print('Wrapping around...')
+                    cursor = 0
+            elif keystroke == '<':
+                cursor -= 1
+                if cursor < 0:
+                    print('Wrapping around...')
+                    cursor = len(filtered_images) - 1
+            elif keystroke in labelling_keystrokes:
+                label_value = labelling_keystrokes[keystroke]
+                filtered_lines.append(f"{label_label}: {label_value}\n")
+                with open(features, 'w') as f:
+                    f.writelines(line + '\n' for line in filtered_lines)
+                if config.config['ml_quick_label_use_skip_lists']:
+                    skip_list.add(image)
+                cursor += 1
+                if cursor >= len(filtered_images):
+                    print('Wrapping around...')
+                    cursor = 0
+            else:
+                print(f'Unknown keystroke: {keystroke}')
+        helper.unrender_example(image, features, filtered_lines)
+    _maybe_write_skip_list(skip_list)