X-Git-Url: https://wannabe.guru.org/gitweb/?a=blobdiff_plain;f=ml%2Fquick_label.py;h=feab67b2e88484132ac80506f5add82b395e35bf;hb=e46158e49121b8a955bb07b73f5bcf9928b79c90;hp=7e0a6bf64921533e00d719223d2657fb21ebbccf;hpb=f2600f30801c849fc1d139386e3ddc3c9eb43e30;p=python_utils.git diff --git a/ml/quick_label.py b/ml/quick_label.py index 7e0a6bf..feab67b 100644 --- a/ml/quick_label.py +++ b/ml/quick_label.py @@ -1,13 +1,24 @@ #!/usr/bin/env python3 -import glob +# © Copyright 2021-2022, Scott Gasch + +"""A helper to facilitate quick manual labeling of ML training data. + +To use, implement a subclass that implements the QuickLabelHelper +interface and pass it into the quick_label function. + +""" + import logging import os -import warnings -from typing import Callable, List, NamedTuple, Optional, Set +import sys +from abc import abstractmethod +from typing import Any, Dict, List, Optional, Set, Tuple +import ansi import argparse_utils import config +import input_utils logger = logging.getLogger(__name__) parser = config.add_commandline_args( @@ -31,21 +42,97 @@ parser.add_argument( "--ml_quick_label_overwrite_labels", default=False, action=argparse_utils.ActionNoYes, - help='Enable overwriting existing labels; default is to not relabel.', + help='Enable overwriting of existing labels; default is to not relabel.', +) +parser.add_argument( + '--ml_quick_label_skip_where_model_agrees', + default=False, + action=argparse_utils.ActionNoYes, + help='Do not filter examples where the model disagrees with the current label.', ) +parser.add_argument( + '--ml_quick_label_delete_invalid_examples', + default=False, + action='store_true', + help='If set we will delete invalid training examples.', +) + + +class QuickLabelHelper: + '''To use this quick labeler your code must create a subclass of this + class and implement the abstract methods below. See comments for + detailed semantics.''' + + @abstractmethod + def get_candidate_files(self) -> List[str]: + '''This must return a list of raw candidate files for labeling.''' + pass + + @abstractmethod + def get_features_for_file(self, filename: str) -> Optional[str]: + '''Given a raw file, return its features file.''' + pass + @abstractmethod + def render_example(self, filename: str, features: str, lines: List[str]) -> None: + '''Render a raw file with its features for the user to judge.''' + pass -class InputSpec(NamedTuple): - image_file_glob: Optional[str] - image_file_prepopulated_list: Optional[List[str]] - image_file_to_features_file: Callable[[str], str] - label: str - valid_keystrokes: List[str] - prompt: str - keystroke_to_label: Callable[[str], str] + @abstractmethod + def unrender_example(self, filename: str, features: str) -> None: + '''Unrender a raw file with its features (if necessary)...''' + pass + @abstractmethod + def is_valid_example(self, filename: str, features: str, lines: List[str]) -> bool: + '''Returns true iff the example is valid (all features are valid, there + are the correct number of features, etc...''' + pass + + @abstractmethod + def ask_current_model_about_example( + self, + filename: str, + features: str, + lines: List[str], + ) -> Any: + '''Ask the current ML model about this example, if necessary/possible. + Returns None to indicate no model to consult.''' + pass + + @abstractmethod + def get_labelling_keystrokes(self) -> Dict[str, Any]: + '''What keystrokes should be considered valid label actions and what + label does each keystroke map into. e.g. if you want to ask + the user to hit 'y' for 'yes' and code that as 255 in your + features or to hit 'n' for 'no' and code that as 0 in your + features, return: + + { 'y': 255, 'n': 0 } + + ''' + pass + + @abstractmethod + def get_everything_label(self) -> Any: + '''If this returns something other than None it indicates that every + example selected should be labeled with this result. Caveat + emptor, we will klobber all your files. + + ''' + pass + + @abstractmethod + def get_label_feature(self) -> str: + '''What feature denotes the example's label? This is used to detect + when examples already have a label and to assign labels to + examples.''' + pass + + +def _maybe_read_skip_list() -> Set[str]: + '''Reads the skip list (files to just bypass) into memory if using.''' -def read_skip_list() -> Set[str]: ret: Set[str] = set() if config.config['ml_quick_label_use_skip_lists']: quick_skip_file = config.config['ml_quick_label_skip_list_path'] @@ -56,11 +143,13 @@ def read_skip_list() -> Set[str]: line = line[:-1] line.strip() ret.add(line) - logger.debug(f'Read {quick_skip_file} and found {len(ret)} entries.') + logger.debug('Read %s and found %d entries.', quick_skip_file, len(ret)) return ret -def write_skip_list(skip_list) -> None: +def _maybe_write_skip_list(skip_list) -> None: + '''Writes the skip list (files to just bypass) to disk if using.''' + if config.config['ml_quick_label_use_skip_lists']: quick_skip_file = config.config['ml_quick_label_skip_list_path'] with open(quick_skip_file, 'w') as f: @@ -68,55 +157,247 @@ def write_skip_list(skip_list) -> None: filename = filename.strip() if len(filename) > 0: f.write(f'{filename}\n') - logger.debug(f'Updated {quick_skip_file}') + logger.debug('Updated %s', quick_skip_file) -def label(in_spec: InputSpec) -> None: - import input_utils - - images = [] - if in_spec.image_file_glob is not None: - images += glob.glob(in_spec.image_file_glob) - elif in_spec.image_file_prepopulated_list is not None: - images += in_spec.image_file_prepopulated_list - else: - raise ValueError('One of image_file_glob or image_file_prepopulated_list is required') +def _filter_images( + images: List[str], skip_list: Set[str], helper: QuickLabelHelper +) -> List[Tuple[str, str]]: + '''Discard examples that have particular characteristics. e.g. + those that are already labeled and whose current label agrees with + the ML model, etc...''' - skip_list = read_skip_list() + filtered_images = [] + label_label = helper.get_label_feature() for image in images: if image in skip_list: - logger.debug(f'Skipping {image} because of the skip list') + logger.debug('Skipping %s because of the skip list.', image) continue - features = in_spec.image_file_to_features_file(image) + + features = helper.get_features_for_file(image) if features is None or not os.path.exists(features): - msg = f'File {image} yielded file {features} which does not exist, SKIPPING.' - logger.warning(msg) - warnings.warn(msg) + logger.warning('%s/%s: features file doesn\'t exist, skipping.', image, features) + continue + + # Read it. + label = None + filtered_lines = [] + with open(features, 'r') as rf: + lines = rf.readlines() + for line in lines: + line = line[:-1] + if line.startswith(label_label): + label = ''.join(line.split(':')[1:]) + label = label.strip() + else: + filtered_lines.append(line) + + if not helper.is_valid_example(image, features, filtered_lines): + logger.warning('%s/%s: Invalid example.', image, features) + if config.config['ml_quick_label_delete_invalid_examples']: + os.remove(image) + os.remove(features) + continue + + if label and not config.config['ml_quick_label_overwrite_labels']: + logger.warning('%s/%s: already has label, skipping.', image, features) + continue + + if config.config['ml_quick_label_skip_where_model_agrees']: + model_says = helper.ask_current_model_about_example(image, features, filtered_lines) + if model_says and label: + if model_says[0] == int(label): + logger.warning( + '%s/%s: Model agrees with current label (%s), skipping.', + image, + features, + label, + ) + continue + + print(f'{image}/{features}: The model disagrees with the current label.') + print(f' ...model says {model_says[0]} with probability {model_says[1]}.') + print(f' ...the example is currently labeled {label}') + + filtered_images.append((image, features)) + return filtered_images + + +def _make_prompt( + helper: QuickLabelHelper, + cursor: int, + num_filtered_images: int, + current_image: str, + current_features: str, + labeled_features: Dict[Tuple[str, str], str], # Examples already labeled +) -> None: + '''Tell an interactive user where they are in the set of examples that + may be labeled and the details of the current example.''' + + label_label = helper.get_label_feature() # the key: of a label in features + filtered_lines = [] + label = labeled_features.get((current_image, current_features), None) + with open(current_features, 'r') as rf: + lines = rf.readlines() + for line in lines: + line = line[:-1] + if len(line) == 0: continue + if not line.startswith(label_label): + filtered_lines.append(line) + else: + assert not label + label = line - # Render features and image. + # Prompt... + helper.render_example(current_image, current_features, filtered_lines) + print(f'{cursor}/{num_filtered_images} ({cursor/num_filtered_images*100.0:.1f}%) | ', end='') + print(f'{ansi.bold()}{current_image} / {current_features}{ansi.reset()}:') + print(f' ...{len(labeled_features)} currently unwritten label(s) ("W" to write).') + if label: + if (current_image, current_features) in labeled_features: + print(f' ...This example is labeled but not yet saved: {label}') + else: + print(f' ...This example is already labeled on disk: {label}') + else: + print(' ...This example is currently unlabeled') + guess = helper.ask_current_model_about_example(current_image, current_features, filtered_lines) + if guess: + print(f' ...The ML Model says {guess}') + print() + + +def _write_labeled_features( + helper: QuickLabelHelper, + labeled_features: Dict[Tuple[str, str], str], + skip_list: Set[str], +) -> None: + label_label = helper.get_label_feature() + for image_features, label in labeled_features.items(): + image = image_features[0] + features = image_features[1] filtered_lines = [] - with open(features, "r") as f: - lines = f.readlines() - saw_label = False + with open(features, 'r') as rf: + lines = rf.readlines() for line in lines: line = line[:-1] - if in_spec.label not in line: + line = line.strip() + if line == '': + continue + if not line.startswith(label_label): filtered_lines.append(line) + + filtered_lines.append(f'{label_label}: {label}') + with open(features, 'w') as f: + f.writelines(line + '\n' for line in filtered_lines) + if config.config['ml_quick_label_use_skip_lists']: + skip_list.add(image) + print(f'Wrote {len(labeled_features)} labels.') + + +def quick_label(helper: QuickLabelHelper) -> None: + '''Pass your QuickLabelHelper implementing class to this function and + it will allow users to label examples and persist them to disk. + + ''' + skip_list = _maybe_read_skip_list() + + # Ask helper for an initial set of files. + images = helper.get_candidate_files() + if len(images) == 0: + logger.warning('No images files to operate on.') + return + logger.info('There are %d starting candidate images.', len(images)) + + # Filter out any that can't be converted to features or already have a + # label (unless they used --ml_qukck_label_overwrite_labels). + filtered_images = _filter_images(images, skip_list, helper) + if len(filtered_images) == 0: + logger.warning('No image files to operate on (post filter).') + return + logger.info('There are %d candidate images post filtering.', len(filtered_images)) + + # Allow the user to label the non-filtered images one by one. + labeled_features: Dict[Tuple[str, str], str] = {} + cursor = 0 + while True: + assert 0 <= cursor < len(filtered_images) + + image = filtered_images[cursor][0] + assert os.path.exists(image) + features = filtered_images[cursor][1] + assert features and os.path.exists(features) + + # Render the features, image and prompt. + _make_prompt(helper, cursor, len(filtered_images), image, features, labeled_features) + try: + # Did they want everything labelled the same? + label_everything = helper.get_everything_label() + if label_everything: + labeled_features[(image, features)] = label_everything + filtered_images.remove((image, features)) + if len(filtered_images) == 0: + print('Nothing more to label.') + break + if cursor >= len(filtered_images): + cursor -= 1 + + # Otherwise ask about each individual example. else: - saw_label = True - - if not saw_label or config.config['ml_quick_label_overwrite_labels']: - logger.info(features) - os.system(f'xv {image} &') - keystroke = input_utils.single_keystroke_response( - in_spec.valid_keystrokes, - prompt=in_spec.prompt, - ) - os.system('killall xv') - label_value = in_spec.keystroke_to_label(keystroke) - filtered_lines.append(f"{in_spec.label}: {label_value}\n") - with open(features, 'w') as f: - f.writelines("%s\n" % line for line in filtered_lines) - skip_list.add(image) - write_skip_list(skip_list) + labelling_keystrokes = helper.get_labelling_keystrokes() + valid_keystrokes = ['<', '>', 'W', 'Q', '?'] + valid_keystrokes += labelling_keystrokes.keys() + prompt = ','.join(valid_keystrokes) + print(f'What should I do? [{prompt}]: ', end='') + sys.stdout.flush() + keystroke = input_utils.single_keystroke_response(valid_keystrokes) + print() + + if keystroke == '?': + print( + ''' + > = Don't label, move to the next example. + < = Don't label, move to the previous example. + W = Write pending labels to disk now. + Q = Quit labeling now. + ? = This message. + else = These keystrokes assign a label to the example and persist it.''' + ) + input_utils.press_any_key() + elif keystroke == 'Q': + logger.info('Ok, stopping for now.') + if len(labeled_features): + logger.info('Discarding %d unsaved labels.', len(labeled_features)) + break + elif keystroke == '>': + cursor += 1 + if cursor >= len(filtered_images): + print('Wrapping around...') + cursor = 0 + elif keystroke == '<': + cursor -= 1 + if cursor < 0: + print('Wrapping around...') + cursor = len(filtered_images) - 1 + elif keystroke == 'W': + _write_labeled_features(helper, labeled_features, skip_list) + labeled_features = {} + elif keystroke in labelling_keystrokes: + label_value = labelling_keystrokes[keystroke] + labeled_features[(image, features)] = label_value + filtered_images.remove((image, features)) + if len(filtered_images) == 0: + print('Nothing more to label.') + break + if cursor >= len(filtered_images): + cursor -= 1 + else: + print(f'Unknown keystroke: {keystroke}. Try again.') + finally: + helper.unrender_example(image, features) + + if len(labeled_features): + yn = input_utils.yn_response(f'Save the {len(labeled_features)} labels to disk? [Y/N]: ') + if yn in ('Y', 'y'): + _write_labeled_features(helper, labeled_features, skip_list) + _maybe_write_skip_list(skip_list)