#!/usr/bin/env python3
-import glob
+# © Copyright 2021-2022, Scott Gasch
+
+"""A helper to facilitate quick manual labeling of ML training data.
+
+To use, implement a subclass that implements the QuickLabelHelper
+interface and pass it into the quick_label function.
+
+"""
+
import logging
import os
-from typing import Callable, List, NamedTuple, Optional, Set
-import warnings
+import sys
+from abc import abstractmethod
+from typing import Any, Dict, List, Optional, Set, Tuple
+import ansi
import argparse_utils
import config
+import input_utils
logger = logging.getLogger(__name__)
parser = config.add_commandline_args(
"--ml_quick_label_overwrite_labels",
default=False,
action=argparse_utils.ActionNoYes,
- help='Enable overwriting existing labels; default is to not relabel.',
+ help='Enable overwriting of existing labels; default is to not relabel.',
+)
+parser.add_argument(
+ '--ml_quick_label_skip_where_model_agrees',
+ default=False,
+ action=argparse_utils.ActionNoYes,
+ help='Do not filter examples where the model disagrees with the current label.',
)
+parser.add_argument(
+ '--ml_quick_label_delete_invalid_examples',
+ default=False,
+ action='store_true',
+ help='If set we will delete invalid training examples.',
+)
+
+
+class QuickLabelHelper:
+ '''To use this quick labeler your code must create a subclass of this
+ class and implement the abstract methods below. See comments for
+ detailed semantics.'''
+
+ @abstractmethod
+ def get_candidate_files(self) -> List[str]:
+ '''This must return a list of raw candidate files for labeling.'''
+ pass
+
+ @abstractmethod
+ def get_features_for_file(self, filename: str) -> Optional[str]:
+ '''Given a raw file, return its features file.'''
+ pass
+ @abstractmethod
+ def render_example(self, filename: str, features: str, lines: List[str]) -> None:
+ '''Render a raw file with its features for the user to judge.'''
+ pass
-class InputSpec(NamedTuple):
- image_file_glob: Optional[str]
- image_file_prepopulated_list: Optional[List[str]]
- image_file_to_features_file: Callable[[str], str]
- label: str
- valid_keystrokes: List[str]
- prompt: str
- keystroke_to_label: Callable[[str], str]
+ @abstractmethod
+ def unrender_example(self, filename: str, features: str) -> None:
+ '''Unrender a raw file with its features (if necessary)...'''
+ pass
+ @abstractmethod
+ def is_valid_example(self, filename: str, features: str, lines: List[str]) -> bool:
+ '''Returns true iff the example is valid (all features are valid, there
+ are the correct number of features, etc...'''
+ pass
+
+ @abstractmethod
+ def ask_current_model_about_example(
+ self,
+ filename: str,
+ features: str,
+ lines: List[str],
+ ) -> Any:
+ '''Ask the current ML model about this example, if necessary/possible.
+ Returns None to indicate no model to consult.'''
+ pass
+
+ @abstractmethod
+ def get_labelling_keystrokes(self) -> Dict[str, Any]:
+ '''What keystrokes should be considered valid label actions and what
+ label does each keystroke map into. e.g. if you want to ask
+ the user to hit 'y' for 'yes' and code that as 255 in your
+ features or to hit 'n' for 'no' and code that as 0 in your
+ features, return:
+
+ { 'y': 255, 'n': 0 }
+
+ '''
+ pass
+
+ @abstractmethod
+ def get_everything_label(self) -> Any:
+ '''If this returns something other than None it indicates that every
+ example selected should be labeled with this result. Caveat
+ emptor, we will klobber all your files.
+
+ '''
+ pass
+
+ @abstractmethod
+ def get_label_feature(self) -> str:
+ '''What feature denotes the example's label? This is used to detect
+ when examples already have a label and to assign labels to
+ examples.'''
+ pass
+
+
+def _maybe_read_skip_list() -> Set[str]:
+ '''Reads the skip list (files to just bypass) into memory if using.'''
-def read_skip_list() -> Set[str]:
ret: Set[str] = set()
if config.config['ml_quick_label_use_skip_lists']:
quick_skip_file = config.config['ml_quick_label_skip_list_path']
line = line[:-1]
line.strip()
ret.add(line)
- logger.debug(f'Read {quick_skip_file} and found {len(ret)} entries.')
+ logger.debug('Read %s and found %d entries.', quick_skip_file, len(ret))
return ret
-def write_skip_list(skip_list) -> None:
+def _maybe_write_skip_list(skip_list) -> None:
+ '''Writes the skip list (files to just bypass) to disk if using.'''
+
if config.config['ml_quick_label_use_skip_lists']:
quick_skip_file = config.config['ml_quick_label_skip_list_path']
with open(quick_skip_file, 'w') as f:
filename = filename.strip()
if len(filename) > 0:
f.write(f'{filename}\n')
- logger.debug(f'Updated {quick_skip_file}')
+ logger.debug('Updated %s', quick_skip_file)
-def label(in_spec: InputSpec) -> None:
- import input_utils
-
- images = []
- if in_spec.image_file_glob is not None:
- images += glob.glob(in_spec.image_file_glob)
- elif in_spec.image_file_prepopulated_list is not None:
- images += in_spec.image_file_prepopulated_list
- else:
- raise ValueError(
- 'One of image_file_glob or image_file_prepopulated_list is required'
- )
+def _filter_images(
+ images: List[str], skip_list: Set[str], helper: QuickLabelHelper
+) -> List[Tuple[str, str]]:
+ '''Discard examples that have particular characteristics. e.g.
+ those that are already labeled and whose current label agrees with
+ the ML model, etc...'''
- skip_list = read_skip_list()
+ filtered_images = []
+ label_label = helper.get_label_feature()
for image in images:
if image in skip_list:
- logger.debug(f'Skipping {image} because of the skip list')
+ logger.debug('Skipping %s because of the skip list.', image)
continue
- features = in_spec.image_file_to_features_file(image)
+
+ features = helper.get_features_for_file(image)
if features is None or not os.path.exists(features):
- msg = f'File {image} yielded file {features} which does not exist, SKIPPING.'
- logger.warning(msg)
- warnings.warn(msg)
+ logger.warning('%s/%s: features file doesn\'t exist, skipping.', image, features)
+ continue
+
+ # Read it.
+ label = None
+ filtered_lines = []
+ with open(features, 'r') as rf:
+ lines = rf.readlines()
+ for line in lines:
+ line = line[:-1]
+ if line.startswith(label_label):
+ label = ''.join(line.split(':')[1:])
+ label = label.strip()
+ else:
+ filtered_lines.append(line)
+
+ if not helper.is_valid_example(image, features, filtered_lines):
+ logger.warning('%s/%s: Invalid example.', image, features)
+ if config.config['ml_quick_label_delete_invalid_examples']:
+ os.remove(image)
+ os.remove(features)
+ continue
+
+ if label and not config.config['ml_quick_label_overwrite_labels']:
+ logger.warning('%s/%s: already has label, skipping.', image, features)
+ continue
+
+ if config.config['ml_quick_label_skip_where_model_agrees']:
+ model_says = helper.ask_current_model_about_example(image, features, filtered_lines)
+ if model_says and label:
+ if model_says[0] == int(label):
+ logger.warning(
+ '%s/%s: Model agrees with current label (%s), skipping.',
+ image,
+ features,
+ label,
+ )
+ continue
+
+ print(f'{image}/{features}: The model disagrees with the current label.')
+ print(f' ...model says {model_says[0]} with probability {model_says[1]}.')
+ print(f' ...the example is currently labeled {label}')
+
+ filtered_images.append((image, features))
+ return filtered_images
+
+
+def _make_prompt(
+ helper: QuickLabelHelper,
+ cursor: int,
+ num_filtered_images: int,
+ current_image: str,
+ current_features: str,
+ labeled_features: Dict[Tuple[str, str], str], # Examples already labeled
+) -> None:
+ '''Tell an interactive user where they are in the set of examples that
+ may be labeled and the details of the current example.'''
+
+ label_label = helper.get_label_feature() # the key: of a label in features
+ filtered_lines = []
+ label = labeled_features.get((current_image, current_features), None)
+ with open(current_features, 'r') as rf:
+ lines = rf.readlines()
+ for line in lines:
+ line = line[:-1]
+ if len(line) == 0:
continue
+ if not line.startswith(label_label):
+ filtered_lines.append(line)
+ else:
+ assert not label
+ label = line
- # Render features and image.
+ # Prompt...
+ helper.render_example(current_image, current_features, filtered_lines)
+ print(f'{cursor}/{num_filtered_images} ({cursor/num_filtered_images*100.0:.1f}%) | ', end='')
+ print(f'{ansi.bold()}{current_image} / {current_features}{ansi.reset()}:')
+ print(f' ...{len(labeled_features)} currently unwritten label(s) ("W" to write).')
+ if label:
+ if (current_image, current_features) in labeled_features:
+ print(f' ...This example is labeled but not yet saved: {label}')
+ else:
+ print(f' ...This example is already labeled on disk: {label}')
+ else:
+ print(' ...This example is currently unlabeled')
+ guess = helper.ask_current_model_about_example(current_image, current_features, filtered_lines)
+ if guess:
+ print(f' ...The ML Model says {guess}')
+ print()
+
+
+def _write_labeled_features(
+ helper: QuickLabelHelper,
+ labeled_features: Dict[Tuple[str, str], str],
+ skip_list: Set[str],
+) -> None:
+ label_label = helper.get_label_feature()
+ for image_features, label in labeled_features.items():
+ image = image_features[0]
+ features = image_features[1]
filtered_lines = []
- with open(features, "r") as f:
- lines = f.readlines()
- saw_label = False
+ with open(features, 'r') as rf:
+ lines = rf.readlines()
for line in lines:
line = line[:-1]
- if in_spec.label not in line:
+ line = line.strip()
+ if line == '':
+ continue
+ if not line.startswith(label_label):
filtered_lines.append(line)
+
+ filtered_lines.append(f'{label_label}: {label}')
+ with open(features, 'w') as f:
+ f.writelines(line + '\n' for line in filtered_lines)
+ if config.config['ml_quick_label_use_skip_lists']:
+ skip_list.add(image)
+ print(f'Wrote {len(labeled_features)} labels.')
+
+
+def quick_label(helper: QuickLabelHelper) -> None:
+ '''Pass your QuickLabelHelper implementing class to this function and
+ it will allow users to label examples and persist them to disk.
+
+ '''
+ skip_list = _maybe_read_skip_list()
+
+ # Ask helper for an initial set of files.
+ images = helper.get_candidate_files()
+ if len(images) == 0:
+ logger.warning('No images files to operate on.')
+ return
+ logger.info('There are %d starting candidate images.', len(images))
+
+ # Filter out any that can't be converted to features or already have a
+ # label (unless they used --ml_qukck_label_overwrite_labels).
+ filtered_images = _filter_images(images, skip_list, helper)
+ if len(filtered_images) == 0:
+ logger.warning('No image files to operate on (post filter).')
+ return
+ logger.info('There are %d candidate images post filtering.', len(filtered_images))
+
+ # Allow the user to label the non-filtered images one by one.
+ labeled_features: Dict[Tuple[str, str], str] = {}
+ cursor = 0
+ while True:
+ assert 0 <= cursor < len(filtered_images)
+
+ image = filtered_images[cursor][0]
+ assert os.path.exists(image)
+ features = filtered_images[cursor][1]
+ assert features and os.path.exists(features)
+
+ # Render the features, image and prompt.
+ _make_prompt(helper, cursor, len(filtered_images), image, features, labeled_features)
+ try:
+ # Did they want everything labelled the same?
+ label_everything = helper.get_everything_label()
+ if label_everything:
+ labeled_features[(image, features)] = label_everything
+ filtered_images.remove((image, features))
+ if len(filtered_images) == 0:
+ print('Nothing more to label.')
+ break
+ if cursor >= len(filtered_images):
+ cursor -= 1
+
+ # Otherwise ask about each individual example.
else:
- saw_label = True
-
- if not saw_label or config.config['ml_quick_label_overwrite_labels']:
- logger.info(features)
- os.system(f'xv {image} &')
- keystroke = input_utils.single_keystroke_response(
- in_spec.valid_keystrokes,
- prompt=in_spec.prompt,
- )
- os.system('killall xv')
- label_value = in_spec.keystroke_to_label(keystroke)
- filtered_lines.append(f"{in_spec.label}: {label_value}\n")
- with open(features, 'w') as f:
- f.writelines("%s\n" % line for line in filtered_lines)
- skip_list.add(image)
- write_skip_list(skip_list)
+ labelling_keystrokes = helper.get_labelling_keystrokes()
+ valid_keystrokes = ['<', '>', 'W', 'Q', '?']
+ valid_keystrokes += labelling_keystrokes.keys()
+ prompt = ','.join(valid_keystrokes)
+ print(f'What should I do? [{prompt}]: ', end='')
+ sys.stdout.flush()
+ keystroke = input_utils.single_keystroke_response(valid_keystrokes)
+ print()
+
+ if keystroke == '?':
+ print(
+ '''
+ > = Don't label, move to the next example.
+ < = Don't label, move to the previous example.
+ W = Write pending labels to disk now.
+ Q = Quit labeling now.
+ ? = This message.
+ else = These keystrokes assign a label to the example and persist it.'''
+ )
+ input_utils.press_any_key()
+ elif keystroke == 'Q':
+ logger.info('Ok, stopping for now.')
+ if len(labeled_features):
+ logger.info('Discarding %d unsaved labels.', len(labeled_features))
+ break
+ elif keystroke == '>':
+ cursor += 1
+ if cursor >= len(filtered_images):
+ print('Wrapping around...')
+ cursor = 0
+ elif keystroke == '<':
+ cursor -= 1
+ if cursor < 0:
+ print('Wrapping around...')
+ cursor = len(filtered_images) - 1
+ elif keystroke == 'W':
+ _write_labeled_features(helper, labeled_features, skip_list)
+ labeled_features = {}
+ elif keystroke in labelling_keystrokes:
+ label_value = labelling_keystrokes[keystroke]
+ labeled_features[(image, features)] = label_value
+ filtered_images.remove((image, features))
+ if len(filtered_images) == 0:
+ print('Nothing more to label.')
+ break
+ if cursor >= len(filtered_images):
+ cursor -= 1
+ else:
+ print(f'Unknown keystroke: {keystroke}. Try again.')
+ finally:
+ helper.unrender_example(image, features)
+
+ if len(labeled_features):
+ yn = input_utils.yn_response(f'Save the {len(labeled_features)} labels to disk? [Y/N]: ')
+ if yn in ('Y', 'y'):
+ _write_labeled_features(helper, labeled_features, skip_list)
+ _maybe_write_skip_list(skip_list)