#!/usr/bin/env python3 # © Copyright 2021-2022, Scott Gasch """A helper to facilitate quick manual labeling of ML training data. To use, implement a subclass that implements the QuickLabelHelper interface and pass it into the quick_label function. """ import logging import os import sys from abc import abstractmethod from typing import Any, Dict, List, Optional, Set, Tuple import ansi import argparse_utils import config import input_utils logger = logging.getLogger(__name__) parser = config.add_commandline_args( f"ML Quick Labeler ({__file__})", "Args related to quick labeling of ML training data", ) parser.add_argument( "--ml_quick_label_skip_list_path", default="./qlabel_skip_list.txt", metavar="FILENAME", type=argparse_utils.valid_filename, help="Path to file in which to store already labeled data.", ) parser.add_argument( "--ml_quick_label_use_skip_lists", default=True, action=argparse_utils.ActionNoYes, help='Should we use a skip list file to speed up execution?', ) parser.add_argument( "--ml_quick_label_overwrite_labels", default=False, action=argparse_utils.ActionNoYes, help='Enable overwriting of existing labels; default is to not relabel.', ) parser.add_argument( '--ml_quick_label_skip_where_model_agrees', default=False, action=argparse_utils.ActionNoYes, help='Do not filter examples where the model disagrees with the current label.', ) parser.add_argument( '--ml_quick_label_delete_invalid_examples', default=False, action='store_true', help='If set we will delete invalid training examples.', ) class QuickLabelHelper: '''To use this quick labeler your code must create a subclass of this class and implement the abstract methods below. See comments for detailed semantics.''' @abstractmethod def get_candidate_files(self) -> List[str]: '''This must return a list of raw candidate files for labeling.''' pass @abstractmethod def get_features_for_file(self, filename: str) -> Optional[str]: '''Given a raw file, return its features file.''' pass @abstractmethod def render_example(self, filename: str, features: str, lines: List[str]) -> None: '''Render a raw file with its features for the user to judge.''' pass @abstractmethod def unrender_example(self, filename: str, features: str) -> None: '''Unrender a raw file with its features (if necessary)...''' pass @abstractmethod def is_valid_example(self, filename: str, features: str, lines: List[str]) -> bool: '''Returns true iff the example is valid (all features are valid, there are the correct number of features, etc...''' pass @abstractmethod def ask_current_model_about_example( self, filename: str, features: str, lines: List[str], ) -> Any: '''Ask the current ML model about this example, if necessary/possible. Returns None to indicate no model to consult.''' pass @abstractmethod def get_labelling_keystrokes(self) -> Dict[str, Any]: '''What keystrokes should be considered valid label actions and what label does each keystroke map into. e.g. if you want to ask the user to hit 'y' for 'yes' and code that as 255 in your features or to hit 'n' for 'no' and code that as 0 in your features, return: { 'y': 255, 'n': 0 } ''' pass @abstractmethod def get_everything_label(self) -> Any: '''If this returns something other than None it indicates that every example selected should be labeled with this result. Caveat emptor, we will klobber all your files. ''' pass @abstractmethod def get_label_feature(self) -> str: '''What feature denotes the example's label? This is used to detect when examples already have a label and to assign labels to examples.''' pass def _maybe_read_skip_list() -> Set[str]: '''Reads the skip list (files to just bypass) into memory if using.''' ret: Set[str] = set() if config.config['ml_quick_label_use_skip_lists']: quick_skip_file = config.config['ml_quick_label_skip_list_path'] if os.path.exists(quick_skip_file): with open(quick_skip_file, 'r') as f: lines = f.readlines() for line in lines: line = line[:-1] line.strip() ret.add(line) logger.debug('Read %s and found %d entries.', quick_skip_file, len(ret)) return ret def _maybe_write_skip_list(skip_list) -> None: '''Writes the skip list (files to just bypass) to disk if using.''' if config.config['ml_quick_label_use_skip_lists']: quick_skip_file = config.config['ml_quick_label_skip_list_path'] with open(quick_skip_file, 'w') as f: for filename in skip_list: filename = filename.strip() if len(filename) > 0: f.write(f'{filename}\n') logger.debug('Updated %s', quick_skip_file) def _filter_images( images: List[str], skip_list: Set[str], helper: QuickLabelHelper ) -> List[Tuple[str, str]]: '''Discard examples that have particular characteristics. e.g. those that are already labeled and whose current label agrees with the ML model, etc...''' filtered_images = [] label_label = helper.get_label_feature() for image in images: if image in skip_list: logger.debug('Skipping %s because of the skip list.', image) continue features = helper.get_features_for_file(image) if features is None or not os.path.exists(features): logger.warning('%s/%s: features file doesn\'t exist, skipping.', image, features) continue # Read it. label = None filtered_lines = [] with open(features, 'r') as rf: lines = rf.readlines() for line in lines: line = line[:-1] if line.startswith(label_label): label = ''.join(line.split(':')[1:]) label = label.strip() else: filtered_lines.append(line) if not helper.is_valid_example(image, features, filtered_lines): logger.warning('%s/%s: Invalid example.', image, features) if config.config['ml_quick_label_delete_invalid_examples']: os.remove(image) os.remove(features) continue if label and not config.config['ml_quick_label_overwrite_labels']: logger.warning('%s/%s: already has label, skipping.', image, features) continue if config.config['ml_quick_label_skip_where_model_agrees']: model_says = helper.ask_current_model_about_example(image, features, filtered_lines) if model_says and label: if model_says[0] == int(label): logger.warning( '%s/%s: Model agrees with current label (%s), skipping.', image, features, label, ) continue print(f'{image}/{features}: The model disagrees with the current label.') print(f' ...model says {model_says[0]} with probability {model_says[1]}.') print(f' ...the example is currently labeled {label}') filtered_images.append((image, features)) return filtered_images def _make_prompt( helper: QuickLabelHelper, cursor: int, num_filtered_images: int, current_image: str, current_features: str, labeled_features: Dict[Tuple[str, str], str], # Examples already labeled ) -> None: '''Tell an interactive user where they are in the set of examples that may be labeled and the details of the current example.''' label_label = helper.get_label_feature() # the key: of a label in features filtered_lines = [] label = labeled_features.get((current_image, current_features), None) with open(current_features, 'r') as rf: lines = rf.readlines() for line in lines: line = line[:-1] if len(line) == 0: continue if not line.startswith(label_label): filtered_lines.append(line) else: assert not label label = line # Prompt... helper.render_example(current_image, current_features, filtered_lines) print(f'{cursor}/{num_filtered_images} ({cursor/num_filtered_images*100.0:.1f}%) | ', end='') print(f'{ansi.bold()}{current_image} / {current_features}{ansi.reset()}:') print(f' ...{len(labeled_features)} currently unwritten label(s) ("W" to write).') if label: if (current_image, current_features) in labeled_features: print(f' ...This example is labeled but not yet saved: {label}') else: print(f' ...This example is already labeled on disk: {label}') else: print(' ...This example is currently unlabeled') guess = helper.ask_current_model_about_example(current_image, current_features, filtered_lines) if guess: print(f' ...The ML Model says {guess}') print() def _write_labeled_features( helper: QuickLabelHelper, labeled_features: Dict[Tuple[str, str], str], skip_list: Set[str], ) -> None: label_label = helper.get_label_feature() for image_features, label in labeled_features.items(): image = image_features[0] features = image_features[1] filtered_lines = [] with open(features, 'r') as rf: lines = rf.readlines() for line in lines: line = line[:-1] line = line.strip() if line == '': continue if not line.startswith(label_label): filtered_lines.append(line) filtered_lines.append(f'{label_label}: {label}') with open(features, 'w') as f: f.writelines(line + '\n' for line in filtered_lines) if config.config['ml_quick_label_use_skip_lists']: skip_list.add(image) print(f'Wrote {len(labeled_features)} labels.') def quick_label(helper: QuickLabelHelper) -> None: '''Pass your QuickLabelHelper implementing class to this function and it will allow users to label examples and persist them to disk. ''' skip_list = _maybe_read_skip_list() # Ask helper for an initial set of files. images = helper.get_candidate_files() if len(images) == 0: logger.warning('No images files to operate on.') return logger.info('There are %d starting candidate images.', len(images)) # Filter out any that can't be converted to features or already have a # label (unless they used --ml_qukck_label_overwrite_labels). filtered_images = _filter_images(images, skip_list, helper) if len(filtered_images) == 0: logger.warning('No image files to operate on (post filter).') return logger.info('There are %d candidate images post filtering.', len(filtered_images)) # Allow the user to label the non-filtered images one by one. labeled_features: Dict[Tuple[str, str], str] = {} cursor = 0 while True: assert 0 <= cursor < len(filtered_images) image = filtered_images[cursor][0] assert os.path.exists(image) features = filtered_images[cursor][1] assert features and os.path.exists(features) # Render the features, image and prompt. _make_prompt(helper, cursor, len(filtered_images), image, features, labeled_features) try: # Did they want everything labelled the same? label_everything = helper.get_everything_label() if label_everything: labeled_features[(image, features)] = label_everything filtered_images.remove((image, features)) if len(filtered_images) == 0: print('Nothing more to label.') break if cursor >= len(filtered_images): cursor -= 1 # Otherwise ask about each individual example. else: labelling_keystrokes = helper.get_labelling_keystrokes() valid_keystrokes = ['<', '>', 'W', 'Q', '?'] valid_keystrokes += labelling_keystrokes.keys() prompt = ','.join(valid_keystrokes) print(f'What should I do? [{prompt}]: ', end='') sys.stdout.flush() keystroke = input_utils.single_keystroke_response(valid_keystrokes) print() if keystroke == '?': print( ''' > = Don't label, move to the next example. < = Don't label, move to the previous example. W = Write pending labels to disk now. Q = Quit labeling now. ? = This message. else = These keystrokes assign a label to the example and persist it.''' ) input_utils.press_any_key() elif keystroke == 'Q': logger.info('Ok, stopping for now.') if len(labeled_features): logger.info('Discarding %d unsaved labels.', len(labeled_features)) break elif keystroke == '>': cursor += 1 if cursor >= len(filtered_images): print('Wrapping around...') cursor = 0 elif keystroke == '<': cursor -= 1 if cursor < 0: print('Wrapping around...') cursor = len(filtered_images) - 1 elif keystroke == 'W': _write_labeled_features(helper, labeled_features, skip_list) labeled_features = {} elif keystroke in labelling_keystrokes: label_value = labelling_keystrokes[keystroke] labeled_features[(image, features)] = label_value filtered_images.remove((image, features)) if len(filtered_images) == 0: print('Nothing more to label.') break if cursor >= len(filtered_images): cursor -= 1 else: print(f'Unknown keystroke: {keystroke}. Try again.') finally: helper.unrender_example(image, features) if len(labeled_features): yn = input_utils.yn_response(f'Save the {len(labeled_features)} labels to disk? [Y/N]: ') if yn in ('Y', 'y'): _write_labeled_features(helper, labeled_features, skip_list) _maybe_write_skip_list(skip_list)