Overhaul quick labeler, part 1.
[python_utils.git] / ml / quick_label.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """A helper to facilitate quick manual labeling of ML training data."""
6
7 import logging
8 import os
9 import sys
10 import time
11 import warnings
12 from abc import abstractmethod
13 from typing import Any, Dict, List, Optional, Set
14
15 import argparse_utils
16 import config
17
18 logger = logging.getLogger(__name__)
19 parser = config.add_commandline_args(
20     f"ML Quick Labeler ({__file__})",
21     "Args related to quick labeling of ML training data",
22 )
23 parser.add_argument(
24     "--ml_quick_label_skip_list_path",
25     default="./qlabel_skip_list.txt",
26     metavar="FILENAME",
27     type=argparse_utils.valid_filename,
28     help="Path to file in which to store already labeled data.",
29 )
30 parser.add_argument(
31     "--ml_quick_label_use_skip_lists",
32     default=True,
33     action=argparse_utils.ActionNoYes,
34     help='Should we use a skip list file to speed up execution?',
35 )
36 parser.add_argument(
37     "--ml_quick_label_overwrite_labels",
38     default=False,
39     action=argparse_utils.ActionNoYes,
40     help='Enable overwriting existing labels; default is to not relabel.',
41 )
42
43
44 class QuickLabelHelper:
45     '''To use this quick labeler your code must create a subclass of this
46     class and implement the abstract methods below.  See comments for
47     detailed semantics.'''
48
49     @abstractmethod
50     def get_candidate_files(self) -> List[str]:
51         '''This must return a list of raw candidate files for labeling.'''
52         pass
53
54     @abstractmethod
55     def get_features_for_file(self, filename: str) -> Optional[str]:
56         '''Given a raw file, return its features file.'''
57         pass
58
59     @abstractmethod
60     def render_example(self, filename: str, features: str, lines: List[str]) -> None:
61         '''Render a raw file with its features for the user.'''
62         pass
63
64     @abstractmethod
65     def unrender_example(self, filename: str, features: str, lines: List[str]) -> None:
66         '''Unrender a raw file with its features (if necessary)...'''
67         pass
68
69     @abstractmethod
70     def ask_current_model_about_example(
71         self,
72         filename: str,
73         features: str,
74         lines: List[str],
75     ) -> Any:
76         '''Ask the current ML model about this example, if necessary.'''
77         pass
78
79     @abstractmethod
80     def get_labelling_keystrokes(self) -> Dict[str, Any]:
81         '''What keystrokes should be considered valid label actions and what
82         label does each keystroke map into.  e.g. if you want to ask
83         the user to hit 'y' for 'yes' and code that as 255 in your
84         features or to hit 'n' for 'no' and code that as 0 in your
85         features, return:
86
87             { 'y': 255, 'n': 0 }
88
89         '''
90         pass
91
92     @abstractmethod
93     def get_everything_label(self) -> Any:
94         '''If this returns something other than None it indicates that every
95         example selected should be labeled with this result.  Caveat
96         emptor, we will klobber all your files.
97
98         '''
99         pass
100
101     @abstractmethod
102     def get_label_feature(self) -> str:
103         '''What feature denotes the example's label?  This is used to detect
104         when examples already have a label and to assign labels to
105         examples.'''
106         pass
107
108
109 def _maybe_read_skip_list() -> Set[str]:
110     '''Reads the skip list (files to just bypass) into memory if using.'''
111
112     ret: Set[str] = set()
113     if config.config['ml_quick_label_use_skip_lists']:
114         quick_skip_file = config.config['ml_quick_label_skip_list_path']
115         if os.path.exists(quick_skip_file):
116             with open(quick_skip_file, 'r') as f:
117                 lines = f.readlines()
118             for line in lines:
119                 line = line[:-1]
120                 line.strip()
121                 ret.add(line)
122         logger.debug('Read %s and found %d entries.', quick_skip_file, len(ret))
123     return ret
124
125
126 def _maybe_write_skip_list(skip_list) -> None:
127     '''Writes the skip list (files to just bypass) to disk if using.'''
128
129     if config.config['ml_quick_label_use_skip_lists']:
130         quick_skip_file = config.config['ml_quick_label_skip_list_path']
131         with open(quick_skip_file, 'w') as f:
132             for filename in skip_list:
133                 filename = filename.strip()
134                 if len(filename) > 0:
135                     f.write(f'{filename}\n')
136         logger.debug('Updated %s', quick_skip_file)
137
138
139 def quick_label(helper: QuickLabelHelper) -> None:
140     # Ask helper for an initial set of files.
141     images = helper.get_candidate_files()
142     if len(images) == 0:
143         logger.warning('No images files to operate on.')
144         return
145
146     # Filter out any that can't be converted to features or already have a
147     # label (unless they used --ml_qukck_label_overwrite_labels).
148     filtered_images = []
149     skip_list = _maybe_read_skip_list()
150     for image in images:
151         if image in skip_list:
152             logger.debug('Skipping %s because of the skip list', image)
153             continue
154
155         features = helper.get_features_for_file(image)
156         if features is None or not os.path.exists(features):
157             msg = f'{image}/{features}: {features} doesn\'t exist, SKIPPING.'
158             logger.warning(msg)
159             warnings.warn(msg)
160             continue
161
162         label_label = helper.get_label_feature()
163         label = None
164         with open(features, 'r') as rf:
165             lines = rf.readlines()
166         for line in lines:
167             line = line[:-1]
168             if line.startswith(label_label):
169                 label = line
170         if label and not config.config['ml_quick_label_overwrite_labels']:
171             msg = f'{image}/{features}: already has label, SKIPPING'
172             logger.warning(msg)
173             warnings.warn(msg)
174             continue
175         filtered_images.append((image, features))
176
177     if len(filtered_images) == 0:
178         logger.warning('No image files to operate on (post filter).')
179         return
180
181     cursor = 0
182     import input_utils
183
184     while True:
185         assert 0 <= cursor < len(filtered_images)
186
187         image = filtered_images[cursor][0]
188         assert os.path.exists(image)
189         features = filtered_images[cursor][1]
190         assert features and os.path.exists(features)
191
192         filtered_lines = []
193         label = None
194         with open(features, 'r') as rf:
195             lines = rf.readlines()
196         for line in lines:
197             line = line[:-1]
198             if not line.startswith(label_label):
199                 filtered_lines.append(line)
200             else:
201                 label = line
202
203         # Render...
204         helper.render_example(image, features, filtered_lines)
205
206         # Prompt...
207         print(
208             f'{cursor} of {len(filtered_images)} {cursor/len(filtered_images)*100.0:.1f}%): {image}, {features}'
209         )
210         if label:
211             print(f'    ...Already labelled: {label}')
212         else:
213             print('    ...Currently unlabeled')
214         guess = helper.ask_current_model_about_example(image, features, filtered_lines)
215         if guess:
216             print(f'    ...Model says {guess}')
217         print()
218
219         # Did they want everything labelled the same?
220         label_everything = helper.get_everything_label()
221         if label_everything:
222             filtered_lines.append(f"{label_label}: {label_everything}\n")
223             with open(features, 'w') as f:
224                 f.writelines(line + '\n' for line in filtered_lines)
225             if config.config['ml_quick_label_use_skip_lists']:
226                 skip_list.add(image)
227             cursor += 1
228             if cursor >= len(filtered_images):
229                 helper.unrender_example(image, features, filtered_lines)
230                 break
231
232         # Otherwise ask about each example.
233         else:
234             labelling_keystrokes = helper.get_labelling_keystrokes()
235             valid_keystrokes = ['<', '>', 'Q', '?']
236             valid_keystrokes += labelling_keystrokes.keys()
237             prompt = ','.join(valid_keystrokes)
238             print(f'What should I do ({prompt})? ', end='')
239             sys.stdout.flush()
240             keystroke = input_utils.single_keystroke_response(valid_keystrokes)
241             print()
242             if keystroke == 'Q':
243                 logger.info('Ok, stopping for now.  Labeled examples are written to disk')
244                 helper.unrender_example(image, features, filtered_lines)
245                 break
246             elif keystroke == '?':
247                 print(
248                     '''
249     >   =   Don't label, move to the next example.
250     <   =   Don't label, move to the previous example.
251     Q   =   Quit labeling now.
252     ?   =   This message.
253   else  =   These keystrokes assign a label to the example and persist it.'''
254                 )
255                 time.sleep(3.0)
256
257             elif keystroke == '>':
258                 cursor += 1
259                 if cursor >= len(filtered_images):
260                     print('Wrapping around...')
261                     cursor = 0
262             elif keystroke == '<':
263                 cursor -= 1
264                 if cursor < 0:
265                     print('Wrapping around...')
266                     cursor = len(filtered_images) - 1
267             elif keystroke in labelling_keystrokes:
268                 label_value = labelling_keystrokes[keystroke]
269                 filtered_lines.append(f"{label_label}: {label_value}\n")
270                 with open(features, 'w') as f:
271                     f.writelines(line + '\n' for line in filtered_lines)
272                 if config.config['ml_quick_label_use_skip_lists']:
273                     skip_list.add(image)
274                 cursor += 1
275                 if cursor >= len(filtered_images):
276                     print('Wrapping around...')
277                     cursor = 0
278             else:
279                 print(f'Unknown keystroke: {keystroke}')
280         helper.unrender_example(image, features, filtered_lines)
281     _maybe_write_skip_list(skip_list)