pytorch

Форк
0
/
checkpoint.py 
833 строки · 31.3 Кб
1
## @package checkpoint
2
# Module caffe2.python.checkpoint
3

4

5

6

7

8
import os
9
import logging
10
from caffe2.python import core, context
11
from caffe2.python.net_builder import ops
12
from caffe2.python.task import (
13
    final_output,
14
    Node,
15
    Task,
16
    TaskGroup,
17
    TaskOutput,
18
    WorkspaceType,
19
)
20

21
logger = logging.getLogger(__name__)
22

23

24

25
class Job(context.Managed):
26
    """
27
    A Job defines three TaskGroups: the `init_group`, the `epoch_group` and the
28
    `exit_group` which will be run by a JobRunner.
29

30
    The `init_group` will be run only once at startup. Its role is to
31
    initialize globally persistent blobs such as model weights, accumulators
32
    and data file lists.
33

34
    The `epoch_group` will be run in a loop after init_group. The loop will
35
    exit when any of the stop signals added with `add_stop_condition` is True
36
    at the end of an epoch.
37

38
    The download_group will be run only once, after all the executions of
39
    epoch_group finish. Its role is to collect the distribute scattered
40
    parameters back after training.
41

42
    The `exit_group` will be run only once at the very end of the job, the
43
    role of this group is to save the results of training in the end of the job.
44

45
    Jobs are context-driven, so that Tasks can be added to the active Job
46
    without having to explicitly pass the job object around.
47

48
    Example of usage:
49

50
    def build_reader(partitions):
51
        with Job.current().init_group:
52
            reader = HiveReader(init_reader, ..., partitions)
53
            Task(step=init_reader)
54
        with Job.current().epoch_group:
55
            limited_reader = ReaderWithLimit(reader, num_iter=10000)
56
            data_queue = pipe(limited_reader, num_threads=8)
57
            Job.current().add_stop_condition(limited_reader.data_finished())
58
        return data_queue
59

60
    def build_hogwild_trainer(reader, model):
61
        with Job.current().init_group:
62
            Task(step=model.param_init_net)
63
        with Job.current().epoch_group:
64
            pipe(reader, processor=model, num_threads=8)
65
        with Job.current().exit_group:
66
            Task(step=model.save_model_net)
67

68
    with Job() as job:
69
        reader = build_reader(partitions)
70
        model = build_model(params)
71
        build_hogwild_trainer(reader, model)
72
    """
73
    def __init__(self,
74
                 init_group=None, epoch_group=None,
75
                 download_group=None, exit_group=None,
76
                 stop_conditions=None, nodes_to_checkpoint=None):
77
        self.init_group = init_group or TaskGroup(
78
            workspace_type=WorkspaceType.GLOBAL)
79
        self.epoch_group = epoch_group or TaskGroup()
80
        self.download_group = download_group or TaskGroup()
81
        self.exit_group = exit_group or TaskGroup()
82
        self.stop_conditions = stop_conditions or []
83
        self._nodes_to_checkpoint = nodes_to_checkpoint
84

85
    def nodes_to_checkpoint(self):
86
        if self._nodes_to_checkpoint:
87
            return self._nodes_to_checkpoint
88
        else:
89
            return self.init_group.used_nodes()
90

91
    def compile(self, session_class):
92
        self._nodes_to_checkpoint = self.nodes_to_checkpoint()
93
        self.init_group = session_class.compile(self.init_group)
94
        self.epoch_group = session_class.compile(self.epoch_group)
95
        self.download_group = session_class.compile(self.download_group)
96
        self.exit_group = session_class.compile(self.exit_group)
97

98
    def __enter__(self):
99
        super().__enter__()
100
        self.epoch_group.__enter__()
101
        return self
102

103
    def __exit__(self, *args):
104
        self.epoch_group.__exit__()
105
        super().__exit__(*args)
106

107
    def add_stop_condition(self, output):
108
        if isinstance(output, core.BlobReference):
