import os
import sys
import time
-import warnings
from abc import abstractmethod
-from typing import Any, Dict, List, Optional, Set
+from typing import Any, Dict, List, Optional, Set, Tuple
import argparse_utils
import config
action=argparse_utils.ActionNoYes,
help='Enable overwriting existing labels; default is to not relabel.',
)
+parser.add_argument(
+ '--ml_quick_label_skip_where_model_agrees',
+ default=False,
+ action=argparse_utils.ActionNoYes,
+ help='Do not filter examples where the model disagrees with the current label.',
+)
+parser.add_argument(
+ 'ml_quick_label_delete_invalid_examples',
+ default=False,
+ action='store_true',
+ help='If set we will delete invalid training examples.',
+)
class QuickLabelHelper:
'''Unrender a raw file with its features (if necessary)...'''
pass
+ @abstractmethod
+ def is_valid_example(self, filename: str, features: str, lines: List[str]) -> bool:
+ '''Returns true iff the example is valid (all features are valid, there
+ are the correct number of features, etc...'''
+ pass
+
@abstractmethod
def ask_current_model_about_example(
self,
logger.debug('Updated %s', quick_skip_file)
-def quick_label(helper: QuickLabelHelper) -> None:
- # Ask helper for an initial set of files.
- images = helper.get_candidate_files()
- if len(images) == 0:
- logger.warning('No images files to operate on.')
- return
-
- # Filter out any that can't be converted to features or already have a
- # label (unless they used --ml_qukck_label_overwrite_labels).
+def _filter_images(
+ images: List[str], skip_list: Set[str], helper: QuickLabelHelper
+) -> List[Tuple[str, str]]:
filtered_images = []
- skip_list = _maybe_read_skip_list()
+ label_label = helper.get_label_feature()
for image in images:
if image in skip_list:
logger.debug('Skipping %s because of the skip list', image)
features = helper.get_features_for_file(image)
if features is None or not os.path.exists(features):
- msg = f'{image}/{features}: {features} doesn\'t exist, SKIPPING.'
- logger.warning(msg)
- warnings.warn(msg)
+ logger.warning('%s/%s: features doesn\'t exist, SKIPPING.', image, features)
continue
- label_label = helper.get_label_feature()
label = None
+ filtered_lines = []
with open(features, 'r') as rf:
lines = rf.readlines()
for line in lines:
line = line[:-1]
if line.startswith(label_label):
- label = line
+ label = ''.join(line.split(':')[1:])
+ label = label.strip()
+ else:
+ filtered_lines.append(line)
+
+ if not helper.is_valid_example(image, features, filtered_lines):
+ logger.warning('%s/%s: Invalid example.', image, features)
+ if config.config['ml_quick_label_delete_invalid_examples']:
+ os.remove(image)
+ os.remove(features)
+ continue
+
if label and not config.config['ml_quick_label_overwrite_labels']:
- msg = f'{image}/{features}: already has label, SKIPPING'
- logger.warning(msg)
- warnings.warn(msg)
+ logger.warning('%s/%s: already has label, SKIPPING.', image, features)
continue
+
+ if config.config['ml_quick_label_skip_where_model_agrees']:
+ model_says = helper.ask_current_model_about_example(image, features, filtered_lines)
+ if model_says and label:
+ if model_says[0] == int(label):
+ continue
+ print(f'{image}/{features}: The model disagrees with the current label.')
+ print(f' ...model says {model_says[0]} with probability {model_says[1]}.')
+ print(f' ...the example is currently labeled {label}')
filtered_images.append((image, features))
+ return filtered_images
+
+
+def quick_label(helper: QuickLabelHelper) -> None:
+ skip_list = _maybe_read_skip_list()
+ # Ask helper for an initial set of files.
+ images = helper.get_candidate_files()
+ if len(images) == 0:
+ logger.warning('No images files to operate on.')
+ return
+
+ # Filter out any that can't be converted to features or already have a
+ # label (unless they used --ml_qukck_label_overwrite_labels).
+ filtered_images = _filter_images(images, skip_list, helper)
if len(filtered_images) == 0:
logger.warning('No image files to operate on (post filter).')
return
- cursor = 0
+ # Allow the user to label the non-filtered images one by one.
import input_utils
+ cursor = 0
+ label_label = helper.get_label_feature()
while True:
assert 0 <= cursor < len(filtered_images)