da9a1d2e00d15e32cce697e3782a5cfce87ea584
[python_utils.git] / ml / quick_label.py
1 #!/usr/bin/env python3
2
3 """A helper to facilitate quick manual labeling of ML training data."""
4
5 import glob
6 import logging
7 import os
8 import warnings
9 from dataclasses import dataclass
10 from typing import Callable, List, Optional, Set
11
12 import argparse_utils
13 import config
14
15 logger = logging.getLogger(__name__)
16 parser = config.add_commandline_args(
17     f"ML Quick Labeler ({__file__})",
18     "Args related to quick labeling of ML training data",
19 )
20 parser.add_argument(
21     "--ml_quick_label_skip_list_path",
22     default="./qlabel_skip_list.txt",
23     metavar="FILENAME",
24     type=argparse_utils.valid_filename,
25     help="Path to file in which to store already labeled data.",
26 )
27 parser.add_argument(
28     "--ml_quick_label_use_skip_lists",
29     default=True,
30     action=argparse_utils.ActionNoYes,
31     help='Should we use a skip list file to speed up execution?',
32 )
33 parser.add_argument(
34     "--ml_quick_label_overwrite_labels",
35     default=False,
36     action=argparse_utils.ActionNoYes,
37     help='Enable overwriting existing labels; default is to not relabel.',
38 )
39
40
41 @dataclass
42 class InputSpec:
43     """A wrapper around the input data we need to operate; should be
44     populated by the caller."""
45
46     image_file_glob: Optional[str] = None
47     image_file_prepopulated_list: Optional[List[str]] = None
48     image_file_to_features_file: Optional[Callable[[str], str]] = None
49     label: str = ''
50     valid_keystrokes: List[str] = []
51     prompt: str = ''
52     keystroke_to_label: Optional[Callable[[str], str]] = None
53
54
55 def read_skip_list() -> Set[str]:
56     ret: Set[str] = set()
57     if config.config['ml_quick_label_use_skip_lists']:
58         quick_skip_file = config.config['ml_quick_label_skip_list_path']
59         if os.path.exists(quick_skip_file):
60             with open(quick_skip_file, 'r') as f:
61                 lines = f.readlines()
62             for line in lines:
63                 line = line[:-1]
64                 line.strip()
65                 ret.add(line)
66         logger.debug('Read %s and found %d entries.', quick_skip_file, len(ret))
67     return ret
68
69
70 def write_skip_list(skip_list) -> None:
71     if config.config['ml_quick_label_use_skip_lists']:
72         quick_skip_file = config.config['ml_quick_label_skip_list_path']
73         with open(quick_skip_file, 'w') as f:
74             for filename in skip_list:
75                 filename = filename.strip()
76                 if len(filename) > 0:
77                     f.write(f'{filename}\n')
78         logger.debug('Updated %s', quick_skip_file)
79
80
81 def label(in_spec: InputSpec) -> None:
82     import input_utils
83
84     images = []
85     if in_spec.image_file_glob is not None:
86         images += glob.glob(in_spec.image_file_glob)
87     elif in_spec.image_file_prepopulated_list is not None:
88         images += in_spec.image_file_prepopulated_list
89     else:
90         raise ValueError('One of image_file_glob or image_file_prepopulated_list is required')
91
92     skip_list = read_skip_list()
93     for image in images:
94         if image in skip_list:
95             logger.debug('Skipping %s because of the skip list', image)
96             continue
97         assert in_spec.image_file_to_features_file
98         features = in_spec.image_file_to_features_file(image)
99         if features is None or not os.path.exists(features):
100             msg = f'File {image} yielded file {features} which does not exist, SKIPPING.'
101             logger.warning(msg)
102             warnings.warn(msg)
103             continue
104
105         # Render features and image.
106         filtered_lines = []
107         with open(features, "r") as f:
108             lines = f.readlines()
109         saw_label = False
110         for line in lines:
111             line = line[:-1]
112             if in_spec.label not in line:
113                 filtered_lines.append(line)
114             else:
115                 saw_label = True
116
117         if not saw_label or config.config['ml_quick_label_overwrite_labels']:
118             logger.info(features)
119             os.system(f'xv {image} &')
120             keystroke = input_utils.single_keystroke_response(
121                 in_spec.valid_keystrokes,
122                 prompt=in_spec.prompt,
123             )
124             os.system('killall xv')
125             assert in_spec.keystroke_to_label
126             label_value = in_spec.keystroke_to_label(keystroke)
127             filtered_lines.append(f"{in_spec.label}: {label_value}\n")
128             with open(features, 'w') as f:
129                 f.writelines(line + '\n' for line in filtered_lines)
130             skip_list.add(image)
131     write_skip_list(skip_list)