Overhaul quick_labeler part 2: understand what the current model says.
authorScott Gasch <[email protected]>
Tue, 12 Apr 2022 22:23:54 +0000 (15:23 -0700)
committerScott Gasch <[email protected]>
Tue, 12 Apr 2022 22:23:54 +0000 (15:23 -0700)
ml/quick_label.py

index 7bd43c3c8d67338969a42a267b508863d0ff44f5..05efbaebcee5217b516a3c3fc37421b1e6a982fd 100644 (file)
@@ -8,9 +8,8 @@ import logging
 import os
 import sys
 import time
-import warnings
 from abc import abstractmethod
-from typing import Any, Dict, List, Optional, Set
+from typing import Any, Dict, List, Optional, Set, Tuple
 
 import argparse_utils
 import config
@@ -39,6 +38,18 @@ parser.add_argument(
     action=argparse_utils.ActionNoYes,
     help='Enable overwriting existing labels; default is to not relabel.',
 )
+parser.add_argument(
+    '--ml_quick_label_skip_where_model_agrees',
+    default=False,
+    action=argparse_utils.ActionNoYes,
+    help='Do not filter examples where the model disagrees with the current label.',
+)
+parser.add_argument(
+    'ml_quick_label_delete_invalid_examples',
+    default=False,
+    action='store_true',
+    help='If set we will delete invalid training examples.',
+)
 
 
 class QuickLabelHelper:
@@ -66,6 +77,12 @@ class QuickLabelHelper:
         '''Unrender a raw file with its features (if necessary)...'''
         pass
 
+    @abstractmethod
+    def is_valid_example(self, filename: str, features: str, lines: List[str]) -> bool:
+        '''Returns true iff the example is valid (all features are valid, there
+        are the correct number of features, etc...'''
+        pass
+
     @abstractmethod
     def ask_current_model_about_example(
         self,
@@ -136,17 +153,11 @@ def _maybe_write_skip_list(skip_list) -> None:
         logger.debug('Updated %s', quick_skip_file)
 
 
-def quick_label(helper: QuickLabelHelper) -> None:
-    # Ask helper for an initial set of files.
-    images = helper.get_candidate_files()
-    if len(images) == 0:
-        logger.warning('No images files to operate on.')
-        return
-
-    # Filter out any that can't be converted to features or already have a
-    # label (unless they used --ml_qukck_label_overwrite_labels).
+def _filter_images(
+    images: List[str], skip_list: Set[str], helper: QuickLabelHelper
+) -> List[Tuple[str, str]]:
     filtered_images = []
-    skip_list = _maybe_read_skip_list()
+    label_label = helper.get_label_feature()
     for image in images:
         if image in skip_list:
             logger.debug('Skipping %s because of the skip list', image)
@@ -154,33 +165,65 @@ def quick_label(helper: QuickLabelHelper) -> None:
 
         features = helper.get_features_for_file(image)
         if features is None or not os.path.exists(features):
-            msg = f'{image}/{features}: {features} doesn\'t exist, SKIPPING.'
-            logger.warning(msg)
-            warnings.warn(msg)
+            logger.warning('%s/%s: features doesn\'t exist, SKIPPING.', image, features)
             continue
 
-        label_label = helper.get_label_feature()
         label = None
+        filtered_lines = []
         with open(features, 'r') as rf:
             lines = rf.readlines()
         for line in lines:
             line = line[:-1]
             if line.startswith(label_label):
-                label = line
+                label = ''.join(line.split(':')[1:])
+                label = label.strip()
+            else:
+                filtered_lines.append(line)
+
+        if not helper.is_valid_example(image, features, filtered_lines):
+            logger.warning('%s/%s: Invalid example.', image, features)
+            if config.config['ml_quick_label_delete_invalid_examples']:
+                os.remove(image)
+                os.remove(features)
+            continue
+
         if label and not config.config['ml_quick_label_overwrite_labels']:
-            msg = f'{image}/{features}: already has label, SKIPPING'
-            logger.warning(msg)
-            warnings.warn(msg)
+            logger.warning('%s/%s: already has label, SKIPPING.', image, features)
             continue
+
+        if config.config['ml_quick_label_skip_where_model_agrees']:
+            model_says = helper.ask_current_model_about_example(image, features, filtered_lines)
+            if model_says and label:
+                if model_says[0] == int(label):
+                    continue
+                print(f'{image}/{features}: The model disagrees with the current label.')
+                print(f'    ...model says {model_says[0]} with probability {model_says[1]}.')
+                print(f'    ...the example is currently labeled {label}')
         filtered_images.append((image, features))
+    return filtered_images
+
+
+def quick_label(helper: QuickLabelHelper) -> None:
+    skip_list = _maybe_read_skip_list()
 
+    # Ask helper for an initial set of files.
+    images = helper.get_candidate_files()
+    if len(images) == 0:
+        logger.warning('No images files to operate on.')
+        return
+
+    # Filter out any that can't be converted to features or already have a
+    # label (unless they used --ml_qukck_label_overwrite_labels).
+    filtered_images = _filter_images(images, skip_list, helper)
     if len(filtered_images) == 0:
         logger.warning('No image files to operate on (post filter).')
         return
 
-    cursor = 0
+    # Allow the user to label the non-filtered images one by one.
     import input_utils
 
+    cursor = 0
+    label_label = helper.get_label_feature()
     while True:
         assert 0 <= cursor < len(filtered_images)