109
            t = Task(outputs=[output], group=self.epoch_group)
110
            output = t.outputs()[0]
111
        assert isinstance(output, TaskOutput)
112
        self.stop_conditions.append(output)
113

114

115
def get_ckpt_filename(node_name, epoch):
116
    """Returns the checkpoint filename.
117

118
    Args:
119
        node_name: A string. The name of the node.
120
        epoch: An integer. The checkpoint epoch.
121

122
    Returns:
123
        ckpt_filename: A string. The filename of the checkpoint.
124
    """
125
    return node_name + '.' + str(epoch)
126

127

128
def db_name(epoch, node_name, db_prefix, path_prefix=None):
129
    """Returns the full db name where checkpoint files are saved.
130

131
    Args:
132
        epoch: An integer. The checkpoint epoch.
133
        node_name: A string. The name of the node.
134
        db_prefix: A string. The prefix used to construct full db name.
135
        path_prefix: A string. Optional param used to construct db name or path
136
            where checkpoint files are stored.
137
    Returns:
138
        db_name: A string. The absolute path of full_db_name where checkpoint
139
            files are saved
140
    """
141
    if path_prefix:
142
        db_name = path_prefix + get_ckpt_filename(node_name, epoch)
143
    else:
144
        ckpt_filename = get_ckpt_filename(node_name, epoch)
145
        db_name = os.path.join(db_prefix, ckpt_filename)
146
    return db_name
147

148

149
class CheckpointManager:
150
    """
151
    Controls saving and loading of workspaces on every epoch boundary of a job.
152
    If a CheckpointManager instance is passed to JobRunner, then JobRunner will
153
    call `init`, `read` and `save` at different moments in between epoch runs.
154

155
    Args:
156
        db_prefix: The prefix used to construct full db name. Since `absolute_path`
157
            is set to True, this will be used as db_name in SaveOp.
158
        node_name: Name of the node where this checkpoint_manager is used.
159
        db_type: Type of database to use for storing checkpoint.
160
        metadata_handler: An optional object capable of reading/writing
161
            checkpoint info in storage of choice.
162
    """
163

164
    BLOB_NAMES = "blob_names"
165

166
    def __init__(self, db_prefix, node_name, db_type, metadata_handler=None):
167
        self._db_prefix = db_prefix
168
        self._node_name = node_name
169
        self._db_type = db_type
170
        self._metadata_handler = metadata_handler
171
        # make sure these blobs are the first in the checkpoint file.
172
        self._net = core.Net('!!checkpoint_mngr')
173
        self._blob_names = self._net.AddExternalInput(self.BLOB_NAMES)
174
        self._names_output = None
175
        self._path_prefix = None
176
        self._path_type = None
177
        self._current_db_name = None
178
        self._current_checkpoint_duration = None
179

180
    """
181
    Initialize the checkpoint manager. Determines all blobs that need to be saved
182
    or loads from a checkpoint.
183

184
    Args:
185
        nodes: An array of nodes where this checkpoint manager is running. Should
186
            only contain a single node.
187
        retrieve_from_epoch: Set to a number to load blobs from this epoch.
188
        path_prefix: Used to construct db name or path where checkpoint files are
189
            stored.
190
        path_type: Indicate the type of path where checkpoint files are stored.
191
    """
192
    def init(
193
        self,
194
        nodes=None,
195
        retrieve_from_epoch=None,
196
        path_prefix=None,
197
        path_type=None
198
    ):
199
        """
200
        Build a Task that will be run once after the job's `init_group` is run.
201
        This task will determine which blobs need to be checkpointed.
202
        If retrieve_from_epoch is not None, then the checkpoint metadata is
203
        retrieved from a previously saved checkpoint.
204
        """
205
        assert nodes is None or len(nodes) == 1, (
206
            'CheckpointManager only supports single node.')
207

208
        with Task(outputs=[self._blob_names]) as task:
209
            if retrieve_from_epoch is None:
