From 723f5fcb660eef79cc455cfb3f2eebfa667c90fa Mon Sep 17 00:00:00 2001 From: Scott Gasch Date: Tue, 19 Apr 2022 13:53:49 -0700 Subject: [PATCH] Improve the quick labeler again. --- ml/quick_label.py | 255 +++++++++++++++++++++++++++++----------------- 1 file changed, 159 insertions(+), 96 deletions(-) diff --git a/ml/quick_label.py b/ml/quick_label.py index 05efbae..00acf05 100644 --- a/ml/quick_label.py +++ b/ml/quick_label.py @@ -7,12 +7,13 @@ import logging import os import sys -import time from abc import abstractmethod from typing import Any, Dict, List, Optional, Set, Tuple +import ansi import argparse_utils import config +import input_utils logger = logging.getLogger(__name__) parser = config.add_commandline_args( @@ -36,7 +37,7 @@ parser.add_argument( "--ml_quick_label_overwrite_labels", default=False, action=argparse_utils.ActionNoYes, - help='Enable overwriting existing labels; default is to not relabel.', + help='Enable overwriting of existing labels; default is to not relabel.', ) parser.add_argument( '--ml_quick_label_skip_where_model_agrees', @@ -45,7 +46,7 @@ parser.add_argument( help='Do not filter examples where the model disagrees with the current label.', ) parser.add_argument( - 'ml_quick_label_delete_invalid_examples', + '--ml_quick_label_delete_invalid_examples', default=False, action='store_true', help='If set we will delete invalid training examples.', @@ -73,7 +74,7 @@ class QuickLabelHelper: pass @abstractmethod - def unrender_example(self, filename: str, features: str, lines: List[str]) -> None: + def unrender_example(self, filename: str, features: str) -> None: '''Unrender a raw file with its features (if necessary)...''' pass @@ -160,14 +161,15 @@ def _filter_images( label_label = helper.get_label_feature() for image in images: if image in skip_list: - logger.debug('Skipping %s because of the skip list', image) + logger.debug('Skipping %s because of the skip list.', image) continue features = helper.get_features_for_file(image) if features is None or not os.path.exists(features): - logger.warning('%s/%s: features doesn\'t exist, SKIPPING.', image, features) + logger.warning('%s/%s: features file doesn\'t exist, skipping.', image, features) continue + # Read it. label = None filtered_lines = [] with open(features, 'r') as rf: @@ -188,21 +190,98 @@ def _filter_images( continue if label and not config.config['ml_quick_label_overwrite_labels']: - logger.warning('%s/%s: already has label, SKIPPING.', image, features) + logger.warning('%s/%s: already has label, skipping.', image, features) continue if config.config['ml_quick_label_skip_where_model_agrees']: model_says = helper.ask_current_model_about_example(image, features, filtered_lines) if model_says and label: if model_says[0] == int(label): + logger.warning( + '%s/%s: Model agrees with current label (%s), skipping.', + image, + features, + label, + ) continue + print(f'{image}/{features}: The model disagrees with the current label.') print(f' ...model says {model_says[0]} with probability {model_says[1]}.') print(f' ...the example is currently labeled {label}') + filtered_images.append((image, features)) return filtered_images +def _make_prompt( + helper: QuickLabelHelper, + cursor: int, + filtered_images: List[Tuple[str, str]], + image: str, + features: str, + labeled_features: Dict[Tuple[str, str], str], +) -> None: + label_label = helper.get_label_feature() + filtered_lines = [] + label = labeled_features.get((image, features), None) + with open(features, 'r') as rf: + lines = rf.readlines() + for line in lines: + line = line[:-1] + if len(line) == 0: + continue + if not line.startswith(label_label): + filtered_lines.append(line) + else: + assert not label + label = line + + # Prompt... + helper.render_example(image, features, filtered_lines) + print(f'{cursor}/{len(filtered_images)} ({cursor/len(filtered_images)*100.0:.1f}%) | ', end='') + print(f'{ansi.bold()}{image} / {features}{ansi.reset()}:') + print(f' ...{len(labeled_features)} currently unsaved labels ("W" to save).') + if label: + if (image, features) in labeled_features: + print(f' ...This example is labeled but not yet saved: {label}') + else: + print(f' ...This example is already labeled on disk: {label}') + else: + print(' ...This example is currently unlabeled') + guess = helper.ask_current_model_about_example(image, features, filtered_lines) + if guess: + print(f' ...The ML Model says {guess}') + print() + + +def _write_labeled_features( + helper: QuickLabelHelper, + labeled_features: Dict[Tuple[str, str], str], + skip_list: Set[str], +) -> None: + label_label = helper.get_label_feature() + for image_features, label in labeled_features.items(): + image = image_features[0] + features = image_features[1] + filtered_lines = [] + with open(features, 'r') as rf: + lines = rf.readlines() + for line in lines: + line = line[:-1] + line = line.strip() + if line == '': + continue + if not line.startswith(label_label): + filtered_lines.append(line) + + filtered_lines.append(f'{label_label}: {label}') + with open(features, 'w') as f: + f.writelines(line + '\n' for line in filtered_lines) + if config.config['ml_quick_label_use_skip_lists']: + skip_list.add(image) + print(f'Wrote {len(labeled_features)} labels.') + + def quick_label(helper: QuickLabelHelper) -> None: skip_list = _maybe_read_skip_list() @@ -211,6 +290,7 @@ def quick_label(helper: QuickLabelHelper) -> None: if len(images) == 0: logger.warning('No images files to operate on.') return + logger.info('There are %d starting candidate images.', len(images)) # Filter out any that can't be converted to features or already have a # label (unless they used --ml_qukck_label_overwrite_labels). @@ -218,12 +298,12 @@ def quick_label(helper: QuickLabelHelper) -> None: if len(filtered_images) == 0: logger.warning('No image files to operate on (post filter).') return + logger.info('There are %d candidate images post filtering.', len(filtered_images)) # Allow the user to label the non-filtered images one by one. - import input_utils + labeled_features: Dict[Tuple[str, str], str] = {} cursor = 0 - label_label = helper.get_label_feature() while True: assert 0 <= cursor < len(filtered_images) @@ -232,93 +312,76 @@ def quick_label(helper: QuickLabelHelper) -> None: features = filtered_images[cursor][1] assert features and os.path.exists(features) - filtered_lines = [] - label = None - with open(features, 'r') as rf: - lines = rf.readlines() - for line in lines: - line = line[:-1] - if not line.startswith(label_label): - filtered_lines.append(line) - else: - label = line - - # Render... - helper.render_example(image, features, filtered_lines) - - # Prompt... - print( - f'{cursor} of {len(filtered_images)} {cursor/len(filtered_images)*100.0:.1f}%): {image}, {features}' - ) - if label: - print(f' ...Already labelled: {label}') - else: - print(' ...Currently unlabeled') - guess = helper.ask_current_model_about_example(image, features, filtered_lines) - if guess: - print(f' ...Model says {guess}') - print() - - # Did they want everything labelled the same? - label_everything = helper.get_everything_label() - if label_everything: - filtered_lines.append(f"{label_label}: {label_everything}\n") - with open(features, 'w') as f: - f.writelines(line + '\n' for line in filtered_lines) - if config.config['ml_quick_label_use_skip_lists']: - skip_list.add(image) - cursor += 1 - if cursor >= len(filtered_images): - helper.unrender_example(image, features, filtered_lines) - break - - # Otherwise ask about each example. - else: - labelling_keystrokes = helper.get_labelling_keystrokes() - valid_keystrokes = ['<', '>', 'Q', '?'] - valid_keystrokes += labelling_keystrokes.keys() - prompt = ','.join(valid_keystrokes) - print(f'What should I do ({prompt})? ', end='') - sys.stdout.flush() - keystroke = input_utils.single_keystroke_response(valid_keystrokes) - print() - if keystroke == 'Q': - logger.info('Ok, stopping for now. Labeled examples are written to disk') - helper.unrender_example(image, features, filtered_lines) - break - elif keystroke == '?': - print( - ''' - > = Don't label, move to the next example. - < = Don't label, move to the previous example. - Q = Quit labeling now. - ? = This message. - else = These keystrokes assign a label to the example and persist it.''' - ) - time.sleep(3.0) - - elif keystroke == '>': - cursor += 1 + # Render the features, image and prompt. + _make_prompt(helper, cursor, filtered_images, image, features, labeled_features) + try: + # Did they want everything labelled the same? + label_everything = helper.get_everything_label() + if label_everything: + labeled_features[(image, features)] = label_everything + filtered_images.remove((image, features)) + if len(filtered_images) == 0: + print('Nothing more to label.') + break if cursor >= len(filtered_images): - print('Wrapping around...') - cursor = 0 - elif keystroke == '<': - cursor -= 1 - if cursor < 0: - print('Wrapping around...') - cursor = len(filtered_images) - 1 - elif keystroke in labelling_keystrokes: - label_value = labelling_keystrokes[keystroke] - filtered_lines.append(f"{label_label}: {label_value}\n") - with open(features, 'w') as f: - f.writelines(line + '\n' for line in filtered_lines) - if config.config['ml_quick_label_use_skip_lists']: - skip_list.add(image) - cursor += 1 - if cursor >= len(filtered_images): - print('Wrapping around...') - cursor = 0 + cursor -= 1 + + # Otherwise ask about each individual example. else: - print(f'Unknown keystroke: {keystroke}') - helper.unrender_example(image, features, filtered_lines) + labelling_keystrokes = helper.get_labelling_keystrokes() + valid_keystrokes = ['<', '>', 'W', 'Q', '?'] + valid_keystrokes += labelling_keystrokes.keys() + prompt = ','.join(valid_keystrokes) + print(f'What should I do? [{prompt}]: ', end='') + sys.stdout.flush() + keystroke = input_utils.single_keystroke_response(valid_keystrokes) + print() + + if keystroke == '?': + print( + ''' + > = Don't label, move to the next example. + < = Don't label, move to the previous example. + W = Write pending labels to disk now. + Q = Quit labeling now. + ? = This message. + else = These keystrokes assign a label to the example and persist it.''' + ) + input_utils.press_any_key() + elif keystroke == 'Q': + logger.info('Ok, stopping for now.') + if len(labeled_features): + logger.info('Discarding %d unsaved labels.', len(labeled_features)) + break + elif keystroke == '>': + cursor += 1 + if cursor >= len(filtered_images): + print('Wrapping around...') + cursor = 0 + elif keystroke == '<': + cursor -= 1 + if cursor < 0: + print('Wrapping around...') + cursor = len(filtered_images) - 1 + elif keystroke == 'W': + _write_labeled_features(helper, labeled_features, skip_list) + labeled_features = {} + elif keystroke in labelling_keystrokes: + label_value = labelling_keystrokes[keystroke] + labeled_features[(image, features)] = label_value + filtered_images.remove((image, features)) + if len(filtered_images) == 0: + print('Nothing more to label.') + break + if cursor >= len(filtered_images): + cursor -= 1 + else: + print(f'Unknown keystroke: {keystroke}. Try again.') + finally: + helper.unrender_example(image, features) + + if len(labeled_features): + yn = input_utils.yn_response(f'Save the {len(labeled_features)} labels to disk? [Y/N]: ') + if yn in ('Y', 'y'): + _write_labeled_features(helper, labeled_features, skip_list) _maybe_write_skip_list(skip_list) -- 2.45.2