Since this thing is on the innerwebs I suppose it should have a
[python_utils.git] / ml / quick_label.py
index 120ff5fe92e644d22385741d2044d8c14d006dea..15256a30da2e8847f15bf114dc23eb7ff755c930 100644 (file)
@@ -1,10 +1,15 @@
 #!/usr/bin/env python3
 
+# © Copyright 2021-2022, Scott Gasch
+
+"""A helper to facilitate quick manual labeling of ML training data."""
+
 import glob
 import logging
 import os
-from typing import Callable, List, NamedTuple, Optional, Set
 import warnings
+from dataclasses import dataclass
+from typing import Callable, List, Optional, Set
 
 import argparse_utils
 import config
@@ -35,14 +40,18 @@ parser.add_argument(
 )
 
 
-class InputSpec(NamedTuple):
-    image_file_glob: Optional[str]
-    image_file_prepopulated_list: Optional[List[str]]
-    image_file_to_features_file: Callable[[str], str]
-    label: str
-    valid_keystrokes: List[str]
-    prompt: str
-    keystroke_to_label: Callable[[str], str]
+@dataclass
+class InputSpec:
+    """A wrapper around the input data we need to operate; should be
+    populated by the caller."""
+
+    image_file_glob: Optional[str] = None
+    image_file_prepopulated_list: Optional[List[str]] = None
+    image_file_to_features_file: Optional[Callable[[str], str]] = None
+    label: str = ''
+    valid_keystrokes: List[str] = []
+    prompt: str = ''
+    keystroke_to_label: Optional[Callable[[str], str]] = None
 
 
 def read_skip_list() -> Set[str]:
@@ -56,7 +65,7 @@ def read_skip_list() -> Set[str]:
                 line = line[:-1]
                 line.strip()
                 ret.add(line)
-        logger.debug(f'Read {quick_skip_file} and found {len(ret)} entries.')
+        logger.debug('Read %s and found %d entries.', quick_skip_file, len(ret))
     return ret
 
 
@@ -68,7 +77,7 @@ def write_skip_list(skip_list) -> None:
                 filename = filename.strip()
                 if len(filename) > 0:
                     f.write(f'{filename}\n')
-        logger.debug(f'Updated {quick_skip_file}')
+        logger.debug('Updated %s', quick_skip_file)
 
 
 def label(in_spec: InputSpec) -> None:
@@ -80,15 +89,14 @@ def label(in_spec: InputSpec) -> None:
     elif in_spec.image_file_prepopulated_list is not None:
         images += in_spec.image_file_prepopulated_list
     else:
-        raise ValueError(
-            'One of image_file_glob or image_file_prepopulated_list is required'
-        )
+        raise ValueError('One of image_file_glob or image_file_prepopulated_list is required')
 
     skip_list = read_skip_list()
     for image in images:
         if image in skip_list:
-            logger.debug(f'Skipping {image} because of the skip list')
+            logger.debug('Skipping %s because of the skip list', image)
             continue
+        assert in_spec.image_file_to_features_file
         features = in_spec.image_file_to_features_file(image)
         if features is None or not os.path.exists(features):
             msg = f'File {image} yielded file {features} which does not exist, SKIPPING.'
@@ -116,9 +124,10 @@ def label(in_spec: InputSpec) -> None:
                 prompt=in_spec.prompt,
             )
             os.system('killall xv')
+            assert in_spec.keystroke_to_label
             label_value = in_spec.keystroke_to_label(keystroke)
             filtered_lines.append(f"{in_spec.label}: {label_value}\n")
             with open(features, 'w') as f:
-                f.writelines("%s\n" % line for line in filtered_lines)
+                f.writelines(line + '\n' for line in filtered_lines)
             skip_list.add(image)
     write_skip_list(skip_list)