210
                ops.GetAllBlobNames(
211
                    [],
212
                    self._blob_names,
213
                    include_shared=False)
214
            else:
215
                full_db_name = db_name(retrieve_from_epoch,
216
                                        self._node_name, self._db_prefix, path_prefix)
217
                db_type = path_type or self._db_type
218
                logger.info("Initializing checkpoints from = %s"
219
                            % full_db_name)
220
                ops.Load(
221
                    [], self._blob_names,
222
                    db=full_db_name,
223
                    db_type=db_type,
224
                    absolute_path=True,
225
                    keep_device=True,
226
                )
227
        self._names_output = task.outputs()[0]
228
        return task
229

230
    def blob_list(self):
231
        assert self._names_output
232
        return self._names_output.fetch().tolist()
233

234
    def _timed_task(self, cp_op_name, add_op):
235
        """
236
        Build a Task that will measure the time span of checkpoint operations,
237
        once operation is done, time can be read from _current_checkpoint_duration.
238

239
        Args:
240
            cp_op_name: A string name of the checkpoint operation.
241
            add_op: A functor to add the checkpoint operation.
242

243
        Returns:
244
            A task with timer.
245
        """
246
        with Task(name=cp_op_name) as task:
247
            with ops.task_init():
248
                timer = ops.TimerBegin([], counter_name=self._node_name)
249
            add_op()
250
            with ops.task_exit():
251
                time_span_blob = ops.TimerGetAndEnd(timer)
252
            self._current_checkpoint_duration = final_output(time_span_blob)
253
        return task
254

255
    def collect_checkpoint_stats(self, stats):
256
        """
257
        Add one checkpoint stats into the stats.
258

259
        Args:
260
            stats: A dict of checkpoint stats that will be reported.
261
        """
262
        if self._current_db_name and self._current_checkpoint_duration:
263
            stats[self._current_db_name] = self._current_checkpoint_duration.fetch()[0]
264
        else:
265
            logger.info(
266
                "Failed to collect checkpoint stats: {}".format(
267
                    self._current_db_name
268
                )
269
            )
270

271
    def load(self, epoch, path_prefix=None, path_type=None):
272
        """
273
        Build a Task that will be run by JobRunner when the job is to be
274
        resumed from a given epoch. This task will run a Load op that will
275
        load and deserialize all relevant blobs from a persistent storage.
276
        """
277
        self._current_db_name = db_name(
278
            epoch, self._node_name, self._db_prefix, path_prefix
279
        )
280
        db_type = path_type or self._db_type
281
        logger.info("Loading checkpoints from = %s" % self._current_db_name)
282

283
        def add_op():
284
            ops.Load(
285
                [],
286
                self.blob_list(),
287
                db=self._current_db_name,
288
                db_type=db_type,
289
                absolute_path=True,
290
                keep_device=True,
291
            )
292

293
        return self._timed_task('checkpoint_load', add_op)
294

295
    def load_blobs_from_checkpoint(self, blob_names, epoch):
296
        """
297
        Builds a Task that loads only the necessary blobs from a checkpoint of
298
        the given epoch. The necessary blobs are given in the blob_names
299
        argument.
300

301
        Args:
302
            blob_names: A list of strings. Each string is the name of a
303
                blob.
304
            epoch: The checkpoint epoch to load from.
305

306
        Returns:
307
            A Task which loads the specified blobs from the checkpoint of the
308
            given epoch.
309
        """
310
        self._current_db_name = db_name(epoch, self._node_name, self._db_prefix)
311
        logger.info('Load from %s' % self._current_db_name)
312

313
        def add_op():
314
            ops.Load(
315
                [],
316
                blob_names,
317
                db=self._current_db_name,
318
                db_type=self._db_type,
319
                absolute_path=True,
320
                allow_incomplete=True)
321

322
        return self._timed_task('checkpoint_partial_load', add_op)
323

324
    def check_db_exists(self, epoch):
