10
from caffe2.python import core, context
11
from caffe2.python.net_builder import ops
12
from caffe2.python.task import (
21
logger = logging.getLogger(__name__)
25
class Job(context.Managed):
27
A Job defines three TaskGroups: the `init_group`, the `epoch_group` and the
28
`exit_group` which will be run by a JobRunner.
30
The `init_group` will be run only once at startup. Its role is to
31
initialize globally persistent blobs such as model weights, accumulators
34
The `epoch_group` will be run in a loop after init_group. The loop will
35
exit when any of the stop signals added with `add_stop_condition` is True
36
at the end of an epoch.
38
The download_group will be run only once, after all the executions of
39
epoch_group finish. Its role is to collect the distribute scattered
40
parameters back after training.
42
The `exit_group` will be run only once at the very end of the job, the
43
role of this group is to save the results of training in the end of the job.
45
Jobs are context-driven, so that Tasks can be added to the active Job
46
without having to explicitly pass the job object around.
50
def build_reader(partitions):
51
with Job.current().init_group:
52
reader = HiveReader(init_reader, ..., partitions)
53
Task(step=init_reader)
54
with Job.current().epoch_group:
55
limited_reader = ReaderWithLimit(reader, num_iter=10000)
56
data_queue = pipe(limited_reader, num_threads=8)
57
Job.current().add_stop_condition(limited_reader.data_finished())
60
def build_hogwild_trainer(reader, model):
61
with Job.current().init_group:
62
Task(step=model.param_init_net)
63
with Job.current().epoch_group:
64
pipe(reader, processor=model, num_threads=8)
65
with Job.current().exit_group:
66
Task(step=model.save_model_net)
69
reader = build_reader(partitions)
70
model = build_model(params)
71
build_hogwild_trainer(reader, model)
74
init_group=None, epoch_group=None,
75
download_group=None, exit_group=None,
76
stop_conditions=None, nodes_to_checkpoint=None):
77
self.init_group = init_group or TaskGroup(
78
workspace_type=WorkspaceType.GLOBAL)
79
self.epoch_group = epoch_group or TaskGroup()
80
self.download_group = download_group or TaskGroup()
81
self.exit_group = exit_group or TaskGroup()
82
self.stop_conditions = stop_conditions or []
83
self._nodes_to_checkpoint = nodes_to_checkpoint
85
def nodes_to_checkpoint(self):
86
if self._nodes_to_checkpoint:
87
return self._nodes_to_checkpoint
89
return self.init_group.used_nodes()
91
def compile(self, session_class):
92
self._nodes_to_checkpoint = self.nodes_to_checkpoint()
93
self.init_group = session_class.compile(self.init_group)
94
self.epoch_group = session_class.compile(self.epoch_group)
95
self.download_group = session_class.compile(self.download_group)
96
self.exit_group = session_class.compile(self.exit_group)
100
self.epoch_group.__enter__()
103
def __exit__(self, *args):
104
self.epoch_group.__exit__()
105
super().__exit__(*args)
107
def add_stop_condition(self, output):
108
if isinstance(output, core.BlobReference):
109
t = Task(outputs=[output], group=self.epoch_group)
110
output = t.outputs()[0]
111
assert isinstance(output, TaskOutput)
112
self.stop_conditions.append(output)
115
def get_ckpt_filename(node_name, epoch):
116
"""Returns the checkpoint filename.
119
node_name: A string. The name of the node.
120
epoch: An integer. The checkpoint epoch.
123
ckpt_filename: A string. The filename of the checkpoint.
125
return node_name + '.' + str(epoch)
128
def db_name(epoch, node_name, db_prefix, path_prefix=None):
129
"""Returns the full db name where checkpoint files are saved.
132
epoch: An integer. The checkpoint epoch.
133
node_name: A string. The name of the node.
134
db_prefix: A string. The prefix used to construct full db name.
135
path_prefix: A string. Optional param used to construct db name or path
136
where checkpoint files are stored.
138
db_name: A string. The absolute path of full_db_name where checkpoint
142
db_name = path_prefix + get_ckpt_filename(node_name, epoch)
144
ckpt_filename = get_ckpt_filename(node_name, epoch)
145
db_name = os.path.join(db_prefix, ckpt_filename)
149
class CheckpointManager:
151
Controls saving and loading of workspaces on every epoch boundary of a job.
152
If a CheckpointManager instance is passed to JobRunner, then JobRunner will
153
call `init`, `read` and `save` at different moments in between epoch runs.
156
db_prefix: The prefix used to construct full db name. Since `absolute_path`
157
is set to True, this will be used as db_name in SaveOp.
158
node_name: Name of the node where this checkpoint_manager is used.
159
db_type: Type of database to use for storing checkpoint.
160
metadata_handler: An optional object capable of reading/writing
161
checkpoint info in storage of choice.
164
BLOB_NAMES = "blob_names"
166
def __init__(self, db_prefix, node_name, db_type, metadata_handler=None):
167
self._db_prefix = db_prefix
168
self._node_name = node_name
169
self._db_type = db_type
170
self._metadata_handler = metadata_handler
172
self._net = core.Net('!!checkpoint_mngr')
173
self._blob_names = self._net.AddExternalInput(self.BLOB_NAMES)
174
self._names_output = None
175
self._path_prefix = None
176
self._path_type = None
177
self._current_db_name = None
178
self._current_checkpoint_duration = None
181
Initialize the checkpoint manager. Determines all blobs that need to be saved
182
or loads from a checkpoint.
185
nodes: An array of nodes where this checkpoint manager is running. Should
186
only contain a single node.
187
retrieve_from_epoch: Set to a number to load blobs from this epoch.
188
path_prefix: Used to construct db name or path where checkpoint files are
190
path_type: Indicate the type of path where checkpoint files are stored.
195
retrieve_from_epoch=None,
200
Build a Task that will be run once after the job's `init_group` is run.
201
This task will determine which blobs need to be checkpointed.
202
If retrieve_from_epoch is not None, then the checkpoint metadata is
203
retrieved from a previously saved checkpoint.
205
assert nodes is None or len(nodes) == 1, (
206
'CheckpointManager only supports single node.')
208
with Task(outputs=[self._blob_names]) as task:
209
if retrieve_from_epoch is None:
213
include_shared=False)
215
full_db_name = db_name(retrieve_from_epoch,
216
self._node_name, self._db_prefix, path_prefix)
217
db_type = path_type or self._db_type
218
logger.info("Initializing checkpoints from = %s"
221
[], self._blob_names,
227
self._names_output = task.outputs()[0]
231
assert self._names_output
232
return self._names_output.fetch().tolist()
234
def _timed_task(self, cp_op_name, add_op):
236
Build a Task that will measure the time span of checkpoint operations,
237
once operation is done, time can be read from _current_checkpoint_duration.
240
cp_op_name: A string name of the checkpoint operation.
241
add_op: A functor to add the checkpoint operation.
246
with Task(name=cp_op_name) as task:
247
with ops.task_init():
248
timer = ops.TimerBegin([], counter_name=self._node_name)
250
with ops.task_exit():
251
time_span_blob = ops.TimerGetAndEnd(timer)
252
self._current_checkpoint_duration = final_output(time_span_blob)
255
def collect_checkpoint_stats(self, stats):
257
Add one checkpoint stats into the stats.
260
stats: A dict of checkpoint stats that will be reported.
262
if self._current_db_name and self._current_checkpoint_duration:
263
stats[self._current_db_name] = self._current_checkpoint_duration.fetch()[0]
266
"Failed to collect checkpoint stats: {}".format(
267
self._current_db_name
271
def load(self, epoch, path_prefix=None, path_type=None):
273
Build a Task that will be run by JobRunner when the job is to be
274
resumed from a given epoch. This task will run a Load op that will
275
load and deserialize all relevant blobs from a persistent storage.
277
self._current_db_name = db_name(
278
epoch, self._node_name, self._db_prefix, path_prefix
280
db_type = path_type or self._db_type
281
logger.info("Loading checkpoints from = %s" % self._current_db_name)
287
db=self._current_db_name,
293
return self._timed_task('checkpoint_load', add_op)
295
def load_blobs_from_checkpoint(self, blob_names, epoch):
297
Builds a Task that loads only the necessary blobs from a checkpoint of
298
the given epoch. The necessary blobs are given in the blob_names
302
blob_names: A list of strings. Each string is the name of a
304
epoch: The checkpoint epoch to load from.
307
A Task which loads the specified blobs from the checkpoint of the
310
self._current_db_name = db_name(epoch, self._node_name, self._db_prefix)
311
logger.info('Load from %s' % self._current_db_name)
317
db=self._current_db_name,
318
db_type=self._db_type,
320
allow_incomplete=True)
322
return self._timed_task('checkpoint_partial_load', add_op)
324
def check_db_exists(self, epoch):
325
logger.info('Check existence of %s' %
326
db_name(epoch, self._node_name, self._db_prefix))
328
existence = ops.Const(False)
332
db_name=db_name(epoch, self._node_name, self._db_prefix),
333
db_type=self._db_type,
335
task.add_output(existence)
338
def report_checkpoint_stats(self, action_name):
340
Report checkpoint operation stats for current node.
343
action_name: A string of the name of checkpoint operation.
346
self.collect_checkpoint_stats(all_stats)
347
if self._metadata_handler:
348
self._metadata_handler.report(action_name, all_stats)
350
def save(self, epoch):
352
Build a Task that is run once after `init_group` and after each
353
epoch is run. This will execute a Save ops to serialize and persist
354
blobs present in the global workspace.
356
self._current_db_name = db_name(epoch, self._node_name, self._db_prefix)
357
logger.info('Saving to %s' % self._current_db_name)
361
self.blob_list(), [],
362
db=self._current_db_name,
363
db_type=self._db_type,
366
return self._timed_task('checkpoint_save', add_op)
368
def write_checkpoint_metadata(self, epoch):
370
Write metadata for checkpoint
373
epoch: An integer. The epoch-id for which checkpoint metadata is
376
if self._metadata_handler is not None:
377
self._metadata_handler.write(epoch=epoch)
379
def get_resume_from_epoch_id(self, user_epoch=None):
381
Identify the epoch-id from which Job must resume
384
user_epoch: An integer. Optional parameter for user to explicitly
385
identify the epoch-id to load checkpoint from
387
epoch: the epoch-id to load checkpoints from
388
or None if no checkpoints were written
390
last_epoch = user_epoch
391
if self._metadata_handler is not None:
392
last_epoch = self._metadata_handler.last_epoch(user_epoch=user_epoch)
395
def set_params(self, nodes, path_prefix=None, path_type=None):
396
"""Set parameters associated with CP manager
399
nodes: An array of nodes where this checkpoint manager is running.
400
path_prefix: Used to construct db name or path where checkpoint files are
402
path_type: Indicate the type of path where checkpoint files are stored.
405
self._path_prefix = path_prefix
407
self._path_type = path_type
408
if self._metadata_handler:
409
self._metadata_handler.set_params(
410
db_prefix=self._db_prefix,
411
db_type=self._db_type,
412
node_names=[str(self._node_name)],
413
path_prefix=self._path_prefix,
414
path_type=self._path_type)
416
def cp_accessible(self, epoch=None):
417
"""Returns True if Checkpoint data is accessible
420
epoch: An integer. The epoch of the checkpoint. If None,
421
it implies we need to check if checkpoint directory is accessible
424
is_cp_accessible: A boolean. Returns True if Checkpoint data is accessible
426
if self._metadata_handler is not None:
427
return self._metadata_handler.cp_accessible(epoch)
432
class MultiNodeCheckpointManager:
434
Coordinates checkpointing and checkpointing across multiple nodes.
435
Each of `init`, `load` and `save` will build TaskGroups which will
436
trigger checkpointing on each of the nodes involved in a distributed job.
439
db_prefix: The prefix used to construct full db name. Since `absolute_path`
440
is set to True, this will be used as db_name in SaveOp.
441
db_type: Type of database to use for storing checkpoint.
442
metadata_handler: An optional object capable of reading/writing
443
checkpoint info in storage of choice.
445
def __init__(self, db_prefix, db_type, metadata_handler=None):
446
self._node_managers = None
447
self._db_prefix = db_prefix
448
self._db_type = db_type
449
self._metadata_handler = metadata_handler
450
self._path_prefix = None
451
self._path_type = None
453
def _task_group(self, func, *args, **kw):
454
assert self._node_managers is not None, 'init must be called first.'
455
with TaskGroup(WorkspaceType.GLOBAL) as task_group:
456
for node, manager in self._node_managers:
458
func(manager, *args, **kw)
463
nodes: An array of nodes where this checkpoint manager is running.
464
retrieve_from_epoch: Set to a number to load blobs from this epoch.
465
path_prefix: Used to construct db name or path where checkpoint files are
467
path_type: Indicate the type of path where checkpoint files are stored.
470
self, nodes, retrieve_from_epoch=None, path_prefix=None, path_type=None
472
if self._node_managers is not None:
473
assert [node for node, _ in self._node_managers] == nodes
474
return TaskGroup(WorkspaceType.GLOBAL)
475
self._node_managers = []
478
manager = CheckpointManager(
479
db_prefix=self._db_prefix,
481
db_type=self._db_type)
482
self._node_managers.append((node, manager))
483
return self._task_group(
484
CheckpointManager.init,
486
retrieve_from_epoch=retrieve_from_epoch,
487
path_prefix=path_prefix,
490
def load(self, epoch, path_prefix=None, path_type=None):
491
return self._task_group(
492
CheckpointManager.load,
494
path_prefix=path_prefix,
497
def load_blobs_locally(self, nodes, blob_names, epoch, session):
498
"""Loads the necessary blobs from the checkpoints to the current node.
501
blob_names: A list of strings. Each string is the name of a
503
epoch: An integer. The checkpoint epoch to load from.
504
session: A Session object to execute the Load ops.
506
if self._node_managers is not None:
507
assert [node for node, _ in self._node_managers] == nodes
509
self._node_managers = []
512
manager = CheckpointManager(
513
db_prefix=self._db_prefix,
515
db_type=self._db_type)
516
self._node_managers.append((node, manager))
517
assert self._node_managers is not None, 'must initialize node managers'
518
for _, manager in self._node_managers:
519
existence_task = manager.check_db_exists(epoch)
520
session.run(existence_task)
521
existence = existence_task.outputs()[0].fetch()
523
logger.info('DB %s does not exist!' %
524
db_name(epoch, manager._node_name, manager._db_prefix))
526
load_task = manager.load_blobs_from_checkpoint(blob_names, epoch)
527
session.run(load_task)
528
logger.info('Successfully loaded from checkpoints.')
531
def get_ckpt_db_name(self, node_name, epoch):
532
"""Returns the DB name of the given node and the given epoch.
534
The DB name is effectively the checkpoint path of the given node and
538
node_name: A string. The node name of interest.
539
epoch: An integer. The epoch of the checkpoint.
542
checkpoint_db_name: A string. The checkpoint path of the given
543
node and the given epoch.
545
for node, manager in self._node_managers:
546
if str(node) == node_name:
547
return db_name(epoch, manager._node_name, manager._db_prefix)
549
def report_checkpoint_stats(self, action_name):
551
Report the checkpoint stats for all the nodes, we need to aggregate all
552
the node's stats together so that we know which node's checkpoint
556
action_name: A string of the name of checkpoint operation.
559
for _, manager in self._node_managers:
560
manager.collect_checkpoint_stats(all_stats)
561
logger.debug("checkpoint stats: {}".format(all_stats))
562
if self._metadata_handler:
563
self._metadata_handler.report(action_name, all_stats)
565
def save(self, epoch):
567
Build a Task that will execute a Save ops to serialize and persist
568
blobs present in the global workspace.
570
return self._task_group(CheckpointManager.save, epoch)
572
def write_checkpoint_metadata(self, epoch):
574
Write metadata for checkpoint
577
epoch: An integer. The epoch-id for which checkpoint metadata is
580
if self._metadata_handler is not None:
581
self._metadata_handler.write(epoch=epoch)
583
def get_resume_from_epoch_id(self, user_epoch=None):
585
Identify the epoch-id from which Job must resume
588
user_epoch: An integer. Optional parameter for user to explicitly
589
identify the epoch-id to load checkpoint from
591
epoch: the epoch-id to load checkpoints from
592
or None if no checkpoints were written
594
last_epoch = user_epoch
595
if self._metadata_handler is not None:
596
last_epoch = self._metadata_handler.last_epoch(user_epoch=user_epoch)
599
def set_params(self, nodes, path_prefix=None, path_type=None):
600
"""Set parameters associated with CP manager
603
nodes: An array of nodes where this checkpoint manager is running.
604
path_prefix: Used to construct db name or path where checkpoint files are
606
path_type: Indicate the type of path where checkpoint files are stored.
608
self._node_names = [str(node) for node in nodes]
610
self._path_prefix = path_prefix
612
self._path_type = path_type
613
if self._metadata_handler:
614
self._metadata_handler.set_params(
615
db_prefix=self._db_prefix,
616
db_type=self._db_type,
617
node_names=self._node_names,
618
path_prefix=self._path_prefix,
619
path_type=self._path_type)
621
def cp_accessible(self, epoch=None):
622
"""Returns True if Checkpoint data is accessible
625
epoch: An integer. The epoch of the checkpoint. If None,
626
it implies we need to check if checkpoint directory is accessible
629
is_cp_accessible: A boolean. Returns True if Checkpoint data is accessible
631
if self._metadata_handler is not None:
632
return self._metadata_handler.cp_accessible(epoch)
637
class UploadTaskGroupBuilder:
638
"""A simple class to upload checkpoints."""
639
def build(self, epoch, checkpoint_manager):
640
"""Builds the task group to upload checkpoints.
643
epoch: An integer. The checkpoint epoch to be uploaded.
644
checkpoint_manager: Can be a CheckpointManager for single machine
645
or a MultiNodeCheckpointManager for multi-machine. The manager
646
that initializes/saves/loads checkpoints.
649
NotImplementedError: This base class only has the interface,
650
the implementation will be in the subclasses.
652
raise NotImplementedError()
657
Implement the runtime logic for jobs with checkpointing at the level of
658
epoch. Can be used to run either single-host or distributed jobs. Job
659
runner is a callable to be called once from the master, passing a session
660
as an argument. This call will block until the Job execution is complete.
662
If a checkpoint_manager is passed, checkpoints will be taken after
663
initialization and after each epoch execution. If, in addition,
664
`resume_from_epoch` is an epoch number, the corresponding checkpoint will
665
be loaded and job execution will continue from the given epoch. In
666
this case, the job's init_group will not be run.
668
Refer to checkpoint_test.py for an example.
670
def __init__(self, job, checkpoint_manager=None, resume_from_epoch=None,
671
upload_task_group_builder=None):
672
"""Initializes the JobRunner.
675
job: A Job object. The job to be executed.
676
checkpoint_manager: Can be a CheckpointManager for single machine
677
or a MultiNodeCheckpointManager for multi-machine. The manager
678
that initializes/saves/loads checkpoints.
679
resume_from_epoch: An integer. The epoch to resume from.
680
upload_task_group_builder: A subclass of the
681
UploadTaskGroupBuilder. Creates a task group to upload
684
self.resume_from_epoch = resume_from_epoch
685
self.checkpoint_manager = checkpoint_manager
687
self.upload_task_group_builder = upload_task_group_builder
689
def train(self, session):
690
"""Runs the training flow.
693
session: A Session object. Valid choises are: LocalSession,
694
LocalHostScheduler, and DistributedSession. It is used to
695
execute one TaskGroup a time.
698
if self.checkpoint_manager:
699
self.checkpoint_manager.set_params(nodes=self.job.nodes_to_checkpoint())
700
self.resume_from_epoch = self.checkpoint_manager.\
701
get_resume_from_epoch_id(self.resume_from_epoch)
702
if self.resume_from_epoch is not None:
703
logger.info('Resuming from epoch {}'.format(self.resume_from_epoch))
706
from_scratch = self.resume_from_epoch is None
708
session.run(self.job.init_group)
710
if self.checkpoint_manager:
711
logger.info('Preparing checkpoints ...')
712
session.run(self.checkpoint_manager.init(
713
self.job.nodes_to_checkpoint(),
714
retrieve_from_epoch=self.resume_from_epoch))
718
self.save_checkpoints(0, session)
720
logger.info('Loading checkpoints for epoch {} ...'.format(
721
self.resume_from_epoch))
723
self.checkpoint_manager.load(self.resume_from_epoch))
724
self.checkpoint_manager.report_checkpoint_stats('checkpoint_load')
725
logger.info('Checkpoint loaded')
727
logger.info("Finished initializing")
730
epoch = 1 if from_scratch else self.resume_from_epoch + 1
732
logger.info('Starting epoch %d' % epoch)
733
session.run(self.job.epoch_group)
734
logger.info('Finished epoch %d' % epoch)
735
stop_conditions = [o.fetch() for o in self.job.stop_conditions]
737
if self.checkpoint_manager:
738
self.save_checkpoints(epoch, session)
740
if any(stop_conditions):
741
logger.info('Stopping')
744
logger.info('Finished training')
746
if (self.upload_task_group_builder):
747
upload_task_group = self.upload_task_group_builder.build(
748
epoch, self.checkpoint_manager)
749
session.run(upload_task_group)
750
logger.info('Finished uploading the checkpoints')
753
session.run(self.job.download_group)
754
logger.info('Finished downloading the parameters')
757
session.run(self.job.exit_group)
758
logger.info('Finished running the exit group')
761
def load_blobs_from_checkpoints(self, blob_names, epoch, session):
762
"""Loads the necessary blobs from the checkpoints.
764
Checkpoints store the snapshots of the workspace in each node.
765
Sometimes we only need to load a subset of the blobs from the
766
checkpoints. One common scenario is to load only the model blobs from
767
the checkpoints for evaluation purpose. Given the names of the
768
necessary blobs, this function goes over all the checkpoints of all the
769
nodes, but only loads the blobs specified in the blob_names to the
773
blob_names: A list of strings. Each string is the name of a
775
epoch: An integer. The checkpoint epoch to load from.
776
session: A Session object to execute the load ops.
779
ValueError: When the checkpoint manager is invalid.
781
if not self.checkpoint_manager:
782
raise ValueError('Checkpoint manager is None')
783
logger.info('Loading checkpoint for epoch {} ...'.format(epoch))
784
result = self.checkpoint_manager.load_blobs_locally(
785
self.job.nodes_to_checkpoint(), blob_names, epoch, session)
786
self.checkpoint_manager.report_checkpoint_stats('checkpoint_partial_load')
789
def save_checkpoints(self, epoch, session):
790
"""Triggers operation to save checkpoints
792
This method will trigger the Save ops to serialize and persist the
793
blobs present in the global workspaace.
796
epoch: An integer. The checkpoint epoch-id that we are saving.
797
session: A Session object to execute the save ops.
800
ValueError: When the checkpoint manager is invalid.
802
if not self.checkpoint_manager:
803
raise ValueError('Checkpoint manager is None')
805
is_accessible = self.checkpoint_manager.cp_accessible(epoch=None)
807
logger.info('Saving checkpoints for epoch {}'.format(epoch))
808
session.run(self.checkpoint_manager.save(epoch))
809
self.checkpoint_manager.write_checkpoint_metadata(epoch)
810
logger.info('Checkpoints saved')
811
self.checkpoint_manager.report_checkpoint_stats('checkpoint_save')
813
logger.warning("Checkpoint files cannot be accessed!")
814
except Exception as ex:
815
logger.warning("Unable to write checkpoint for epoch {}. Error={}".
819
def epoch_limiter(job, num_epochs):
821
Creates a task that will output True when a given
822
number of epochs has finished.
825
init_net = core.Net('epoch_counter_init')
826
counter = init_net.CreateCounter([], init_count=num_epochs - 1)
829
with job.epoch_group:
830
epoch_net = core.Net('epoch_countdown')
831
finished = epoch_net.CountDown(counter)
832
output = Task(step=epoch_net, outputs=finished).outputs()[0]
833
job.add_stop_condition(output)