Easier and more self documenting patterns for loading/saving Persistent
[python_utils.git] / ml / quick_label.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """A helper to facilitate quick manual labeling of ML training data.
6
7 To use, implement a subclass that implements the QuickLabelHelper
8 interface and pass it into the quick_label function.
9
10 """
11
12 import logging
13 import os
14 import sys
15 from abc import abstractmethod
16 from typing import Any, Dict, List, Optional, Set, Tuple
17
18 import ansi
19 import argparse_utils
20 import config
21 import input_utils
22
23 logger = logging.getLogger(__name__)
24 parser = config.add_commandline_args(
25     f"ML Quick Labeler ({__file__})",
26     "Args related to quick labeling of ML training data",
27 )
28 parser.add_argument(
29     "--ml_quick_label_skip_list_path",
30     default="./qlabel_skip_list.txt",
31     metavar="FILENAME",
32     type=argparse_utils.valid_filename,
33     help="Path to file in which to store already labeled data.",
34 )
35 parser.add_argument(
36     "--ml_quick_label_use_skip_lists",
37     default=True,
38     action=argparse_utils.ActionNoYes,
39     help='Should we use a skip list file to speed up execution?',
40 )
41 parser.add_argument(
42     "--ml_quick_label_overwrite_labels",
43     default=False,
44     action=argparse_utils.ActionNoYes,
45     help='Enable overwriting of existing labels; default is to not relabel.',
46 )
47 parser.add_argument(
48     '--ml_quick_label_skip_where_model_agrees',
49     default=False,
50     action=argparse_utils.ActionNoYes,
51     help='Do not filter examples where the model disagrees with the current label.',
52 )
53 parser.add_argument(
54     '--ml_quick_label_delete_invalid_examples',
55     default=False,
56     action='store_true',
57     help='If set we will delete invalid training examples.',
58 )
59
60
61 class QuickLabelHelper:
62     '''To use this quick labeler your code must create a subclass of this
63     class and implement the abstract methods below.  See comments for
64     detailed semantics.'''
65
66     @abstractmethod
67     def get_candidate_files(self) -> List[str]:
68         '''This must return a list of raw candidate files for labeling.'''
69         pass
70
71     @abstractmethod
72     def get_features_for_file(self, filename: str) -> Optional[str]:
73         '''Given a raw file, return its features file.'''
74         pass
75
76     @abstractmethod
77     def render_example(self, filename: str, features: str, lines: List[str]) -> None:
78         '''Render a raw file with its features for the user to judge.'''
79         pass
80
81     @abstractmethod
82     def unrender_example(self, filename: str, features: str) -> None:
83         '''Unrender a raw file with its features (if necessary)...'''
84         pass
85
86     @abstractmethod
87     def is_valid_example(self, filename: str, features: str, lines: List[str]) -> bool:
88         '''Returns true iff the example is valid (all features are valid, there
89         are the correct number of features, etc...'''
90         pass
91
92     @abstractmethod
93     def ask_current_model_about_example(
94         self,
95         filename: str,
96         features: str,
97         lines: List[str],
98     ) -> Any:
99         '''Ask the current ML model about this example, if necessary/possible.
100         Returns None to indicate no model to consult.'''
101         pass
102
103     @abstractmethod
104     def get_labelling_keystrokes(self) -> Dict[str, Any]:
105         '''What keystrokes should be considered valid label actions and what
106         label does each keystroke map into.  e.g. if you want to ask
107         the user to hit 'y' for 'yes' and code that as 255 in your
108         features or to hit 'n' for 'no' and code that as 0 in your
109         features, return:
110
111             { 'y': 255, 'n': 0 }
112
113         '''
114         pass
115
116     @abstractmethod
117     def get_everything_label(self) -> Any:
118         '''If this returns something other than None it indicates that every
119         example selected should be labeled with this result.  Caveat
120         emptor, we will klobber all your files.
121
122         '''
123         pass
124
125     @abstractmethod
126     def get_label_feature(self) -> str:
127         '''What feature denotes the example's label?  This is used to detect
128         when examples already have a label and to assign labels to
129         examples.'''
130         pass
131
132
133 def _maybe_read_skip_list() -> Set[str]:
134     '''Reads the skip list (files to just bypass) into memory if using.'''
135
136     ret: Set[str] = set()
137     if config.config['ml_quick_label_use_skip_lists']:
138         quick_skip_file = config.config['ml_quick_label_skip_list_path']
139         if os.path.exists(quick_skip_file):
140             with open(quick_skip_file, 'r') as f:
141                 lines = f.readlines()
142             for line in lines:
143                 line = line[:-1]
144                 line.strip()
145                 ret.add(line)
146         logger.debug('Read %s and found %d entries.', quick_skip_file, len(ret))
147     return ret
148
149
150 def _maybe_write_skip_list(skip_list) -> None:
151     '''Writes the skip list (files to just bypass) to disk if using.'''
152
153     if config.config['ml_quick_label_use_skip_lists']:
154         quick_skip_file = config.config['ml_quick_label_skip_list_path']
155         with open(quick_skip_file, 'w') as f:
156             for filename in skip_list:
157                 filename = filename.strip()
158                 if len(filename) > 0:
159                     f.write(f'{filename}\n')
160         logger.debug('Updated %s', quick_skip_file)
161
162
163 def _filter_images(
164     images: List[str], skip_list: Set[str], helper: QuickLabelHelper
165 ) -> List[Tuple[str, str]]:
166     '''Discard examples that have particular characteristics.  e.g.
167     those that are already labeled and whose current label agrees with
168     the ML model, etc...'''
169
170     filtered_images = []
171     label_label = helper.get_label_feature()
172     for image in images:
173         if image in skip_list:
174             logger.debug('Skipping %s because of the skip list.', image)
175             continue
176
177         features = helper.get_features_for_file(image)
178         if features is None or not os.path.exists(features):
179             logger.warning('%s/%s: features file doesn\'t exist, skipping.', image, features)
180             continue
181
182         # Read it.
183         label = None
184         filtered_lines = []
185         with open(features, 'r') as rf:
186             lines = rf.readlines()
187         for line in lines:
188             line = line[:-1]
189             if line.startswith(label_label):
190                 label = ''.join(line.split(':')[1:])
191                 label = label.strip()
192             else:
193                 filtered_lines.append(line)
194
195         if not helper.is_valid_example(image, features, filtered_lines):
196             logger.warning('%s/%s: Invalid example.', image, features)
197             if config.config['ml_quick_label_delete_invalid_examples']:
198                 os.remove(image)
199                 os.remove(features)
200             continue
201
202         if label and not config.config['ml_quick_label_overwrite_labels']:
203             logger.warning('%s/%s: already has label, skipping.', image, features)
204             continue
205
206         if config.config['ml_quick_label_skip_where_model_agrees']:
207             model_says = helper.ask_current_model_about_example(image, features, filtered_lines)
208             if model_says and label:
209                 if model_says[0] == int(label):
210                     logger.warning(
211                         '%s/%s: Model agrees with current label (%s), skipping.',
212                         image,
213                         features,
214                         label,
215                     )
216                     continue
217
218                 print(f'{image}/{features}: The model disagrees with the current label.')
219                 print(f'    ...model says {model_says[0]} with probability {model_says[1]}.')
220                 print(f'    ...the example is currently labeled {label}')
221
222         filtered_images.append((image, features))
223     return filtered_images
224
225
226 def _make_prompt(
227     helper: QuickLabelHelper,
228     cursor: int,
229     num_filtered_images: int,
230     current_image: str,
231     current_features: str,
232     labeled_features: Dict[Tuple[str, str], str],  # Examples already labeled
233 ) -> None:
234     '''Tell an interactive user where they are in the set of examples that
235     may be labeled and the details of the current example.'''
236
237     label_label = helper.get_label_feature()  # the key: of a label in features
238     filtered_lines = []
239     label = labeled_features.get((current_image, current_features), None)
240     with open(current_features, 'r') as rf:
241         lines = rf.readlines()
242     for line in lines:
243         line = line[:-1]
244         if len(line) == 0:
245             continue
246         if not line.startswith(label_label):
247             filtered_lines.append(line)
248         else:
249             assert not label
250             label = line
251
252     # Prompt...
253     helper.render_example(current_image, current_features, filtered_lines)
254     print(f'{cursor}/{num_filtered_images} ({cursor/num_filtered_images*100.0:.1f}%) | ', end='')
255     print(f'{ansi.bold()}{current_image} / {current_features}{ansi.reset()}:')
256     print(f'    ...{len(labeled_features)} currently unwritten label(s) ("W" to write).')
257     if label:
258         if (current_image, current_features) in labeled_features:
259             print(f'    ...This example is labeled but not yet saved: {label}')
260         else:
261             print(f'    ...This example is already labeled on disk: {label}')
262     else:
263         print('    ...This example is currently unlabeled')
264     guess = helper.ask_current_model_about_example(current_image, current_features, filtered_lines)
265     if guess:
266         print(f'    ...The ML Model says {guess}')
267     print()
268
269
270 def _write_labeled_features(
271     helper: QuickLabelHelper,
272     labeled_features: Dict[Tuple[str, str], str],
273     skip_list: Set[str],
274 ) -> None:
275     label_label = helper.get_label_feature()
276     for image_features, label in labeled_features.items():
277         image = image_features[0]
278         features = image_features[1]
279         filtered_lines = []
280         with open(features, 'r') as rf:
281             lines = rf.readlines()
282         for line in lines:
283             line = line[:-1]
284             line = line.strip()
285             if line == '':
286                 continue
287             if not line.startswith(label_label):
288                 filtered_lines.append(line)
289
290         filtered_lines.append(f'{label_label}: {label}')
291         with open(features, 'w') as f:
292             f.writelines(line + '\n' for line in filtered_lines)
293             if config.config['ml_quick_label_use_skip_lists']:
294                 skip_list.add(image)
295     print(f'Wrote {len(labeled_features)} labels.')
296
297
298 def quick_label(helper: QuickLabelHelper) -> None:
299     '''Pass your QuickLabelHelper implementing class to this function and
300     it will allow users to label examples and persist them to disk.
301
302     '''
303     skip_list = _maybe_read_skip_list()
304
305     # Ask helper for an initial set of files.
306     images = helper.get_candidate_files()
307     if len(images) == 0:
308         logger.warning('No images files to operate on.')
309         return
310     logger.info('There are %d starting candidate images.', len(images))
311
312     # Filter out any that can't be converted to features or already have a
313     # label (unless they used --ml_qukck_label_overwrite_labels).
314     filtered_images = _filter_images(images, skip_list, helper)
315     if len(filtered_images) == 0:
316         logger.warning('No image files to operate on (post filter).')
317         return
318     logger.info('There are %d candidate images post filtering.', len(filtered_images))
319
320     # Allow the user to label the non-filtered images one by one.
321     labeled_features: Dict[Tuple[str, str], str] = {}
322     cursor = 0
323     while True:
324         assert 0 <= cursor < len(filtered_images)
325
326         image = filtered_images[cursor][0]
327         assert os.path.exists(image)
328         features = filtered_images[cursor][1]
329         assert features and os.path.exists(features)
330
331         # Render the features, image and prompt.
332         _make_prompt(helper, cursor, len(filtered_images), image, features, labeled_features)
333         try:
334             # Did they want everything labelled the same?
335             label_everything = helper.get_everything_label()
336             if label_everything:
337                 labeled_features[(image, features)] = label_everything
338                 filtered_images.remove((image, features))
339                 if len(filtered_images) == 0:
340                     print('Nothing more to label.')
341                     break
342                 if cursor >= len(filtered_images):
343                     cursor -= 1
344
345             # Otherwise ask about each individual example.
346             else:
347                 labelling_keystrokes = helper.get_labelling_keystrokes()
348                 valid_keystrokes = ['<', '>', 'W', 'Q', '?']
349                 valid_keystrokes += labelling_keystrokes.keys()
350                 prompt = ','.join(valid_keystrokes)
351                 print(f'What should I do?  [{prompt}]: ', end='')
352                 sys.stdout.flush()
353                 keystroke = input_utils.single_keystroke_response(valid_keystrokes)
354                 print()
355
356                 if keystroke == '?':
357                     print(
358                         '''
359         >   =   Don't label, move to the next example.
360         <   =   Don't label, move to the previous example.
361         W   =   Write pending labels to disk now.
362         Q   =   Quit labeling now.
363         ?   =   This message.
364       else  =   These keystrokes assign a label to the example and persist it.'''
365                     )
366                     input_utils.press_any_key()
367                 elif keystroke == 'Q':
368                     logger.info('Ok, stopping for now.')
369                     if len(labeled_features):
370                         logger.info('Discarding %d unsaved labels.', len(labeled_features))
371                     break
372                 elif keystroke == '>':
373                     cursor += 1
374                     if cursor >= len(filtered_images):
375                         print('Wrapping around...')
376                         cursor = 0
377                 elif keystroke == '<':
378                     cursor -= 1
379                     if cursor < 0:
380                         print('Wrapping around...')
381                         cursor = len(filtered_images) - 1
382                 elif keystroke == 'W':
383                     _write_labeled_features(helper, labeled_features, skip_list)
384                     labeled_features = {}
385                 elif keystroke in labelling_keystrokes:
386                     label_value = labelling_keystrokes[keystroke]
387                     labeled_features[(image, features)] = label_value
388                     filtered_images.remove((image, features))
389                     if len(filtered_images) == 0:
390                         print('Nothing more to label.')
391                         break
392                     if cursor >= len(filtered_images):
393                         cursor -= 1
394                 else:
395                     print(f'Unknown keystroke: {keystroke}.  Try again.')
396         finally:
397             helper.unrender_example(image, features)
398
399     if len(labeled_features):
400         yn = input_utils.yn_response(f'Save the {len(labeled_features)} labels to disk?  [Y/N]: ')
401         if yn in ('Y', 'y'):
402             _write_labeled_features(helper, labeled_features, skip_list)
403     _maybe_write_skip_list(skip_list)