Improve the quick labeler again.
authorScott Gasch <[email protected]>
Tue, 19 Apr 2022 20:53:49 +0000 (13:53 -0700)
committerScott Gasch <[email protected]>
Tue, 19 Apr 2022 20:53:49 +0000 (13:53 -0700)
ml/quick_label.py

index 05efbaebcee5217b516a3c3fc37421b1e6a982fd..00acf05f4e389a36096b8cd13e6cbbee56985446 100644 (file)
@@ -7,12 +7,13 @@
 import logging
 import os
 import sys
-import time
 from abc import abstractmethod
 from typing import Any, Dict, List, Optional, Set, Tuple
 
+import ansi
 import argparse_utils
 import config
+import input_utils
 
 logger = logging.getLogger(__name__)
 parser = config.add_commandline_args(
@@ -36,7 +37,7 @@ parser.add_argument(
     "--ml_quick_label_overwrite_labels",
     default=False,
     action=argparse_utils.ActionNoYes,
-    help='Enable overwriting existing labels; default is to not relabel.',
+    help='Enable overwriting of existing labels; default is to not relabel.',
 )
 parser.add_argument(
     '--ml_quick_label_skip_where_model_agrees',
@@ -45,7 +46,7 @@ parser.add_argument(
     help='Do not filter examples where the model disagrees with the current label.',
 )
 parser.add_argument(
-    'ml_quick_label_delete_invalid_examples',
+    '--ml_quick_label_delete_invalid_examples',
     default=False,
     action='store_true',
     help='If set we will delete invalid training examples.',
@@ -73,7 +74,7 @@ class QuickLabelHelper:
         pass
 
     @abstractmethod
-    def unrender_example(self, filename: str, features: str, lines: List[str]) -> None:
+    def unrender_example(self, filename: str, features: str) -> None:
         '''Unrender a raw file with its features (if necessary)...'''
         pass
 
@@ -160,14 +161,15 @@ def _filter_images(
     label_label = helper.get_label_feature()
     for image in images:
         if image in skip_list:
-            logger.debug('Skipping %s because of the skip list', image)
+            logger.debug('Skipping %s because of the skip list.', image)
             continue
 
         features = helper.get_features_for_file(image)
         if features is None or not os.path.exists(features):
-            logger.warning('%s/%s: features doesn\'t exist, SKIPPING.', image, features)
+            logger.warning('%s/%s: features file doesn\'t exist, skipping.', image, features)
             continue
 
+        # Read it.
         label = None
         filtered_lines = []
         with open(features, 'r') as rf:
@@ -188,21 +190,98 @@ def _filter_images(
             continue
 
         if label and not config.config['ml_quick_label_overwrite_labels']:
-            logger.warning('%s/%s: already has label, SKIPPING.', image, features)
+            logger.warning('%s/%s: already has label, skipping.', image, features)
             continue
 
         if config.config['ml_quick_label_skip_where_model_agrees']:
             model_says = helper.ask_current_model_about_example(image, features, filtered_lines)
             if model_says and label:
                 if model_says[0] == int(label):
+                    logger.warning(
+                        '%s/%s: Model agrees with current label (%s), skipping.',
+                        image,
+                        features,
+                        label,
+                    )
                     continue
+
                 print(f'{image}/{features}: The model disagrees with the current label.')
                 print(f'    ...model says {model_says[0]} with probability {model_says[1]}.')
                 print(f'    ...the example is currently labeled {label}')
+
         filtered_images.append((image, features))
     return filtered_images
 
 
+def _make_prompt(
+    helper: QuickLabelHelper,
+    cursor: int,
+    filtered_images: List[Tuple[str, str]],
+    image: str,
+    features: str,
+    labeled_features: Dict[Tuple[str, str], str],
+) -> None:
+    label_label = helper.get_label_feature()
+    filtered_lines = []
+    label = labeled_features.get((image, features), None)
+    with open(features, 'r') as rf:
+        lines = rf.readlines()
+    for line in lines:
+        line = line[:-1]
+        if len(line) == 0:
+            continue
+        if not line.startswith(label_label):
+            filtered_lines.append(line)
+        else:
+            assert not label
+            label = line
+
+    # Prompt...
+    helper.render_example(image, features, filtered_lines)
+    print(f'{cursor}/{len(filtered_images)} ({cursor/len(filtered_images)*100.0:.1f}%) | ', end='')
+    print(f'{ansi.bold()}{image} / {features}{ansi.reset()}:')
+    print(f'    ...{len(labeled_features)} currently unsaved labels ("W" to save).')
+    if label:
+        if (image, features) in labeled_features:
+            print(f'    ...This example is labeled but not yet saved: {label}')
+        else:
+            print(f'    ...This example is already labeled on disk: {label}')
+    else:
+        print('    ...This example is currently unlabeled')
+    guess = helper.ask_current_model_about_example(image, features, filtered_lines)
+    if guess:
+        print(f'    ...The ML Model says {guess}')
+    print()
+
+
+def _write_labeled_features(
+    helper: QuickLabelHelper,
+    labeled_features: Dict[Tuple[str, str], str],
+    skip_list: Set[str],
+) -> None:
+    label_label = helper.get_label_feature()
+    for image_features, label in labeled_features.items():
+        image = image_features[0]
+        features = image_features[1]
+        filtered_lines = []
+        with open(features, 'r') as rf:
+            lines = rf.readlines()
+        for line in lines:
+            line = line[:-1]
+            line = line.strip()
+            if line == '':
+                continue
+            if not line.startswith(label_label):
+                filtered_lines.append(line)
+
+        filtered_lines.append(f'{label_label}: {label}')
+        with open(features, 'w') as f:
+            f.writelines(line + '\n' for line in filtered_lines)
+            if config.config['ml_quick_label_use_skip_lists']:
+                skip_list.add(image)
+    print(f'Wrote {len(labeled_features)} labels.')
+
+
 def quick_label(helper: QuickLabelHelper) -> None:
     skip_list = _maybe_read_skip_list()
 
@@ -211,6 +290,7 @@ def quick_label(helper: QuickLabelHelper) -> None:
     if len(images) == 0:
         logger.warning('No images files to operate on.')
         return
+    logger.info('There are %d starting candidate images.', len(images))
 
     # Filter out any that can't be converted to features or already have a
     # label (unless they used --ml_qukck_label_overwrite_labels).
@@ -218,12 +298,12 @@ def quick_label(helper: QuickLabelHelper) -> None:
     if len(filtered_images) == 0:
         logger.warning('No image files to operate on (post filter).')
         return
+    logger.info('There are %d candidate images post filtering.', len(filtered_images))
 
     # Allow the user to label the non-filtered images one by one.
-    import input_utils
 
+    labeled_features: Dict[Tuple[str, str], str] = {}
     cursor = 0
-    label_label = helper.get_label_feature()
     while True:
         assert 0 <= cursor < len(filtered_images)
 
@@ -232,93 +312,76 @@ def quick_label(helper: QuickLabelHelper) -> None:
         features = filtered_images[cursor][1]
         assert features and os.path.exists(features)
 
-        filtered_lines = []
-        label = None
-        with open(features, 'r') as rf:
-            lines = rf.readlines()
-        for line in lines:
-            line = line[:-1]
-            if not line.startswith(label_label):
-                filtered_lines.append(line)
-            else:
-                label = line
-
-        # Render...
-        helper.render_example(image, features, filtered_lines)
-
-        # Prompt...
-        print(
-            f'{cursor} of {len(filtered_images)} {cursor/len(filtered_images)*100.0:.1f}%): {image}, {features}'
-        )
-        if label:
-            print(f'    ...Already labelled: {label}')
-        else:
-            print('    ...Currently unlabeled')
-        guess = helper.ask_current_model_about_example(image, features, filtered_lines)
-        if guess:
-            print(f'    ...Model says {guess}')
-        print()
-
-        # Did they want everything labelled the same?
-        label_everything = helper.get_everything_label()
-        if label_everything:
-            filtered_lines.append(f"{label_label}: {label_everything}\n")
-            with open(features, 'w') as f:
-                f.writelines(line + '\n' for line in filtered_lines)
-            if config.config['ml_quick_label_use_skip_lists']:
-                skip_list.add(image)
-            cursor += 1
-            if cursor >= len(filtered_images):
-                helper.unrender_example(image, features, filtered_lines)
-                break
-
-        # Otherwise ask about each example.
-        else:
-            labelling_keystrokes = helper.get_labelling_keystrokes()
-            valid_keystrokes = ['<', '>', 'Q', '?']
-            valid_keystrokes += labelling_keystrokes.keys()
-            prompt = ','.join(valid_keystrokes)
-            print(f'What should I do ({prompt})? ', end='')
-            sys.stdout.flush()
-            keystroke = input_utils.single_keystroke_response(valid_keystrokes)
-            print()
-            if keystroke == 'Q':
-                logger.info('Ok, stopping for now.  Labeled examples are written to disk')
-                helper.unrender_example(image, features, filtered_lines)
-                break
-            elif keystroke == '?':
-                print(
-                    '''
-    >   =   Don't label, move to the next example.
-    <   =   Don't label, move to the previous example.
-    Q   =   Quit labeling now.
-    ?   =   This message.
-  else  =   These keystrokes assign a label to the example and persist it.'''
-                )
-                time.sleep(3.0)
-
-            elif keystroke == '>':
-                cursor += 1
+        # Render the features, image and prompt.
+        _make_prompt(helper, cursor, filtered_images, image, features, labeled_features)
+        try:
+            # Did they want everything labelled the same?
+            label_everything = helper.get_everything_label()
+            if label_everything:
+                labeled_features[(image, features)] = label_everything
+                filtered_images.remove((image, features))
+                if len(filtered_images) == 0:
+                    print('Nothing more to label.')
+                    break
                 if cursor >= len(filtered_images):
-                    print('Wrapping around...')
-                    cursor = 0
-            elif keystroke == '<':
-                cursor -= 1
-                if cursor < 0:
-                    print('Wrapping around...')
-                    cursor = len(filtered_images) - 1
-            elif keystroke in labelling_keystrokes:
-                label_value = labelling_keystrokes[keystroke]
-                filtered_lines.append(f"{label_label}: {label_value}\n")
-                with open(features, 'w') as f:
-                    f.writelines(line + '\n' for line in filtered_lines)
-                if config.config['ml_quick_label_use_skip_lists']:
-                    skip_list.add(image)
-                cursor += 1
-                if cursor >= len(filtered_images):
-                    print('Wrapping around...')
-                    cursor = 0
+                    cursor -= 1
+
+            # Otherwise ask about each individual example.
             else:
-                print(f'Unknown keystroke: {keystroke}')
-        helper.unrender_example(image, features, filtered_lines)
+                labelling_keystrokes = helper.get_labelling_keystrokes()
+                valid_keystrokes = ['<', '>', 'W', 'Q', '?']
+                valid_keystrokes += labelling_keystrokes.keys()
+                prompt = ','.join(valid_keystrokes)
+                print(f'What should I do?  [{prompt}]: ', end='')
+                sys.stdout.flush()
+                keystroke = input_utils.single_keystroke_response(valid_keystrokes)
+                print()
+
+                if keystroke == '?':
+                    print(
+                        '''
+        >   =   Don't label, move to the next example.
+        <   =   Don't label, move to the previous example.
+        W   =   Write pending labels to disk now.
+        Q   =   Quit labeling now.
+        ?   =   This message.
+      else  =   These keystrokes assign a label to the example and persist it.'''
+                    )
+                    input_utils.press_any_key()
+                elif keystroke == 'Q':
+                    logger.info('Ok, stopping for now.')
+                    if len(labeled_features):
+                        logger.info('Discarding %d unsaved labels.', len(labeled_features))
+                    break
+                elif keystroke == '>':
+                    cursor += 1
+                    if cursor >= len(filtered_images):
+                        print('Wrapping around...')
+                        cursor = 0
+                elif keystroke == '<':
+                    cursor -= 1
+                    if cursor < 0:
+                        print('Wrapping around...')
+                        cursor = len(filtered_images) - 1
+                elif keystroke == 'W':
+                    _write_labeled_features(helper, labeled_features, skip_list)
+                    labeled_features = {}
+                elif keystroke in labelling_keystrokes:
+                    label_value = labelling_keystrokes[keystroke]
+                    labeled_features[(image, features)] = label_value
+                    filtered_images.remove((image, features))
+                    if len(filtered_images) == 0:
+                        print('Nothing more to label.')
+                        break
+                    if cursor >= len(filtered_images):
+                        cursor -= 1
+                else:
+                    print(f'Unknown keystroke: {keystroke}.  Try again.')
+        finally:
+            helper.unrender_example(image, features)
+
+    if len(labeled_features):
+        yn = input_utils.yn_response(f'Save the {len(labeled_features)} labels to disk?  [Y/N]: ')
+        if yn in ('Y', 'y'):
+            _write_labeled_features(helper, labeled_features, skip_list)
     _maybe_write_skip_list(skip_list)