325
        logger.info('Check existence of %s' %
326
                    db_name(epoch, self._node_name, self._db_prefix))
327
        with Task() as task:
328
            existence = ops.Const(False)
329
            ops.DBExists(
330
                [],
331
                [existence],
332
                db_name=db_name(epoch, self._node_name, self._db_prefix),
333
                db_type=self._db_type,
334
                absolute_path=True)
335
            task.add_output(existence)
336
        return task
337

338
    def report_checkpoint_stats(self, action_name):
339
        """
340
        Report checkpoint operation stats for current node.
341

342
        Args:
343
            action_name: A string of the name of checkpoint operation.
344
        """
345
        all_stats = {}
346
        self.collect_checkpoint_stats(all_stats)
347
        if self._metadata_handler:
348
            self._metadata_handler.report(action_name, all_stats)
349

350
    def save(self, epoch):
351
        """
352
        Build a Task that is run once after `init_group` and after each
353
        epoch is run. This will execute a Save ops to serialize and persist
354
        blobs present in the global workspace.
355
        """
356
        self._current_db_name = db_name(epoch, self._node_name, self._db_prefix)
357
        logger.info('Saving to %s' % self._current_db_name)
358

359
        def add_op():
360
            ops.Save(
361
                self.blob_list(), [],
362
                db=self._current_db_name,
363
                db_type=self._db_type,
364
                absolute_path=True)
365

366
        return self._timed_task('checkpoint_save', add_op)
367

368
    def write_checkpoint_metadata(self, epoch):
369
        """
370
        Write metadata for checkpoint
371

372
        Args:
373
            epoch: An integer. The epoch-id for which checkpoint metadata is
374
                written
375
        """
376
        if self._metadata_handler is not None:
377
            self._metadata_handler.write(epoch=epoch)
378

379
    def get_resume_from_epoch_id(self, user_epoch=None):
380
        """
381
        Identify the epoch-id from which Job must resume
382

383
        Args:
384
            user_epoch: An integer. Optional parameter for user to explicitly
385
                identify the epoch-id to load checkpoint from
386
        Returns:
387
            epoch: the epoch-id to load checkpoints from
388
                or None if no checkpoints were written
389
        """
390
        last_epoch = user_epoch
391
        if self._metadata_handler is not None:
392
            last_epoch = self._metadata_handler.last_epoch(user_epoch=user_epoch)
393
        return last_epoch
394

395
    def set_params(self, nodes, path_prefix=None, path_type=None):
396
        """Set parameters associated with CP manager
397

398
        Args:
399
            nodes: An array of nodes where this checkpoint manager is running.
400
            path_prefix: Used to construct db name or path where checkpoint files are
401
                stored.
402
            path_type: Indicate the type of path where checkpoint files are stored.
403
        """
404
        if path_prefix:
405
            self._path_prefix = path_prefix
406
        if path_type:
407
            self._path_type = path_type
408
        if self._metadata_handler:
409
            self._metadata_handler.set_params(
410
                db_prefix=self._db_prefix,
411
                db_type=self._db_type,
412
                node_names=[str(self._node_name)],
413
                path_prefix=self._path_prefix,
414
                path_type=self._path_type)
415

416
    def cp_accessible(self, epoch=None):
417
        """Returns True if Checkpoint data is accessible
418

419
        Args:
420
            epoch: An integer. The epoch of the checkpoint. If None,
421
                it implies we need to check if checkpoint directory is accessible
422

423
        Returns:
424
            is_cp_accessible: A boolean. Returns True if Checkpoint data is accessible
425
        """
426
        if self._metadata_handler is not None:
427
            return self._metadata_handler.cp_accessible(epoch)
428
        else:
429
            return True
430

431

