00acf05f4e389a36096b8cd13e6cbbee56985446
[python_utils.git] / ml / quick_label.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """A helper to facilitate quick manual labeling of ML training data."""
6
7 import logging
8 import os
9 import sys
10 from abc import abstractmethod
11 from typing import Any, Dict, List, Optional, Set, Tuple
12
13 import ansi
14 import argparse_utils
15 import config
16 import input_utils
17
18 logger = logging.getLogger(__name__)
19 parser = config.add_commandline_args(
20     f"ML Quick Labeler ({__file__})",
21     "Args related to quick labeling of ML training data",
22 )
23 parser.add_argument(
24     "--ml_quick_label_skip_list_path",
25     default="./qlabel_skip_list.txt",
26     metavar="FILENAME",
27     type=argparse_utils.valid_filename,
28     help="Path to file in which to store already labeled data.",
29 )
30 parser.add_argument(
31     "--ml_quick_label_use_skip_lists",
32     default=True,
33     action=argparse_utils.ActionNoYes,
34     help='Should we use a skip list file to speed up execution?',
35 )
36 parser.add_argument(
37     "--ml_quick_label_overwrite_labels",
38     default=False,
39     action=argparse_utils.ActionNoYes,
40     help='Enable overwriting of existing labels; default is to not relabel.',
41 )
42 parser.add_argument(
43     '--ml_quick_label_skip_where_model_agrees',
44     default=False,
45     action=argparse_utils.ActionNoYes,
46     help='Do not filter examples where the model disagrees with the current label.',
47 )
48 parser.add_argument(
49     '--ml_quick_label_delete_invalid_examples',
50     default=False,
51     action='store_true',
52     help='If set we will delete invalid training examples.',
53 )
54
55
56 class QuickLabelHelper:
57     '''To use this quick labeler your code must create a subclass of this
58     class and implement the abstract methods below.  See comments for
59     detailed semantics.'''
60
61     @abstractmethod
62     def get_candidate_files(self) -> List[str]:
63         '''This must return a list of raw candidate files for labeling.'''
64         pass
65
66     @abstractmethod
67     def get_features_for_file(self, filename: str) -> Optional[str]:
68         '''Given a raw file, return its features file.'''
69         pass
70
71     @abstractmethod
72     def render_example(self, filename: str, features: str, lines: List[str]) -> None:
73         '''Render a raw file with its features for the user.'''
74         pass
75
76     @abstractmethod
77     def unrender_example(self, filename: str, features: str) -> None:
78         '''Unrender a raw file with its features (if necessary)...'''
79         pass
80
81     @abstractmethod
82     def is_valid_example(self, filename: str, features: str, lines: List[str]) -> bool:
83         '''Returns true iff the example is valid (all features are valid, there
84         are the correct number of features, etc...'''
85         pass
86
87     @abstractmethod
88     def ask_current_model_about_example(
89         self,
90         filename: str,
91         features: str,
92         lines: List[str],
93     ) -> Any:
94         '''Ask the current ML model about this example, if necessary.'''
95         pass
96
97     @abstractmethod
98     def get_labelling_keystrokes(self) -> Dict[str, Any]:
99         '''What keystrokes should be considered valid label actions and what
100         label does each keystroke map into.  e.g. if you want to ask
101         the user to hit 'y' for 'yes' and code that as 255 in your
102         features or to hit 'n' for 'no' and code that as 0 in your
103         features, return:
104
105             { 'y': 255, 'n': 0 }
106
107         '''
108         pass
109
110     @abstractmethod
111     def get_everything_label(self) -> Any:
112         '''If this returns something other than None it indicates that every
113         example selected should be labeled with this result.  Caveat
114         emptor, we will klobber all your files.
115
116         '''
117         pass
118
119     @abstractmethod
120     def get_label_feature(self) -> str:
121         '''What feature denotes the example's label?  This is used to detect
122         when examples already have a label and to assign labels to
123         examples.'''
124         pass
125
126
127 def _maybe_read_skip_list() -> Set[str]:
128     '''Reads the skip list (files to just bypass) into memory if using.'''
129
130     ret: Set[str] = set()
131     if config.config['ml_quick_label_use_skip_lists']:
132         quick_skip_file = config.config['ml_quick_label_skip_list_path']
133         if os.path.exists(quick_skip_file):
134             with open(quick_skip_file, 'r') as f:
135                 lines = f.readlines()
136             for line in lines:
137                 line = line[:-1]
138                 line.strip()
139                 ret.add(line)
140         logger.debug('Read %s and found %d entries.', quick_skip_file, len(ret))
141     return ret
142
143
144 def _maybe_write_skip_list(skip_list) -> None:
145     '''Writes the skip list (files to just bypass) to disk if using.'''
146
147     if config.config['ml_quick_label_use_skip_lists']:
148         quick_skip_file = config.config['ml_quick_label_skip_list_path']
149         with open(quick_skip_file, 'w') as f:
150             for filename in skip_list:
151                 filename = filename.strip()
152                 if len(filename) > 0:
153                     f.write(f'{filename}\n')
154         logger.debug('Updated %s', quick_skip_file)
155
156
157 def _filter_images(
158     images: List[str], skip_list: Set[str], helper: QuickLabelHelper
159 ) -> List[Tuple[str, str]]:
160     filtered_images = []
161     label_label = helper.get_label_feature()
162     for image in images:
163         if image in skip_list:
164             logger.debug('Skipping %s because of the skip list.', image)
165             continue
166
167         features = helper.get_features_for_file(image)
168         if features is None or not os.path.exists(features):
169             logger.warning('%s/%s: features file doesn\'t exist, skipping.', image, features)
170             continue
171
172         # Read it.
173         label = None
174         filtered_lines = []
175         with open(features, 'r') as rf:
176             lines = rf.readlines()
177         for line in lines:
178             line = line[:-1]
179             if line.startswith(label_label):
180                 label = ''.join(line.split(':')[1:])
181                 label = label.strip()
182             else:
183                 filtered_lines.append(line)
184
185         if not helper.is_valid_example(image, features, filtered_lines):
186             logger.warning('%s/%s: Invalid example.', image, features)
187             if config.config['ml_quick_label_delete_invalid_examples']:
188                 os.remove(image)
189                 os.remove(features)
190             continue
191
192         if label and not config.config['ml_quick_label_overwrite_labels']:
193             logger.warning('%s/%s: already has label, skipping.', image, features)
194             continue
195
196         if config.config['ml_quick_label_skip_where_model_agrees']:
197             model_says = helper.ask_current_model_about_example(image, features, filtered_lines)
198             if model_says and label:
199                 if model_says[0] == int(label):
200                     logger.warning(
201                         '%s/%s: Model agrees with current label (%s), skipping.',
202                         image,
203                         features,
204                         label,
205                     )
206                     continue
207
208                 print(f'{image}/{features}: The model disagrees with the current label.')
209                 print(f'    ...model says {model_says[0]} with probability {model_says[1]}.')
210                 print(f'    ...the example is currently labeled {label}')
211
212         filtered_images.append((image, features))
213     return filtered_images
214
215
216 def _make_prompt(
217     helper: QuickLabelHelper,
218     cursor: int,
219     filtered_images: List[Tuple[str, str]],
220     image: str,
221     features: str,
222     labeled_features: Dict[Tuple[str, str], str],
223 ) -> None:
224     label_label = helper.get_label_feature()
225     filtered_lines = []
226     label = labeled_features.get((image, features), None)
227     with open(features, 'r') as rf:
228         lines = rf.readlines()
229     for line in lines:
230         line = line[:-1]
231         if len(line) == 0:
232             continue
233         if not line.startswith(label_label):
234             filtered_lines.append(line)
235         else:
236             assert not label
237             label = line
238
239     # Prompt...
240     helper.render_example(image, features, filtered_lines)
241     print(f'{cursor}/{len(filtered_images)} ({cursor/len(filtered_images)*100.0:.1f}%) | ', end='')
242     print(f'{ansi.bold()}{image} / {features}{ansi.reset()}:')
243     print(f'    ...{len(labeled_features)} currently unsaved labels ("W" to save).')
244     if label:
245         if (image, features) in labeled_features:
246             print(f'    ...This example is labeled but not yet saved: {label}')
247         else:
248             print(f'    ...This example is already labeled on disk: {label}')
249     else:
250         print('    ...This example is currently unlabeled')
251     guess = helper.ask_current_model_about_example(image, features, filtered_lines)
252     if guess:
253         print(f'    ...The ML Model says {guess}')
254     print()
255
256
257 def _write_labeled_features(
258     helper: QuickLabelHelper,
259     labeled_features: Dict[Tuple[str, str], str],
260     skip_list: Set[str],
261 ) -> None:
262     label_label = helper.get_label_feature()
263     for image_features, label in labeled_features.items():
264         image = image_features[0]
265         features = image_features[1]
266         filtered_lines = []
267         with open(features, 'r') as rf:
268             lines = rf.readlines()
269         for line in lines:
270             line = line[:-1]
271             line = line.strip()
272             if line == '':
273                 continue
274             if not line.startswith(label_label):
275                 filtered_lines.append(line)
276
277         filtered_lines.append(f'{label_label}: {label}')
278         with open(features, 'w') as f:
279             f.writelines(line + '\n' for line in filtered_lines)
280             if config.config['ml_quick_label_use_skip_lists']:
281                 skip_list.add(image)
282     print(f'Wrote {len(labeled_features)} labels.')
283
284
285 def quick_label(helper: QuickLabelHelper) -> None:
286     skip_list = _maybe_read_skip_list()
287
288     # Ask helper for an initial set of files.
289     images = helper.get_candidate_files()
290     if len(images) == 0:
291         logger.warning('No images files to operate on.')
292         return
293     logger.info('There are %d starting candidate images.', len(images))
294
295     # Filter out any that can't be converted to features or already have a
296     # label (unless they used --ml_qukck_label_overwrite_labels).
297     filtered_images = _filter_images(images, skip_list, helper)
298     if len(filtered_images) == 0:
299         logger.warning('No image files to operate on (post filter).')
300         return
301     logger.info('There are %d candidate images post filtering.', len(filtered_images))
302
303     # Allow the user to label the non-filtered images one by one.
304
305     labeled_features: Dict[Tuple[str, str], str] = {}
306     cursor = 0
307     while True:
308         assert 0 <= cursor < len(filtered_images)
309
310         image = filtered_images[cursor][0]
311         assert os.path.exists(image)
312         features = filtered_images[cursor][1]
313         assert features and os.path.exists(features)
314
315         # Render the features, image and prompt.
316         _make_prompt(helper, cursor, filtered_images, image, features, labeled_features)
317         try:
318             # Did they want everything labelled the same?
319             label_everything = helper.get_everything_label()
320             if label_everything:
321                 labeled_features[(image, features)] = label_everything
322                 filtered_images.remove((image, features))
323                 if len(filtered_images) == 0:
324                     print('Nothing more to label.')
325                     break
326                 if cursor >= len(filtered_images):
327                     cursor -= 1
328
329             # Otherwise ask about each individual example.
330             else:
331                 labelling_keystrokes = helper.get_labelling_keystrokes()
332                 valid_keystrokes = ['<', '>', 'W', 'Q', '?']
333                 valid_keystrokes += labelling_keystrokes.keys()
334                 prompt = ','.join(valid_keystrokes)
335                 print(f'What should I do?  [{prompt}]: ', end='')
336                 sys.stdout.flush()
337                 keystroke = input_utils.single_keystroke_response(valid_keystrokes)
338                 print()
339
340                 if keystroke == '?':
341                     print(
342                         '''
343         >   =   Don't label, move to the next example.
344         <   =   Don't label, move to the previous example.
345         W   =   Write pending labels to disk now.
346         Q   =   Quit labeling now.
347         ?   =   This message.
348       else  =   These keystrokes assign a label to the example and persist it.'''
349                     )
350                     input_utils.press_any_key()
351                 elif keystroke == 'Q':
352                     logger.info('Ok, stopping for now.')
353                     if len(labeled_features):
354                         logger.info('Discarding %d unsaved labels.', len(labeled_features))
355                     break
356                 elif keystroke == '>':
357                     cursor += 1
358                     if cursor >= len(filtered_images):
359                         print('Wrapping around...')
360                         cursor = 0
361                 elif keystroke == '<':
362                     cursor -= 1
363                     if cursor < 0:
364                         print('Wrapping around...')
365                         cursor = len(filtered_images) - 1
366                 elif keystroke == 'W':
367                     _write_labeled_features(helper, labeled_features, skip_list)
368                     labeled_features = {}
369                 elif keystroke in labelling_keystrokes:
370                     label_value = labelling_keystrokes[keystroke]
371                     labeled_features[(image, features)] = label_value
372                     filtered_images.remove((image, features))
373                     if len(filtered_images) == 0:
374                         print('Nothing more to label.')
375                         break
376                     if cursor >= len(filtered_images):
377                         cursor -= 1
378                 else:
379                     print(f'Unknown keystroke: {keystroke}.  Try again.')
380         finally:
381             helper.unrender_example(image, features)
382
383     if len(labeled_features):
384         yn = input_utils.yn_response(f'Save the {len(labeled_features)} labels to disk?  [Y/N]: ')
385         if yn in ('Y', 'y'):
386             _write_labeled_features(helper, labeled_features, skip_list)
387     _maybe_write_skip_list(skip_list)