3 """A simple utility to unpickle some code, run it, and pickle the
14 import cloudpickle # type: ignore
15 import psutil # type: ignore
20 from thread_utils import background_thread
23 logger = logging.getLogger(__file__)
25 cfg = config.add_commandline_args(
26 f"Remote Worker ({__file__})",
27 "Helper to run pickled code remotely and return results",
34 help='The location of the bundle of code to execute.'
41 help='The location where we should write the computation results.'
45 action=argparse_utils.ActionNoYes,
47 help='Should we watch for the cancellation of our parent ssh process?'
52 def watch_for_cancel(terminate_event: threading.Event) -> None:
53 p = psutil.Process(os.getpid())
56 ancestors = p.parents()
57 for ancestor in ancestors:
58 name = ancestor.name()
59 if 'ssh' in name.lower():
64 os.kill(os.getpid(), signal.SIGTERM)
66 os.kill(os.getpid(), signal.SIGKILL)
68 if terminate_event.is_set():
75 in_file = config.config['code_file']
76 out_file = config.config['result_file']
78 logger.debug(f'Reading {in_file}.')
80 with open(in_file, 'rb') as rb:
81 serialized = rb.read()
82 except Exception as e:
84 logger.critical(f'Problem reading {in_file}. Aborting.')
87 logger.debug(f'Deserializing {in_file}.')
89 fun, args, kwargs = cloudpickle.loads(serialized)
90 except Exception as e:
92 logger.critical(f'Problem deserializing {in_file}. Aborting.')
95 logger.debug('Invoking user code...')
97 ret = fun(*args, **kwargs)
99 logger.debug(f'User code took {end - start:.1f}s')
101 logger.debug('Serializing results')
103 serialized = cloudpickle.dumps(ret)
104 except Exception as e:
106 logger.critical(f'Could not serialize result ({type(ret)}). Aborting.')
109 logger.debug(f'Writing {out_file}.')
111 with open(out_file, 'wb') as wb:
113 except Exception as e:
115 logger.critical(f'Error writing {out_file}. Aborting.')
119 if __name__ == '__main__':