432
class MultiNodeCheckpointManager:
433
    """
434
    Coordinates checkpointing and checkpointing across multiple nodes.
435
    Each of `init`, `load` and `save` will build TaskGroups which will
436
    trigger checkpointing on each of the nodes involved in a distributed job.
437

438
    Args:
439
        db_prefix: The prefix used to construct full db name. Since `absolute_path`
440
            is set to True, this will be used as db_name in SaveOp.
441
        db_type: Type of database to use for storing checkpoint.
442
        metadata_handler: An optional object capable of reading/writing
443
            checkpoint info in storage of choice.
444
    """
445
    def __init__(self, db_prefix, db_type, metadata_handler=None):
446
        self._node_managers = None
447
        self._db_prefix = db_prefix
448
        self._db_type = db_type
449
        self._metadata_handler = metadata_handler
450
        self._path_prefix = None
451
        self._path_type = None
452

453
    def _task_group(self, func, *args, **kw):
454
        assert self._node_managers is not None, 'init must be called first.'
455
        with TaskGroup(WorkspaceType.GLOBAL) as task_group:
456
            for node, manager in self._node_managers:
457
                with Node(node):
458
                    func(manager, *args, **kw)
459
            return task_group
460

461
    """
462
    Args:
463
        nodes: An array of nodes where this checkpoint manager is running.
464
        retrieve_from_epoch: Set to a number to load blobs from this epoch.
465
        path_prefix: Used to construct db name or path where checkpoint files are
466
            stored.
467
        path_type: Indicate the type of path where checkpoint files are stored.
468
    """
469
    def init(
470
        self, nodes, retrieve_from_epoch=None, path_prefix=None, path_type=None
471
    ):
472
        if self._node_managers is not None:
473
            assert [node for node, _ in self._node_managers] == nodes
474
            return TaskGroup(WorkspaceType.GLOBAL)
475
        self._node_managers = []
476
        for node in nodes:
477
            with Node(node):
478
                manager = CheckpointManager(
479
                    db_prefix=self._db_prefix,
480
                    node_name=str(node),
481
                    db_type=self._db_type)
482
                self._node_managers.append((node, manager))
483
        return self._task_group(
484
            CheckpointManager.init,
485
            nodes=[node],
486
            retrieve_from_epoch=retrieve_from_epoch,
487
            path_prefix=path_prefix,
488
            path_type=path_type)
489

490
    def load(self, epoch, path_prefix=None, path_type=None):
491
        return self._task_group(
492
            CheckpointManager.load,
493
            epoch,
494
            path_prefix=path_prefix,
495
            path_type=path_type)
496

497
    def load_blobs_locally(self, nodes, blob_names, epoch, session):
498
        """Loads the necessary blobs from the checkpoints to the current node.
499

500
        Args:
501
            blob_names: A list of strings. Each string is the name of a
502
                blob.
503
            epoch: An integer. The checkpoint epoch to load from.
504
            session: A Session object to execute the Load ops.
505
        """
506
        if self._node_managers is not None:
507
            assert [node for node, _ in self._node_managers] == nodes
508
        else:
509
            self._node_managers = []
510
            for node in nodes:
511
                with Node(node):
512
                    manager = CheckpointManager(
513
                        db_prefix=self._db_prefix,
514
                        node_name=str(node),
515
                        db_type=self._db_type)
516
                    self._node_managers.append((node, manager))
517
        assert self._node_managers is not None, 'must initialize node managers'
518
        for _, manager in self._node_managers:
519
            existence_task = manager.check_db_exists(epoch)
520
            session.run(existence_task)
521
            existence = existence_task.outputs()[0].fetch()
522
            if not existence:
523
                logger.info('DB %s does not exist!' %
524
                            db_name(epoch, manager._node_name, manager._db_prefix))
525
                return False
526
            load_task = manager.load_blobs_from_checkpoint(blob_names, epoch)
527
            session.run(load_task)
528
        logger.info('Successfully loaded from checkpoints.')
529
        return True
530

531
    def get_ckpt_db_name(self, node_name, epoch):
