From c478e593e27c5487b979556ddc45841b52c8d052 Mon Sep 17 00:00:00 2001 From: Scott Gasch Date: Tue, 12 Apr 2022 14:19:26 -0700 Subject: [PATCH] Overhaul quick labeler, part 1. --- ml/quick_label.py | 248 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 198 insertions(+), 50 deletions(-) diff --git a/ml/quick_label.py b/ml/quick_label.py index 15256a3..7bd43c3 100644 --- a/ml/quick_label.py +++ b/ml/quick_label.py @@ -4,12 +4,13 @@ """A helper to facilitate quick manual labeling of ML training data.""" -import glob import logging import os +import sys +import time import warnings -from dataclasses import dataclass -from typing import Callable, List, Optional, Set +from abc import abstractmethod +from typing import Any, Dict, List, Optional, Set import argparse_utils import config @@ -40,21 +41,74 @@ parser.add_argument( ) -@dataclass -class InputSpec: - """A wrapper around the input data we need to operate; should be - populated by the caller.""" +class QuickLabelHelper: + '''To use this quick labeler your code must create a subclass of this + class and implement the abstract methods below. See comments for + detailed semantics.''' - image_file_glob: Optional[str] = None - image_file_prepopulated_list: Optional[List[str]] = None - image_file_to_features_file: Optional[Callable[[str], str]] = None - label: str = '' - valid_keystrokes: List[str] = [] - prompt: str = '' - keystroke_to_label: Optional[Callable[[str], str]] = None + @abstractmethod + def get_candidate_files(self) -> List[str]: + '''This must return a list of raw candidate files for labeling.''' + pass + @abstractmethod + def get_features_for_file(self, filename: str) -> Optional[str]: + '''Given a raw file, return its features file.''' + pass + + @abstractmethod + def render_example(self, filename: str, features: str, lines: List[str]) -> None: + '''Render a raw file with its features for the user.''' + pass + + @abstractmethod + def unrender_example(self, filename: str, features: str, lines: List[str]) -> None: + '''Unrender a raw file with its features (if necessary)...''' + pass + + @abstractmethod + def ask_current_model_about_example( + self, + filename: str, + features: str, + lines: List[str], + ) -> Any: + '''Ask the current ML model about this example, if necessary.''' + pass + + @abstractmethod + def get_labelling_keystrokes(self) -> Dict[str, Any]: + '''What keystrokes should be considered valid label actions and what + label does each keystroke map into. e.g. if you want to ask + the user to hit 'y' for 'yes' and code that as 255 in your + features or to hit 'n' for 'no' and code that as 0 in your + features, return: + + { 'y': 255, 'n': 0 } + + ''' + pass + + @abstractmethod + def get_everything_label(self) -> Any: + '''If this returns something other than None it indicates that every + example selected should be labeled with this result. Caveat + emptor, we will klobber all your files. + + ''' + pass + + @abstractmethod + def get_label_feature(self) -> str: + '''What feature denotes the example's label? This is used to detect + when examples already have a label and to assign labels to + examples.''' + pass + + +def _maybe_read_skip_list() -> Set[str]: + '''Reads the skip list (files to just bypass) into memory if using.''' -def read_skip_list() -> Set[str]: ret: Set[str] = set() if config.config['ml_quick_label_use_skip_lists']: quick_skip_file = config.config['ml_quick_label_skip_list_path'] @@ -69,7 +123,9 @@ def read_skip_list() -> Set[str]: return ret -def write_skip_list(skip_list) -> None: +def _maybe_write_skip_list(skip_list) -> None: + '''Writes the skip list (files to just bypass) to disk if using.''' + if config.config['ml_quick_label_use_skip_lists']: quick_skip_file = config.config['ml_quick_label_skip_list_path'] with open(quick_skip_file, 'w') as f: @@ -80,54 +136,146 @@ def write_skip_list(skip_list) -> None: logger.debug('Updated %s', quick_skip_file) -def label(in_spec: InputSpec) -> None: - import input_utils - - images = [] - if in_spec.image_file_glob is not None: - images += glob.glob(in_spec.image_file_glob) - elif in_spec.image_file_prepopulated_list is not None: - images += in_spec.image_file_prepopulated_list - else: - raise ValueError('One of image_file_glob or image_file_prepopulated_list is required') +def quick_label(helper: QuickLabelHelper) -> None: + # Ask helper for an initial set of files. + images = helper.get_candidate_files() + if len(images) == 0: + logger.warning('No images files to operate on.') + return - skip_list = read_skip_list() + # Filter out any that can't be converted to features or already have a + # label (unless they used --ml_qukck_label_overwrite_labels). + filtered_images = [] + skip_list = _maybe_read_skip_list() for image in images: if image in skip_list: logger.debug('Skipping %s because of the skip list', image) continue - assert in_spec.image_file_to_features_file - features = in_spec.image_file_to_features_file(image) + + features = helper.get_features_for_file(image) if features is None or not os.path.exists(features): - msg = f'File {image} yielded file {features} which does not exist, SKIPPING.' + msg = f'{image}/{features}: {features} doesn\'t exist, SKIPPING.' + logger.warning(msg) + warnings.warn(msg) + continue + + label_label = helper.get_label_feature() + label = None + with open(features, 'r') as rf: + lines = rf.readlines() + for line in lines: + line = line[:-1] + if line.startswith(label_label): + label = line + if label and not config.config['ml_quick_label_overwrite_labels']: + msg = f'{image}/{features}: already has label, SKIPPING' logger.warning(msg) warnings.warn(msg) continue + filtered_images.append((image, features)) + + if len(filtered_images) == 0: + logger.warning('No image files to operate on (post filter).') + return + + cursor = 0 + import input_utils + + while True: + assert 0 <= cursor < len(filtered_images) + + image = filtered_images[cursor][0] + assert os.path.exists(image) + features = filtered_images[cursor][1] + assert features and os.path.exists(features) - # Render features and image. filtered_lines = [] - with open(features, "r") as f: - lines = f.readlines() - saw_label = False + label = None + with open(features, 'r') as rf: + lines = rf.readlines() for line in lines: line = line[:-1] - if in_spec.label not in line: + if not line.startswith(label_label): filtered_lines.append(line) else: - saw_label = True - - if not saw_label or config.config['ml_quick_label_overwrite_labels']: - logger.info(features) - os.system(f'xv {image} &') - keystroke = input_utils.single_keystroke_response( - in_spec.valid_keystrokes, - prompt=in_spec.prompt, - ) - os.system('killall xv') - assert in_spec.keystroke_to_label - label_value = in_spec.keystroke_to_label(keystroke) - filtered_lines.append(f"{in_spec.label}: {label_value}\n") + label = line + + # Render... + helper.render_example(image, features, filtered_lines) + + # Prompt... + print( + f'{cursor} of {len(filtered_images)} {cursor/len(filtered_images)*100.0:.1f}%): {image}, {features}' + ) + if label: + print(f' ...Already labelled: {label}') + else: + print(' ...Currently unlabeled') + guess = helper.ask_current_model_about_example(image, features, filtered_lines) + if guess: + print(f' ...Model says {guess}') + print() + + # Did they want everything labelled the same? + label_everything = helper.get_everything_label() + if label_everything: + filtered_lines.append(f"{label_label}: {label_everything}\n") with open(features, 'w') as f: f.writelines(line + '\n' for line in filtered_lines) - skip_list.add(image) - write_skip_list(skip_list) + if config.config['ml_quick_label_use_skip_lists']: + skip_list.add(image) + cursor += 1 + if cursor >= len(filtered_images): + helper.unrender_example(image, features, filtered_lines) + break + + # Otherwise ask about each example. + else: + labelling_keystrokes = helper.get_labelling_keystrokes() + valid_keystrokes = ['<', '>', 'Q', '?'] + valid_keystrokes += labelling_keystrokes.keys() + prompt = ','.join(valid_keystrokes) + print(f'What should I do ({prompt})? ', end='') + sys.stdout.flush() + keystroke = input_utils.single_keystroke_response(valid_keystrokes) + print() + if keystroke == 'Q': + logger.info('Ok, stopping for now. Labeled examples are written to disk') + helper.unrender_example(image, features, filtered_lines) + break + elif keystroke == '?': + print( + ''' + > = Don't label, move to the next example. + < = Don't label, move to the previous example. + Q = Quit labeling now. + ? = This message. + else = These keystrokes assign a label to the example and persist it.''' + ) + time.sleep(3.0) + + elif keystroke == '>': + cursor += 1 + if cursor >= len(filtered_images): + print('Wrapping around...') + cursor = 0 + elif keystroke == '<': + cursor -= 1 + if cursor < 0: + print('Wrapping around...') + cursor = len(filtered_images) - 1 + elif keystroke in labelling_keystrokes: + label_value = labelling_keystrokes[keystroke] + filtered_lines.append(f"{label_label}: {label_value}\n") + with open(features, 'w') as f: + f.writelines(line + '\n' for line in filtered_lines) + if config.config['ml_quick_label_use_skip_lists']: + skip_list.add(image) + cursor += 1 + if cursor >= len(filtered_images): + print('Wrapping around...') + cursor = 0 + else: + print(f'Unknown keystroke: {keystroke}') + helper.unrender_example(image, features, filtered_lines) + _maybe_write_skip_list(skip_list) -- 2.46.0