Since this thing is on the innerwebs I suppose it should have a
[python_utils.git] / ml / quick_label.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """A helper to facilitate quick manual labeling of ML training data."""
6
7 import glob
8 import logging
9 import os
10 import warnings
11 from dataclasses import dataclass
12 from typing import Callable, List, Optional, Set
13
14 import argparse_utils
15 import config
16
17 logger = logging.getLogger(__name__)
18 parser = config.add_commandline_args(
19     f"ML Quick Labeler ({__file__})",
20     "Args related to quick labeling of ML training data",
21 )
22 parser.add_argument(
23     "--ml_quick_label_skip_list_path",
24     default="./qlabel_skip_list.txt",
25     metavar="FILENAME",
26     type=argparse_utils.valid_filename,
27     help="Path to file in which to store already labeled data.",
28 )
29 parser.add_argument(
30     "--ml_quick_label_use_skip_lists",
31     default=True,
32     action=argparse_utils.ActionNoYes,
33     help='Should we use a skip list file to speed up execution?',
34 )
35 parser.add_argument(
36     "--ml_quick_label_overwrite_labels",
37     default=False,
38     action=argparse_utils.ActionNoYes,
39     help='Enable overwriting existing labels; default is to not relabel.',
40 )
41
42
43 @dataclass
44 class InputSpec:
45     """A wrapper around the input data we need to operate; should be
46     populated by the caller."""
47
48     image_file_glob: Optional[str] = None
49     image_file_prepopulated_list: Optional[List[str]] = None
50     image_file_to_features_file: Optional[Callable[[str], str]] = None
51     label: str = ''
52     valid_keystrokes: List[str] = []
53     prompt: str = ''
54     keystroke_to_label: Optional[Callable[[str], str]] = None
55
56
57 def read_skip_list() -> Set[str]:
58     ret: Set[str] = set()
59     if config.config['ml_quick_label_use_skip_lists']:
60         quick_skip_file = config.config['ml_quick_label_skip_list_path']
61         if os.path.exists(quick_skip_file):
62             with open(quick_skip_file, 'r') as f:
63                 lines = f.readlines()
64             for line in lines:
65                 line = line[:-1]
66                 line.strip()
67                 ret.add(line)
68         logger.debug('Read %s and found %d entries.', quick_skip_file, len(ret))
69     return ret
70
71
72 def write_skip_list(skip_list) -> None:
73     if config.config['ml_quick_label_use_skip_lists']:
74         quick_skip_file = config.config['ml_quick_label_skip_list_path']
75         with open(quick_skip_file, 'w') as f:
76             for filename in skip_list:
77                 filename = filename.strip()
78                 if len(filename) > 0:
79                     f.write(f'{filename}\n')
80         logger.debug('Updated %s', quick_skip_file)
81
82
83 def label(in_spec: InputSpec) -> None:
84     import input_utils
85
86     images = []
87     if in_spec.image_file_glob is not None:
88         images += glob.glob(in_spec.image_file_glob)
89     elif in_spec.image_file_prepopulated_list is not None:
90         images += in_spec.image_file_prepopulated_list
91     else:
92         raise ValueError('One of image_file_glob or image_file_prepopulated_list is required')
93
94     skip_list = read_skip_list()
95     for image in images:
96         if image in skip_list:
97             logger.debug('Skipping %s because of the skip list', image)
98             continue
99         assert in_spec.image_file_to_features_file
100         features = in_spec.image_file_to_features_file(image)
101         if features is None or not os.path.exists(features):
102             msg = f'File {image} yielded file {features} which does not exist, SKIPPING.'
103             logger.warning(msg)
104             warnings.warn(msg)
105             continue
106
107         # Render features and image.
108         filtered_lines = []
109         with open(features, "r") as f:
110             lines = f.readlines()
111         saw_label = False
112         for line in lines:
113             line = line[:-1]
114             if in_spec.label not in line:
115                 filtered_lines.append(line)
116             else:
117                 saw_label = True
118
119         if not saw_label or config.config['ml_quick_label_overwrite_labels']:
120             logger.info(features)
121             os.system(f'xv {image} &')
122             keystroke = input_utils.single_keystroke_response(
123                 in_spec.valid_keystrokes,
124                 prompt=in_spec.prompt,
125             )
126             os.system('killall xv')
127             assert in_spec.keystroke_to_label
128             label_value = in_spec.keystroke_to_label(keystroke)
129             filtered_lines.append(f"{in_spec.label}: {label_value}\n")
130             with open(features, 'w') as f:
131                 f.writelines(line + '\n' for line in filtered_lines)
132             skip_list.add(image)
133     write_skip_list(skip_list)