Overhaul quick labeler, part 1.
[python_utils.git] / ml / quick_label.py
index 15256a30da2e8847f15bf114dc23eb7ff755c930..7bd43c3c8d67338969a42a267b508863d0ff44f5 100644 (file)
@@ -4,12 +4,13 @@
 
 """A helper to facilitate quick manual labeling of ML training data."""
 
-import glob
 import logging
 import os
+import sys
+import time
 import warnings
-from dataclasses import dataclass
-from typing import Callable, List, Optional, Set
+from abc import abstractmethod
+from typing import Any, Dict, List, Optional, Set
 
 import argparse_utils
 import config
@@ -40,21 +41,74 @@ parser.add_argument(
 )
 
 
-@dataclass
-class InputSpec:
-    """A wrapper around the input data we need to operate; should be
-    populated by the caller."""
+class QuickLabelHelper:
+    '''To use this quick labeler your code must create a subclass of this
+    class and implement the abstract methods below.  See comments for
+    detailed semantics.'''
 
-    image_file_glob: Optional[str] = None
-    image_file_prepopulated_list: Optional[List[str]] = None
-    image_file_to_features_file: Optional[Callable[[str], str]] = None
-    label: str = ''
-    valid_keystrokes: List[str] = []
-    prompt: str = ''
-    keystroke_to_label: Optional[Callable[[str], str]] = None
+    @abstractmethod
+    def get_candidate_files(self) -> List[str]:
+        '''This must return a list of raw candidate files for labeling.'''
+        pass
 
+    @abstractmethod
+    def get_features_for_file(self, filename: str) -> Optional[str]:
+        '''Given a raw file, return its features file.'''
+        pass
+
+    @abstractmethod
+    def render_example(self, filename: str, features: str, lines: List[str]) -> None:
+        '''Render a raw file with its features for the user.'''
+        pass
+
+    @abstractmethod
+    def unrender_example(self, filename: str, features: str, lines: List[str]) -> None:
+        '''Unrender a raw file with its features (if necessary)...'''
+        pass
+
+    @abstractmethod
+    def ask_current_model_about_example(
+        self,
+        filename: str,
+        features: str,
+        lines: List[str],
+    ) -> Any:
+        '''Ask the current ML model about this example, if necessary.'''
+        pass
+
+    @abstractmethod
+    def get_labelling_keystrokes(self) -> Dict[str, Any]:
+        '''What keystrokes should be considered valid label actions and what
+        label does each keystroke map into.  e.g. if you want to ask
+        the user to hit 'y' for 'yes' and code that as 255 in your
+        features or to hit 'n' for 'no' and code that as 0 in your
+        features, return:
+
+            { 'y': 255, 'n': 0 }
+
+        '''
+        pass
+
+    @abstractmethod
+    def get_everything_label(self) -> Any:
+        '''If this returns something other than None it indicates that every
+        example selected should be labeled with this result.  Caveat
+        emptor, we will klobber all your files.
+
+        '''
+        pass
+
+    @abstractmethod
+    def get_label_feature(self) -> str:
+        '''What feature denotes the example's label?  This is used to detect
+        when examples already have a label and to assign labels to
+        examples.'''
+        pass
+
+
+def _maybe_read_skip_list() -> Set[str]:
+    '''Reads the skip list (files to just bypass) into memory if using.'''
 
-def read_skip_list() -> Set[str]:
     ret: Set[str] = set()
     if config.config['ml_quick_label_use_skip_lists']:
         quick_skip_file = config.config['ml_quick_label_skip_list_path']
@@ -69,7 +123,9 @@ def read_skip_list() -> Set[str]:
     return ret
 
 
-def write_skip_list(skip_list) -> None:
+def _maybe_write_skip_list(skip_list) -> None:
+    '''Writes the skip list (files to just bypass) to disk if using.'''
+
     if config.config['ml_quick_label_use_skip_lists']:
         quick_skip_file = config.config['ml_quick_label_skip_list_path']
         with open(quick_skip_file, 'w') as f:
@@ -80,54 +136,146 @@ def write_skip_list(skip_list) -> None:
         logger.debug('Updated %s', quick_skip_file)
 
 
-def label(in_spec: InputSpec) -> None:
-    import input_utils
-
-    images = []
-    if in_spec.image_file_glob is not None:
-        images += glob.glob(in_spec.image_file_glob)
-    elif in_spec.image_file_prepopulated_list is not None:
-        images += in_spec.image_file_prepopulated_list
-    else:
-        raise ValueError('One of image_file_glob or image_file_prepopulated_list is required')
+def quick_label(helper: QuickLabelHelper) -> None:
+    # Ask helper for an initial set of files.
+    images = helper.get_candidate_files()
+    if len(images) == 0:
+        logger.warning('No images files to operate on.')
+        return
 
-    skip_list = read_skip_list()
+    # Filter out any that can't be converted to features or already have a
+    # label (unless they used --ml_qukck_label_overwrite_labels).
+    filtered_images = []
+    skip_list = _maybe_read_skip_list()
     for image in images:
         if image in skip_list:
             logger.debug('Skipping %s because of the skip list', image)
             continue
-        assert in_spec.image_file_to_features_file
-        features = in_spec.image_file_to_features_file(image)
+
+        features = helper.get_features_for_file(image)
         if features is None or not os.path.exists(features):
