Random cleanups and type safety. Created ml subdir.
[python_utils.git] / ml / quick_label.py
1 #!/usr/bin/env python3
2
3 import glob
4 import logging
5 import os
6 from typing import Callable, List, NamedTuple, Optional, Set
7
8 import argparse_utils
9 import config
10
11 logger = logging.getLogger(__name__)
12 parser = config.add_commandline_args(
13     f"ML Quick Labeler ({__file__})",
14     "Args related to quick labeling of ML training data",
15 )
16 parser.add_argument(
17     "--ml_quick_label_skip_list_path",
18     default="./qlabel_skip_list.txt",
19     metavar="FILENAME",
20     type=argparse_utils.valid_filename,
21     help="Path to file in which to store already labeled data.",
22 )
23 parser.add_argument(
24     "--ml_quick_label_use_skip_lists",
25     default=True,
26     action=argparse_utils.ActionNoYes,
27     help='Should we use a skip list file to speed up execution?',
28 )
29 parser.add_argument(
30     "--ml_quick_label_overwrite_labels",
31     default=False,
32     action=argparse_utils.ActionNoYes,
33     help='Enable overwriting existing labels; default is to not relabel.',
34 )
35
36
37 class InputSpec(NamedTuple):
38     image_file_glob: Optional[str]
39     image_file_prepopulated_list: Optional[List[str]]
40     image_file_to_features_file: Callable[[str], str]
41     label: str
42     valid_keystrokes: List[str]
43     prompt: str
44     keystroke_to_label: Callable[[str], str]
45
46
47 def read_skip_list() -> Set[str]:
48     ret: Set[str] = set()
49     if config.config['ml_quick_label_use_skip_lists']:
50         quick_skip_file = config.config['ml_quick_label_skip_list_path']
51         if os.path.exists(quick_skip_file):
52             with open(quick_skip_file, 'r') as f:
53                 lines = f.readlines()
54             for line in lines:
55                 line = line[:-1]
56                 line.strip()
57                 ret.add(line)
58         logger.debug(f'Read {quick_skip_file} and found {len(ret)} entries.')
59     return ret
60
61
62 def write_skip_list(skip_list) -> None:
63     if config.config['ml_quick_label_use_skip_lists']:
64         quick_skip_file = config.config['ml_quick_label_skip_list_path']
65         with open(quick_skip_file, 'w') as f:
66             for filename in skip_list:
67                 filename = filename.strip()
68                 if len(filename) > 0:
69                     f.write(f'{filename}\n')
70         logger.debug(f'Updated {quick_skip_file}')
71
72
73 def label(in_spec: InputSpec) -> None:
74     import input_utils
75
76     images = []
77     if in_spec.image_file_glob is not None:
78         images += glob.glob(in_spec.image_file_glob)
79     elif in_spec.image_file_prepopulated_list is not None:
80         images += in_spec.image_file_prepopulated_list
81     else:
82         raise ValueError(
83             'One of image_file_glob or image_file_prepopulated_list is required'
84         )
85
86     skip_list = read_skip_list()
87     for image in images:
88         if image in skip_list:
89             logger.debug(f'Skipping {image} because of the skip list')
90             continue
91         features = in_spec.image_file_to_features_file(image)
92         if features is None or not os.path.exists(features):
93             logger.warning(
94                 f'File {image} yielded file {features} which does not exist, SKIPPING.'
95             )
96             continue
97
98         # Render features and image.
99         filtered_lines = []
100         with open(features, "r") as f:
101             lines = f.readlines()
102         saw_label = False
103         for line in lines:
104             line = line[:-1]
105             if in_spec.label not in line:
106                 filtered_lines.append(line)
107             else:
108                 saw_label = True
109
110         if not saw_label or config.config['ml_quick_label_overwrite_labels']:
111             logger.info(features)
112             os.system(f'xv {image} &')
113             keystroke = input_utils.single_keystroke_response(
114                 in_spec.valid_keystrokes,
115                 prompt=in_spec.prompt,
116             )
117             os.system('killall xv')
118             label_value = in_spec.keystroke_to_label(keystroke)
119             filtered_lines.append(f"{in_spec.label}: {label_value}\n")
120             with open(features, 'w') as f:
121                 f.writelines("%s\n" % line for line in filtered_lines)
122             skip_list.add(image)
123     write_skip_list(skip_list)