532
        """Returns the DB name of the given node and the given epoch.
533

534
        The DB name is effectively the checkpoint path of the given node and
535
        the given epoch.
536

537
        Args:
538
            node_name: A string. The node name of interest.
539
            epoch: An integer. The epoch of the checkpoint.
540

541
        Returns:
542
            checkpoint_db_name: A string. The checkpoint path of the given
543
                node and the given epoch.
544
        """
545
        for node, manager in self._node_managers:
546
            if str(node) == node_name:
547
                return db_name(epoch, manager._node_name, manager._db_prefix)
548

549
    def report_checkpoint_stats(self, action_name):
550
        """
551
        Report the checkpoint stats for all the nodes, we need to aggregate all
552
        the node's stats together so that we know which node's checkpoint
553
        operation dominates.
554

555
        Args:
556
            action_name: A string of the name of checkpoint operation.
557
        """
558
        all_stats = {}
559
        for _, manager in self._node_managers:
560
            manager.collect_checkpoint_stats(all_stats)
561
        logger.debug("checkpoint stats: {}".format(all_stats))
562
        if self._metadata_handler:
563
            self._metadata_handler.report(action_name, all_stats)
564

565
    def save(self, epoch):
566
        """
567
        Build a Task that will execute a Save ops to serialize and persist
568
        blobs present in the global workspace.
569
        """
570
        return self._task_group(CheckpointManager.save, epoch)
571

572
    def write_checkpoint_metadata(self, epoch):
573
        """
574
        Write metadata for checkpoint
575

576
        Args:
577
            epoch: An integer. The epoch-id for which checkpoint metadata is
578
                written
579
        """
580
        if self._metadata_handler is not None:
581
            self._metadata_handler.write(epoch=epoch)
582

583
    def get_resume_from_epoch_id(self, user_epoch=None):
584
        """
585
        Identify the epoch-id from which Job must resume
586

587
        Args:
588
            user_epoch: An integer. Optional parameter for user to explicitly
589
                identify the epoch-id to load checkpoint from
590
        Returns:
591
            epoch: the epoch-id to load checkpoints from
592
                or None if no checkpoints were written
593
        """
594
        last_epoch = user_epoch
595
        if self._metadata_handler is not None:
596
            last_epoch = self._metadata_handler.last_epoch(user_epoch=user_epoch)
597
        return last_epoch
598

599
    def set_params(self, nodes, path_prefix=None, path_type=None):
600
        """Set parameters associated with CP manager
601

602
        Args:
603
            nodes: An array of nodes where this checkpoint manager is running.
604
            path_prefix: Used to construct db name or path where checkpoint files are
605
                stored.
606
            path_type: Indicate the type of path where checkpoint files are stored.
607
        """
608
        self._node_names = [str(node) for node in nodes]
609
        if path_prefix:
610
            self._path_prefix = path_prefix
611
        if path_type:
612
            self._path_type = path_type
613
        if self._metadata_handler:
614
            self._metadata_handler.set_params(
615
                db_prefix=self._db_prefix,
616
                db_type=self._db_type,
617
                node_names=self._node_names,
618
                path_prefix=self._path_prefix,
619
                path_type=self._path_type)
620

621
    def cp_accessible(self, epoch=None):
622
        """Returns True if Checkpoint data is accessible
623

624
        Args:
625
            epoch: An integer. The epoch of the checkpoint. If None,
626
                it implies we need to check if checkpoint directory is accessible
627

628
        Returns:
629
            is_cp_accessible: A boolean. Returns True if Checkpoint data is accessible
630
        """
631
        if self._metadata_handler is not None:
632
            return self._metadata_handler.cp_accessible(epoch)
633
        else:
634
            return True
635

636

637
class UploadTaskGroupBuilder:
638
    """A simple class to upload checkpoints."""
639
    def build(self, epoch, checkpoint_manager):
640
        """Builds the task group to upload checkpoints.
641

642
        Args:
643
            epoch: An integer. The checkpoint epoch to be uploaded.
644
            checkpoint_manager: Can be a CheckpointManager for single machine
645
                or a MultiNodeCheckpointManager for multi-machine. The manager
646
                that initializes/saves/loads checkpoints.
647

648
        Raises:
649
            NotImplementedError: This base class only has the interface,
650
                the implementation will be in the subclasses.
651
        """