-            msg = f'File {image} yielded file {features} which does not exist, SKIPPING.'
+            msg = f'{image}/{features}: {features} doesn\'t exist, SKIPPING.'
+            logger.warning(msg)
+            warnings.warn(msg)
+            continue
+
+        label_label = helper.get_label_feature()
+        label = None
+        with open(features, 'r') as rf:
+            lines = rf.readlines()
+        for line in lines:
+            line = line[:-1]
+            if line.startswith(label_label):
+                label = line
+        if label and not config.config['ml_quick_label_overwrite_labels']:
+            msg = f'{image}/{features}: already has label, SKIPPING'
             logger.warning(msg)
             warnings.warn(msg)
             continue
+        filtered_images.append((image, features))
+
+    if len(filtered_images) == 0:
+        logger.warning('No image files to operate on (post filter).')
+        return
+
+    cursor = 0
+    import input_utils
+
+    while True:
+        assert 0 <= cursor < len(filtered_images)
+
+        image = filtered_images[cursor][0]
+        assert os.path.exists(image)
+        features = filtered_images[cursor][1]
+        assert features and os.path.exists(features)
 
-        # Render features and image.
         filtered_lines = []
-        with open(features, "r") as f:
-            lines = f.readlines()
-        saw_label = False
+        label = None
+        with open(features, 'r') as rf:
+            lines = rf.readlines()
         for line in lines:
             line = line[:-1]
-            if in_spec.label not in line:
+            if not line.startswith(label_label):
                 filtered_lines.append(line)
             else:
-                saw_label = True
-
-        if not saw_label or config.config['ml_quick_label_overwrite_labels']:
-            logger.info(features)
-            os.system(f'xv {image} &')
-            keystroke = input_utils.single_keystroke_response(
-                in_spec.valid_keystrokes,
-                prompt=in_spec.prompt,
-            )
-            os.system('killall xv')
-            assert in_spec.keystroke_to_label
-            label_value = in_spec.keystroke_to_label(keystroke)
-            filtered_lines.append(f"{in_spec.label}: {label_value}\n")
+                label = line
+
+        # Render...
+        helper.render_example(image, features, filtered_lines)
+
+        # Prompt...
+        print(
+            f'{cursor} of {len(filtered_images)} {cursor/len(filtered_images)*100.0:.1f}%): {image}, {features}'
+        )
+        if label:
+            print(f'    ...Already labelled: {label}')
+        else:
+            print('    ...Currently unlabeled')
+        guess = helper.ask_current_model_about_example(image, features, filtered_lines)
+        if guess:
+            print(f'    ...Model says {guess}')
+        print()
+
+        # Did they want everything labelled the same?
+        label_everything = helper.get_everything_label()
+        if label_everything:
+            filtered_lines.append(f"{label_label}: {label_everything}\n")
             with open(features, 'w') as f:
                 f.writelines(line + '\n' for line in filtered_lines)
-            skip_list.add(image)
-    write_skip_list(skip_list)
+            if config.config['ml_quick_label_use_skip_lists']:
+                skip_list.add(image)
+            cursor += 1
+            if cursor >= len(filtered_images):
+                helper.unrender_example(image, features, filtered_lines)
+                break
+
+        # Otherwise ask about each example.
+        else:
+            labelling_keystrokes = helper.get_labelling_keystrokes()
+            valid_keystrokes = ['<', '>', 'Q', '?']
+            valid_keystrokes += labelling_keystrokes.keys()
+            prompt = ','.join(valid_keystrokes)
+            print(f'What should I do ({prompt})? ', end='')
+            sys.stdout.flush()
+            keystroke = input_utils.single_keystroke_response(valid_keystrokes)
+            print()
+            if keystroke == 'Q':
+                logger.info('Ok, stopping for now.  Labeled examples are written to disk')
+                helper.unrender_example(image, features, filtered_lines)
+                break
+            elif keystroke == '?':
+                print(
+                    '''
+    >   =   Don't label, move to the next example.
+    <   =   Don't label, move to the previous example.
+    Q   =   Quit labeling now.
+    ?   =   This message.
+  else  =   These keystrokes assign a label to the example and persist it.'''
+                )
+                time.sleep(3.0)
+
+            elif keystroke == '>':
+                cursor += 1
+                if cursor >= len(filtered_images):
+                    print('Wrapping around...')
+                    cursor = 0
+            elif keystroke == '<':
+                cursor -= 1
+                if cursor < 0:
+                    print('Wrapping around...')
+                    cursor = len(filtered_images) - 1
+            elif keystroke in labelling_keystrokes:
+                label_value = labelling_keystrokes[keystroke]
+                filtered_lines.append(f"{label_label}: {label_value}\n")
+                with open(features, 'w') as f:
+                    f.writelines(line + '\n' for line in filtered_lines)
+                if config.config['ml_quick_label_use_skip_lists']:
+                    skip_list.add(image)
+                cursor += 1
+                if cursor >= len(filtered_images):
+                    print('Wrapping around...')
+                    cursor = 0
+            else:
+                print(f'Unknown keystroke: {keystroke}')
+        helper.unrender_example(image, features, filtered_lines)
+    _maybe_write_skip_list(skip_list)