Start using warnings from stdlib.
[python_utils.git] / ml / quick_label.py
1 #!/usr/bin/env python3
2
3 import glob
4 import logging
5 import os
6 from typing import Callable, List, NamedTuple, Optional, Set
7 import warnings
8
9 import argparse_utils
10 import config
11
12 logger = logging.getLogger(__name__)
13 parser = config.add_commandline_args(
14     f"ML Quick Labeler ({__file__})",
15     "Args related to quick labeling of ML training data",
16 )
17 parser.add_argument(
18     "--ml_quick_label_skip_list_path",
19     default="./qlabel_skip_list.txt",
20     metavar="FILENAME",
21     type=argparse_utils.valid_filename,
22     help="Path to file in which to store already labeled data.",
23 )
24 parser.add_argument(
25     "--ml_quick_label_use_skip_lists",
26     default=True,
27     action=argparse_utils.ActionNoYes,
28     help='Should we use a skip list file to speed up execution?',
29 )
30 parser.add_argument(
31     "--ml_quick_label_overwrite_labels",
32     default=False,
33     action=argparse_utils.ActionNoYes,
34     help='Enable overwriting existing labels; default is to not relabel.',
35 )
36
37
38 class InputSpec(NamedTuple):
39     image_file_glob: Optional[str]
40     image_file_prepopulated_list: Optional[List[str]]
41     image_file_to_features_file: Callable[[str], str]
42     label: str
43     valid_keystrokes: List[str]
44     prompt: str
45     keystroke_to_label: Callable[[str], str]
46
47
48 def read_skip_list() -> Set[str]:
49     ret: Set[str] = set()
50     if config.config['ml_quick_label_use_skip_lists']:
51         quick_skip_file = config.config['ml_quick_label_skip_list_path']
52         if os.path.exists(quick_skip_file):
53             with open(quick_skip_file, 'r') as f:
54                 lines = f.readlines()
55             for line in lines:
56                 line = line[:-1]
57                 line.strip()
58                 ret.add(line)
59         logger.debug(f'Read {quick_skip_file} and found {len(ret)} entries.')
60     return ret
61
62
63 def write_skip_list(skip_list) -> None:
64     if config.config['ml_quick_label_use_skip_lists']:
65         quick_skip_file = config.config['ml_quick_label_skip_list_path']
66         with open(quick_skip_file, 'w') as f:
67             for filename in skip_list:
68                 filename = filename.strip()
69                 if len(filename) > 0:
70                     f.write(f'{filename}\n')
71         logger.debug(f'Updated {quick_skip_file}')
72
73
74 def label(in_spec: InputSpec) -> None:
75     import input_utils
76
77     images = []
78     if in_spec.image_file_glob is not None:
79         images += glob.glob(in_spec.image_file_glob)
80     elif in_spec.image_file_prepopulated_list is not None:
81         images += in_spec.image_file_prepopulated_list
82     else:
83         raise ValueError(
84             'One of image_file_glob or image_file_prepopulated_list is required'
85         )
86
87     skip_list = read_skip_list()
88     for image in images:
89         if image in skip_list:
90             logger.debug(f'Skipping {image} because of the skip list')
91             continue
92         features = in_spec.image_file_to_features_file(image)
93         if features is None or not os.path.exists(features):
94             msg = f'File {image} yielded file {features} which does not exist, SKIPPING.'
95             logger.warning(msg)
96             warnings.warn(msg)
97             continue
98
99         # Render features and image.
100         filtered_lines = []
101         with open(features, "r") as f:
102             lines = f.readlines()
103         saw_label = False
104         for line in lines:
105             line = line[:-1]
106             if in_spec.label not in line:
107                 filtered_lines.append(line)
108             else:
109                 saw_label = True
110
111         if not saw_label or config.config['ml_quick_label_overwrite_labels']:
112             logger.info(features)
113             os.system(f'xv {image} &')
114             keystroke = input_utils.single_keystroke_response(
115                 in_spec.valid_keystrokes,
116                 prompt=in_spec.prompt,
117             )
118             os.system('killall xv')
119             label_value = in_spec.keystroke_to_label(keystroke)
120             filtered_lines.append(f"{in_spec.label}: {label_value}\n")
121             with open(features, 'w') as f:
122                 f.writelines("%s\n" % line for line in filtered_lines)
123             skip_list.add(image)
124     write_skip_list(skip_list)