652
        raise NotImplementedError()
653

654

655
class JobRunner:
656
    """
657
    Implement the runtime logic for jobs with checkpointing at the level of
658
    epoch. Can be used to run either single-host or distributed jobs. Job
659
    runner is a callable to be called once from the master, passing a session
660
    as an argument. This call will block until the Job execution is complete.
661

662
    If a checkpoint_manager is passed, checkpoints will be taken after
663
    initialization and after each epoch execution. If, in addition,
664
    `resume_from_epoch` is an epoch number, the corresponding checkpoint will
665
    be loaded and job execution will continue from the given epoch. In
666
    this case, the job's init_group will not be run.
667

668
    Refer to checkpoint_test.py for an example.
669
    """
670
    def __init__(self, job, checkpoint_manager=None, resume_from_epoch=None,
671
                 upload_task_group_builder=None):
672
        """Initializes the JobRunner.
673

674
        Args:
675
            job: A Job object. The job to be executed.
676
            checkpoint_manager: Can be a CheckpointManager for single machine
677
                or a MultiNodeCheckpointManager for multi-machine. The manager
678
                that initializes/saves/loads checkpoints.
679
            resume_from_epoch: An integer. The epoch to resume from.
680
            upload_task_group_builder: A subclass of the
681
                UploadTaskGroupBuilder. Creates a task group to upload
682
                checkpoints.
683
        """
684
        self.resume_from_epoch = resume_from_epoch
685
        self.checkpoint_manager = checkpoint_manager
686
        self.job = job
687
        self.upload_task_group_builder = upload_task_group_builder
688

689
    def train(self, session):
690
        """Runs the training flow.
691

692
        Args:
693
            session: A Session object. Valid choises are: LocalSession,
694
                LocalHostScheduler, and DistributedSession. It is used to
695
                execute one TaskGroup a time.
696
        """
697
        # identify the epoch we must resume from
698
        if self.checkpoint_manager:
699
            self.checkpoint_manager.set_params(nodes=self.job.nodes_to_checkpoint())
700
            self.resume_from_epoch = self.checkpoint_manager.\
701
                get_resume_from_epoch_id(self.resume_from_epoch)
702
            if self.resume_from_epoch is not None:
703
                logger.info('Resuming from epoch {}'.format(self.resume_from_epoch))
704

705
        # Initialize all the nodes.
706
        from_scratch = self.resume_from_epoch is None
707
        if from_scratch:
708
            session.run(self.job.init_group)
709

710
        if self.checkpoint_manager:
711
            logger.info('Preparing checkpoints ...')
712
            session.run(self.checkpoint_manager.init(
713
                self.job.nodes_to_checkpoint(),
714
                retrieve_from_epoch=self.resume_from_epoch))
715
            # Save the first checkpoint before training starts, or resume from
716
            # a previously saved checkpoint.
717
            if from_scratch:
718
                self.save_checkpoints(0, session)
719
            else:
720
                logger.info('Loading checkpoints for epoch {} ...'.format(
721
                    self.resume_from_epoch))
722
                session.run(
723
                    self.checkpoint_manager.load(self.resume_from_epoch))
724
                self.checkpoint_manager.report_checkpoint_stats('checkpoint_load')
725
                logger.info('Checkpoint loaded')
726

727
        logger.info("Finished initializing")
728

729
        # Start training.
730
        epoch = 1 if from_scratch else self.resume_from_epoch + 1
731
        while True:
732
            logger.info('Starting epoch %d' % epoch)
733
            session.run(self.job.epoch_group)
734
            logger.info('Finished epoch %d' % epoch)
735
            stop_conditions = [o.fetch() for o in self.job.stop_conditions]
736

737
            if self.checkpoint_manager:
