Overhaul quick_labeler part 2: understand what the current model says.
[python_utils.git] / ml / quick_label.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """A helper to facilitate quick manual labeling of ML training data."""
6
7 import logging
8 import os
9 import sys
10 import time
11 from abc import abstractmethod
12 from typing import Any, Dict, List, Optional, Set, Tuple
13
14 import argparse_utils
15 import config
16
17 logger = logging.getLogger(__name__)
18 parser = config.add_commandline_args(
19     f"ML Quick Labeler ({__file__})",
20     "Args related to quick labeling of ML training data",
21 )
22 parser.add_argument(
23     "--ml_quick_label_skip_list_path",
24     default="./qlabel_skip_list.txt",
25     metavar="FILENAME",
26     type=argparse_utils.valid_filename,
27     help="Path to file in which to store already labeled data.",
28 )
29 parser.add_argument(
30     "--ml_quick_label_use_skip_lists",
31     default=True,
32     action=argparse_utils.ActionNoYes,
33     help='Should we use a skip list file to speed up execution?',
34 )
35 parser.add_argument(
36     "--ml_quick_label_overwrite_labels",
37     default=False,
38     action=argparse_utils.ActionNoYes,
39     help='Enable overwriting existing labels; default is to not relabel.',
40 )
41 parser.add_argument(
42     '--ml_quick_label_skip_where_model_agrees',
43     default=False,
44     action=argparse_utils.ActionNoYes,
45     help='Do not filter examples where the model disagrees with the current label.',
46 )
47 parser.add_argument(
48     'ml_quick_label_delete_invalid_examples',
49     default=False,
50     action='store_true',
51     help='If set we will delete invalid training examples.',
52 )
53
54
55 class QuickLabelHelper:
56     '''To use this quick labeler your code must create a subclass of this
57     class and implement the abstract methods below.  See comments for
58     detailed semantics.'''
59
60     @abstractmethod
61     def get_candidate_files(self) -> List[str]:
62         '''This must return a list of raw candidate files for labeling.'''
63         pass
64
65     @abstractmethod
66     def get_features_for_file(self, filename: str) -> Optional[str]:
67         '''Given a raw file, return its features file.'''
68         pass
69
70     @abstractmethod
71     def render_example(self, filename: str, features: str, lines: List[str]) -> None:
72         '''Render a raw file with its features for the user.'''
73         pass
74
75     @abstractmethod
76     def unrender_example(self, filename: str, features: str, lines: List[str]) -> None:
77         '''Unrender a raw file with its features (if necessary)...'''
78         pass
79
80     @abstractmethod
81     def is_valid_example(self, filename: str, features: str, lines: List[str]) -> bool:
82         '''Returns true iff the example is valid (all features are valid, there
83         are the correct number of features, etc...'''
84         pass
85
86     @abstractmethod
87     def ask_current_model_about_example(
88         self,
89         filename: str,
90         features: str,
91         lines: List[str],
92     ) -> Any:
93         '''Ask the current ML model about this example, if necessary.'''
94         pass
95
96     @abstractmethod
97     def get_labelling_keystrokes(self) -> Dict[str, Any]:
98         '''What keystrokes should be considered valid label actions and what
99         label does each keystroke map into.  e.g. if you want to ask
100         the user to hit 'y' for 'yes' and code that as 255 in your
101         features or to hit 'n' for 'no' and code that as 0 in your
102         features, return:
103
104             { 'y': 255, 'n': 0 }
105
106         '''
107         pass
108
109     @abstractmethod
110     def get_everything_label(self) -> Any:
111         '''If this returns something other than None it indicates that every
112         example selected should be labeled with this result.  Caveat
113         emptor, we will klobber all your files.
114
115         '''
116         pass
117
118     @abstractmethod
119     def get_label_feature(self) -> str:
120         '''What feature denotes the example's label?  This is used to detect
121         when examples already have a label and to assign labels to
122         examples.'''
123         pass
124
125
126 def _maybe_read_skip_list() -> Set[str]:
127     '''Reads the skip list (files to just bypass) into memory if using.'''
128
129     ret: Set[str] = set()
130     if config.config['ml_quick_label_use_skip_lists']:
131         quick_skip_file = config.config['ml_quick_label_skip_list_path']
132         if os.path.exists(quick_skip_file):
133             with open(quick_skip_file, 'r') as f:
134                 lines = f.readlines()
135             for line in lines:
136                 line = line[:-1]
137                 line.strip()
138                 ret.add(line)
139         logger.debug('Read %s and found %d entries.', quick_skip_file, len(ret))
140     return ret
141
142
143 def _maybe_write_skip_list(skip_list) -> None:
144     '''Writes the skip list (files to just bypass) to disk if using.'''
145
146     if config.config['ml_quick_label_use_skip_lists']:
147         quick_skip_file = config.config['ml_quick_label_skip_list_path']
148         with open(quick_skip_file, 'w') as f:
149             for filename in skip_list:
150                 filename = filename.strip()
151                 if len(filename) > 0:
152                     f.write(f'{filename}\n')
153         logger.debug('Updated %s', quick_skip_file)
154
155
156 def _filter_images(
157     images: List[str], skip_list: Set[str], helper: QuickLabelHelper
158 ) -> List[Tuple[str, str]]:
159     filtered_images = []
160     label_label = helper.get_label_feature()
161     for image in images:
162         if image in skip_list:
163             logger.debug('Skipping %s because of the skip list', image)
164             continue
165
166         features = helper.get_features_for_file(image)
167         if features is None or not os.path.exists(features):
168             logger.warning('%s/%s: features doesn\'t exist, SKIPPING.', image, features)
169             continue
170
171         label = None
172         filtered_lines = []
173         with open(features, 'r') as rf:
174             lines = rf.readlines()
175         for line in lines:
176             line = line[:-1]
177             if line.startswith(label_label):
178                 label = ''.join(line.split(':')[1:])
179                 label = label.strip()
180             else:
181                 filtered_lines.append(line)
182
183         if not helper.is_valid_example(image, features, filtered_lines):
184             logger.warning('%s/%s: Invalid example.', image, features)
185             if config.config['ml_quick_label_delete_invalid_examples']:
186                 os.remove(image)
187                 os.remove(features)
188             continue
189
190         if label and not config.config['ml_quick_label_overwrite_labels']:
191             logger.warning('%s/%s: already has label, SKIPPING.', image, features)
192             continue
193
194         if config.config['ml_quick_label_skip_where_model_agrees']:
195             model_says = helper.ask_current_model_about_example(image, features, filtered_lines)
196             if model_says and label:
197                 if model_says[0] == int(label):
198                     continue
199                 print(f'{image}/{features}: The model disagrees with the current label.')
200                 print(f'    ...model says {model_says[0]} with probability {model_says[1]}.')
201                 print(f'    ...the example is currently labeled {label}')
202         filtered_images.append((image, features))
203     return filtered_images
204
205
206 def quick_label(helper: QuickLabelHelper) -> None:
207     skip_list = _maybe_read_skip_list()
208
209     # Ask helper for an initial set of files.
210     images = helper.get_candidate_files()
211     if len(images) == 0:
212         logger.warning('No images files to operate on.')
213         return
214
215     # Filter out any that can't be converted to features or already have a
216     # label (unless they used --ml_qukck_label_overwrite_labels).
217     filtered_images = _filter_images(images, skip_list, helper)
218     if len(filtered_images) == 0:
219         logger.warning('No image files to operate on (post filter).')
220         return
221
222     # Allow the user to label the non-filtered images one by one.
223     import input_utils
224
225     cursor = 0
226     label_label = helper.get_label_feature()
227     while True:
228         assert 0 <= cursor < len(filtered_images)
229
230         image = filtered_images[cursor][0]
231         assert os.path.exists(image)
232         features = filtered_images[cursor][1]
233         assert features and os.path.exists(features)
234
235         filtered_lines = []
236         label = None
237         with open(features, 'r') as rf:
238             lines = rf.readlines()
239         for line in lines:
240             line = line[:-1]
241             if not line.startswith(label_label):
242                 filtered_lines.append(line)
243             else:
244                 label = line
245
246         # Render...
247         helper.render_example(image, features, filtered_lines)
248
249         # Prompt...
250         print(
251             f'{cursor} of {len(filtered_images)} {cursor/len(filtered_images)*100.0:.1f}%): {image}, {features}'
252         )
253         if label:
254             print(f'    ...Already labelled: {label}')
255         else:
256             print('    ...Currently unlabeled')
257         guess = helper.ask_current_model_about_example(image, features, filtered_lines)
258         if guess:
259             print(f'    ...Model says {guess}')
260         print()
261
262         # Did they want everything labelled the same?
263         label_everything = helper.get_everything_label()
264         if label_everything:
265             filtered_lines.append(f"{label_label}: {label_everything}\n")
266             with open(features, 'w') as f:
267                 f.writelines(line + '\n' for line in filtered_lines)
268             if config.config['ml_quick_label_use_skip_lists']:
269                 skip_list.add(image)
270             cursor += 1
271             if cursor >= len(filtered_images):
272                 helper.unrender_example(image, features, filtered_lines)
273                 break
274
275         # Otherwise ask about each example.
276         else:
277             labelling_keystrokes = helper.get_labelling_keystrokes()
278             valid_keystrokes = ['<', '>', 'Q', '?']
279             valid_keystrokes += labelling_keystrokes.keys()
280             prompt = ','.join(valid_keystrokes)
281             print(f'What should I do ({prompt})? ', end='')
282             sys.stdout.flush()
283             keystroke = input_utils.single_keystroke_response(valid_keystrokes)
284             print()
285             if keystroke == 'Q':
286                 logger.info('Ok, stopping for now.  Labeled examples are written to disk')
287                 helper.unrender_example(image, features, filtered_lines)
288                 break
289             elif keystroke == '?':
290                 print(
291                     '''
292     >   =   Don't label, move to the next example.
293     <   =   Don't label, move to the previous example.
294     Q   =   Quit labeling now.
295     ?   =   This message.
296   else  =   These keystrokes assign a label to the example and persist it.'''
297                 )
298                 time.sleep(3.0)
299
300             elif keystroke == '>':
301                 cursor += 1
302                 if cursor >= len(filtered_images):
303                     print('Wrapping around...')
304                     cursor = 0
305             elif keystroke == '<':
306                 cursor -= 1
307                 if cursor < 0:
308                     print('Wrapping around...')
309                     cursor = len(filtered_images) - 1
310             elif keystroke in labelling_keystrokes:
311                 label_value = labelling_keystrokes[keystroke]
312                 filtered_lines.append(f"{label_label}: {label_value}\n")
313                 with open(features, 'w') as f:
314                     f.writelines(line + '\n' for line in filtered_lines)
315                 if config.config['ml_quick_label_use_skip_lists']:
316                     skip_list.add(image)
317                 cursor += 1
318                 if cursor >= len(filtered_images):
319                     print('Wrapping around...')
320                     cursor = 0
321             else:
322                 print(f'Unknown keystroke: {keystroke}')
323         helper.unrender_example(image, features, filtered_lines)
324     _maybe_write_skip_list(skip_list)