738
                self.save_checkpoints(epoch, session)
739

740
            if any(stop_conditions):
741
                logger.info('Stopping')
742
                break
743
            epoch += 1
744
        logger.info('Finished training')
745
        # Upload the checkpoints.
746
        if (self.upload_task_group_builder):
747
            upload_task_group = self.upload_task_group_builder.build(
748
                epoch, self.checkpoint_manager)
749
            session.run(upload_task_group)
750
            logger.info('Finished uploading the checkpoints')
751

752
        # Download the parameters to save
753
        session.run(self.job.download_group)
754
        logger.info('Finished downloading the parameters')
755

756
        # Finally run the exit step to save nets
757
        session.run(self.job.exit_group)
758
        logger.info('Finished running the exit group')
759
        return epoch
760

761
    def load_blobs_from_checkpoints(self, blob_names, epoch, session):
762
        """Loads the necessary blobs from the checkpoints.
763

764
        Checkpoints store the snapshots of the workspace in each node.
765
        Sometimes we only need to load a subset of the blobs from the
766
        checkpoints. One common scenario is to load only the model blobs from
767
        the checkpoints for evaluation purpose. Given the names of the
768
        necessary blobs, this function goes over all the checkpoints of all the
769
        nodes, but only loads the blobs specified in the blob_names to the
770
        current workspace.
771

772
        Args:
773
            blob_names: A list of strings. Each string is the name of a
774
                blob.
775
            epoch: An integer. The checkpoint epoch to load from.
776
            session: A Session object to execute the load ops.
777

778
        Raises:
779
            ValueError: When the checkpoint manager is invalid.
780
        """
781
        if not self.checkpoint_manager:
782
            raise ValueError('Checkpoint manager is None')
783
        logger.info('Loading checkpoint for epoch {} ...'.format(epoch))
784
        result = self.checkpoint_manager.load_blobs_locally(
785
            self.job.nodes_to_checkpoint(), blob_names, epoch, session)
786
        self.checkpoint_manager.report_checkpoint_stats('checkpoint_partial_load')
787
        return result
788

789
    def save_checkpoints(self, epoch, session):
790
        """Triggers operation to save checkpoints
791

792
        This method will trigger the Save ops to serialize and persist the
793
        blobs present in the global workspaace.
794

795
        Args:
796
            epoch: An integer. The checkpoint epoch-id that we are saving.
797
            session: A Session object to execute the save ops.
798

799
        Raises:
800
            ValueError: When the checkpoint manager is invalid.
801
        """
802
        if not self.checkpoint_manager:
803
            raise ValueError('Checkpoint manager is None')
804
        try:
805
            is_accessible = self.checkpoint_manager.cp_accessible(epoch=None)
806
            if is_accessible:
807
                logger.info('Saving checkpoints for epoch {}'.format(epoch))
808
                session.run(self.checkpoint_manager.save(epoch))
809
                self.checkpoint_manager.write_checkpoint_metadata(epoch)
810
                logger.info('Checkpoints saved')
811
                self.checkpoint_manager.report_checkpoint_stats('checkpoint_save')
812
            else:
813
                logger.warning("Checkpoint files cannot be accessed!")
814
        except Exception as ex:
815
            logger.warning("Unable to write checkpoint for epoch {}. Error={}".
816
                            format(epoch, ex))
817

818

819
def epoch_limiter(job, num_epochs):
820
    """
821
    Creates a task that will output True when a given
822
    number of epochs has finished.
823
    """
824
    with job.init_group:
825
        init_net = core.Net('epoch_counter_init')
826
        counter = init_net.CreateCounter([], init_count=num_epochs - 1)
827
        Task(step=init_net)
828

829
    with job.epoch_group:
830
        epoch_net = core.Net('epoch_countdown')
831
        finished = epoch_net.CountDown(counter)
832
        output = Task(step=epoch_net, outputs=finished).outputs()[0]
833
    job.add_stop_condition(output)
834

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.