pytorch

Форк
0
/
layer_model_helper.py 
752 строки · 28.6 Кб
1
# @package layer_model_helper
2
# Module caffe2.python.layer_model_helper
3

4

5

6

7

8
from caffe2.python import core, model_helper, schema, scope, utils, muji
9
from caffe2.python.modeling.parameter_info import (
10
    ParameterInfo,
11
)
12
from caffe2.python.modeling.parameter_sharing import (
13
    parameter_sharing_context,
14
)
15
from caffe2.python.modeling.net_modifier import NetModifier
16

17
from caffe2.python.optimizer import get_param_device, Optimizer
18
from caffe2.python.regularizer import Regularizer, RegularizationBy
19
from caffe2.python.layers import layers
20

21
import logging
22
import numpy as np
23
import copy
24
logger = logging.getLogger(__name__)
25

26

27
class LayerModelHelper(model_helper.ModelHelper):
28
    """
29
    Model helper for building models on top of layers abstractions.
30

31
    Each layer is the abstraction that is higher level than Operator. Layer
32
    is responsible for ownership of it's own parameters and can easily be
33
    instantiated in multiple nets possible with different sets of ops.
34
    As an example: one can easily instantiate predict and train nets from
35
    the same set of layers, where predict net will have subset of the
36
    operators from train net.
37
    """
38

39
    def __init__(self, name, input_feature_schema, trainer_extra_schema,
40
                 keep_blobs=False,
41
                 use_attribution=True):
42
        ''' TODO(amalevich): more documnetation on input args
43

44
        use_attribution:
45
            if True, will generate the atrribution net for feature importance
46
            calculation; Need to turn it to false when FC is quantized as FP16
47
            This attribute access will be consistent with MTML model.
48
        '''
49

50
        super().__init__(name=name)
51
        self._layer_names = set()
52
        self._layers = []
53
        self._param_to_shape = {}
54

55
        # seed default
56
        self._seed = None
57
        self._sequence_seed = True
58

59
        # optimizer bookkeeping
60
        self.param_to_optim = {}
61
        self.param_to_reg = {}
62

63
        self._default_optimizer = None
64
        self._loss = None
65
        self._prediction = []
66
        self._output_schema = None
67

68
        self._post_grad_net_modifiers = []
69
        self._final_net_modifiers = []
70

71
        # breakdown map; breakdown features are categorical (like dense) but not
72
        # necessarily used to represent data for training
73
        self._breakdown_map = None
74

75
        # Connect Schema to self.net. That particular instance of schmea will be
76
        # use for generation of the Layers across the network and would be used
77
        # for connection with Readers.
78
        self._input_feature_schema = schema.NewRecord(
79
            self.net,
80
            input_feature_schema
81
        ) if not keep_blobs else input_feature_schema.clone()
82
        self._trainer_extra_schema = schema.NewRecord(
83
            self.net,
84
            trainer_extra_schema
85
        ) if not keep_blobs else trainer_extra_schema.clone()
86
        self._metrics_schema = schema.Struct()
87

88
        self._preproc_output_schema = None
89

90
        self._init_global_constants()
91
        self.param_init_net = self.create_init_net('param_init_net')
92
        self._initialize_params = True
93

94
        self._transfer_learning_blob_name_mappings = None
95

96
        # additional (hard-coded) diagnose_options to report based on the model
97
        # TODO(xlwang): it's hack!
98
        self.ad_hoc_diagnose_blobs_and_operations = []
99
        self.ad_hoc_plot_blobs = []
100
        self.use_attribution = use_attribution
101

102
    def clear_output_schema(self):
103
        self._output_schema = None
104

105
    def set_initialize_params(self, initialize_params):
106
        self._initialize_params = initialize_params
107

108
    def add_metric_field(self, name, value):
109
        assert name not in self._metrics_schema.fields, (
110
            "Try to add metric field twice: {}".format(name))
111
        self._metrics_schema = self._metrics_schema + schema.Struct(
112
            (name, value)
113
        )
114

115
    # an empty white_set will skip everything
116
    def filter_metrics_schema(self, white_set):
117
        logger.info("Filter metric schema with white_set {}".format(white_set))
118
        field_names = self._metrics_schema.field_names()
119
        for name in field_names:
120
            if name not in white_set:
121
                self._metrics_schema = self._metrics_schema - schema.Struct((name, schema.Scalar()))
122

123
    def add_ad_hoc_plot_blob(self, blob, dtype=None):
124
        assert isinstance(
125
            blob, (str, core.BlobReference)
126
        ), "expect type str or BlobReference, but got {}".format(type(blob))
127
        dtype = dtype or (np.float64, (1, ))
128
        self.add_metric_field(str(blob), schema.Scalar(dtype, blob))
129
        self.ad_hoc_plot_blobs.append(blob)
130

131
    @staticmethod
132
    def _get_global_constant_initializer_op(
133
        blob_name, array=None, dtype=None, initializer=None
134
    ):
135
        # to add a global constant to model, one first need to get the
136
        # initializer
137
        if array is not None:
138
            assert initializer is None,\
139
                "Only one from array and initializer should be specified"
140
            if dtype is None:
141
                array = np.array(array)
142
            else:
143
                array = np.array(array, dtype=dtype)
144

145
            # TODO: make GivenTensor generic
146
            op_name = None
147
            if array.dtype == np.int32:
148
                op_name = 'GivenTensorIntFill'
149
            elif array.dtype == np.int64:
150
                op_name = 'GivenTensorInt64Fill'
151
            elif array.dtype == str:
152
                op_name = 'GivenTensorStringFill'
153
            elif array.dtype == bool:
154
                op_name = 'GivenTensorBoolFill'
155
            else:
156
                op_name = 'GivenTensorFill'
157

158
            def initializer(blob_name):
159
                return core.CreateOperator(
160
                    op_name, [],
161
                    blob_name,
162
                    shape=array.shape,
163
                    values=array.flatten().tolist()
164
                )
165
        else:
166
            assert initializer is not None
167
        initializer_op = initializer(blob_name)
168
        return initializer_op
169

170
    def add_global_constant(
171
        self, name, array=None, dtype=None, initializer=None
172
    ):
173
        assert isinstance(name, str), (
174
            'name should be a string as we are using it as map key')
175
        # This is global namescope for constants. They will be created in all
176
        # init_nets and there should be very few of them.
177
        assert name not in self.global_constants, \
178
            "%s already added in global_constants" % name
179
        blob_name = self.net.NextBlob(name)
180
        self.global_constants[name] = blob_name
181
        initializer_op = LayerModelHelper._get_global_constant_initializer_op(
182
            blob_name, array, dtype, initializer
183
        )
184
        assert blob_name not in self.global_constant_initializers, \
185
            "there is already a initializer op associated with blob %s" % \
186
            blob_name
187
        self.global_constant_initializers[blob_name] = initializer_op
188
        return blob_name
189

190
    def maybe_add_global_constant(self, name, *args, **kwargs):
191
        # To ad hoc add new global constants without duplication
192
        # if the name was already registered in global_constants, it will not be
193
        # added even if the intended value is different from its original value
194

195
        if name in self.global_constants:
196
            blob_name = self.global_constants[name]
197
            initializer_op = \
198
                LayerModelHelper._get_global_constant_initializer_op(
199
                    blob_name, *args, **kwargs
200
                )
201
            # check if the original initializer is the same as the one intended
202
            # now
203
            assert utils.OpAlmostEqual(
204
                initializer_op,
205
                self.global_constant_initializers[blob_name],
206
                'debug_info'
207
            ), \
208
                "conflict initializers for global constant %s, " \
209
                "previous %s, now %s" % (
210
                    blob_name, str(initializer_op),
211
                    str(self.global_constant_initializers[blob_name]))
212
            return blob_name
213
        return self.add_global_constant(name, *args, **kwargs)
214

215
    def _init_global_constants(self):
216
        self.global_constants = {}
217
        self.global_constant_initializers = {}
218
        self.add_global_constant('ONE', 1.0)
219
        self.add_global_constant('NAN', float("NaN"))
220
        self.add_global_constant('ZERO', 0.0)
221
        self.add_global_constant('ZERO_RANGE', [0, 0], dtype='int32')
222

223
    def _add_global_constants(self, init_net):
224
        for initializer_op in self.global_constant_initializers.values():
225
            init_net._net.op.extend([initializer_op])
226

227
    def create_init_net(self, name):
228
        init_net = core.Net(name)
229
        self._add_global_constants(init_net)
230
        return init_net
231

232
    def _validate_param_shape(self, param_name, shape):
233
        if param_name not in self._param_to_shape:
234
            return
235

236
        ref_shape = self._param_to_shape[param_name]
237

238
        if shape != ref_shape:
239
            raise ValueError(
240
                "Got inconsistent shapes between shared parameters "
241
                "when trying to map a blob in scope {0} to {1}. ref_shape : "
242
                " {2}, shape : {3}".format(
243
                    scope.CurrentNameScope(), param_name, ref_shape, shape)
244
            )
245

246
    def _validate_param_optim(self, param_name, optim):
247
        # there are three possible values for optim:
248
        # 1) None (which will use self._default_optimizer after this layer is instantiated)
249
        # 2) self.NoOptim
250
        # 3) an instance of Optimizer class such as AdagradOptimizer
251

252
        # this implies this parameter is not shared with any other parameter so far
253
        if param_name not in self.param_to_optim:
254
            return
255

256
        logger.info("{} shares the same parameter with another parameter. "
257
                    "Validating if the same optimizer has been specified for them.".format(
258
                        param_name,
259
                    ))
260

261
        ref_optim = self.param_to_optim[param_name]
262

263
        if optim is None:
264
            assert ref_optim == self._default_optimizer, (
265
                "Optim for {} is None which will fall back to use default_optimizer. "
266
                "However, the optimizer that has been specified for this shared parameter "
267
                "is {} which is different from default_optimizer {}. "
268
                "Please check the optimizers specified for parameters shared "
269
                "with {} and the default_optimizer to ensure the consistency.".format(
270
                    param_name, ref_optim, self._default_optimizer, param_name
271
                )
272
            )
273
        elif optim == self.NoOptim:
274
            assert ref_optim == self.NoOptim, (
275
                "Optim for {} is NoOptim. However, the optimizer for the parameters "
276
                "shared with {} is {} which is different from NoOptim. "
277
                "Please check the optimizer specified for other parameters in the "
278
                "shared group to ensure consistency.".format(
279
                    param_name, param_name, ref_optim
280
                )
281
            )
282
        elif isinstance(optim, Optimizer):
283
            assert isinstance(ref_optim, Optimizer), (
284
                "Optim for {} is an instance of Optimizer. However, the optimizer "
285
                "for the parameters shared with {} is {} which is not an instance "
286
                "of Optimizer. Please check the optimizer specified for other "
287
                " parameters in the shared group to ensure consistency.".format(
288
                    param_name, param_name, ref_optim, optim
289
                )
290
            )
291

292
            assert type(optim) is type(ref_optim) and optim.attributes == ref_optim.attributes, (
293
                "Optim for {} is an instance of Optimizer. However, the optimizer "
294
                "for the parameters shared with {} is {}. "
295
                "This optimizer either doesn't have the same type as the current optimizer: "
296
                "{} vs {}, or its attributes such as learning rate are different from "
297
                "that of current optimizer which is {} vs {}. "
298
                "Please check the optimizer specified for other parameters in the "
299
                "shared group to ensure consistency.".format(
300
                    param_name, param_name, ref_optim, type(optim), type(ref_optim), optim.attributes, ref_optim.attributes
301
                )
302
            )
303
        else:
304
            raise ValueError("optim should be either None, NoOptim, or an instance of Optimizer, Got {} ".format(optim))
305

306
    def create_param(self, param_name, shape, initializer, optimizer=None,
307
                     ps_param=None, regularizer=None):
308
        if isinstance(param_name, core.BlobReference):
309
            param_name = str(param_name)
310
        elif isinstance(param_name, str):
311
            # Parameter name will be equal to current Namescope that got
312
            # resolved with the respect of parameter sharing of the scopes.
313
            param_name = parameter_sharing_context.get_parameter_name(
314
                param_name)
315
        else:
316
            raise ValueError("Unsupported type for param_name")
317

318
        param_blob = core.BlobReference(param_name)
319

320
        if len(initializer) == 1:
321
            init_op_args = {}
322
        else:
323
            assert len(initializer) == 2
324
            init_op_args = copy.deepcopy(initializer[1])
325
        if shape is not None:
326
            assert 'shape' not in init_op_args
327
            init_op_args.update({'shape': shape})
328

329
        initializer_op = None
330
        if self._initialize_params:
331
            initializer_op = core.CreateOperator(
332
                initializer[0],
333
                [],
334
                param_blob,
335
                **init_op_args
336
            )
337

338
        param = layers.LayerParameter(
339
            parameter=param_blob,
340
            initializer=initializer_op,
341
            optimizer=optimizer,
342
            ps_param=ps_param,
343
            regularizer=regularizer
344
        )
345

346
        self._validate_param_shape(param_name, shape)
347

348
        self._validate_param_optim(param_name, optimizer)
349

350
        self._param_to_shape[param_name] = shape
351

352
        return param
353

354
    def next_layer_name(self, prefix):
355
        base_name = core.ScopedName(prefix)
356
        name = base_name
357
        index = 0
358
        while name in self._layer_names:
359
            name = base_name + '_auto_' + str(index)
360
            index += 1
361

362
        self._layer_names.add(name)
363
        return name
364

365
    def add_layer(self, layer):
366
        self._layers.append(layer)
367
        for param in layer.get_parameters():
368
            assert isinstance(param.parameter, core.BlobReference)
369

370
            self.param_to_optim[str(param.parameter)] = \
371
                param.optimizer or self.default_optimizer
372

373
            self.params.append(param.parameter)
374
            if isinstance(param, layers.LayerParameter):
375
                logger.info("Add parameter regularizer {0}".format(param.parameter))
376
                self.param_to_reg[param.parameter] = param.regularizer
377
            elif isinstance(param, ParameterInfo):
378
                # TODO:
379
                # Currently, LSTM and RNNcells, which use ModelHelper instead of
380
                # LayerModelHelper as super class, are called in pooling_methods
381
                # In ModelHelper, regularization is not supported in create_param
382
                # We will unify the way of create_param of ModelHelper and
383
                # LayerModelHelper in the future.
384
                logger.info('regularization is unsupported for ParameterInfo object')
385
            else:
386
                raise ValueError(
387
                    'unknown object type besides ParameterInfo and LayerParameter: {}'
388
                    .format(param)
389
                )
390

391
        # The primary value of adding everything to self.net - generation of the
392
        # operators right away, i.e. if error happens it'll be detected
393
        # immediately. Other than this - create_x_net should be called.
394
        layer.add_operators(self.net, self.param_init_net)
395
        return layer.output_schema
396

397
    def get_parameter_blobs(self):
398
        param_blobs = []
399
        for layer in self._layers:
400
            for param in layer.get_parameters():
401
                param_blobs.append(param.parameter)
402

403
        return param_blobs
404

405
    def add_post_grad_net_modifiers(self, modifier):
406
        assert modifier not in self._post_grad_net_modifiers,\
407
            "{0} is already in {1}".format(modifier, self._post_grad_net_modifiers)
408
        assert isinstance(modifier, NetModifier),\
409
            "{} has to be a NetModifier instance".format(modifier)
410
        self._post_grad_net_modifiers.append(modifier)
411

412
    def add_final_net_modifiers(self, modifier):
413
        assert modifier not in self._final_net_modifiers,\
414
            "{0} is already in {1}".format(modifier, self._final_net_modifiers)
415
        assert isinstance(modifier, NetModifier),\
416
            "{} has to be a NetModifier instance".format(modifier)
417
        self._final_net_modifiers.append(modifier)
418

419
    @property
420
    def seed(self):
421
        return self._seed
422

423
    @property
424
    def sequence_seed(self):
425
        return self._sequence_seed
426

427
    def store_seed(self, seed, sequence_seed=True):
428
        # Store seed config that will be applied to each op in the net.
429
        self._seed = seed
430
        # If sequence_seed is True, the i-th op has rand_seed=`seed + i`
431
        self._sequence_seed = sequence_seed
432

433
    def apply_seed(self, net):
434
        if self._seed:
435
            net.set_rand_seed(self._seed, self._sequence_seed)
436

437
    @property
438
    def default_optimizer(self):
439
        return self._default_optimizer
440

441
    @default_optimizer.setter
442
    def default_optimizer(self, optimizer):
443
        self._default_optimizer = optimizer
444

445
    @property
446
    def input_feature_schema(self):
447
        return self._input_feature_schema
448

449
    @property
450
    def trainer_extra_schema(self):
451
        return self._trainer_extra_schema
452

453
    @property
454
    def metrics_schema(self):
455
        """
456
        Returns the schema that represents model output that should be used for
457
        metric reporting.
458

459
        During the training/evaluation this schema will be appended to the
460
        schema that represents model output.
461
        """
462
        return self._metrics_schema
463

464
    @property
465
    def output_schema(self):
466
        assert self._output_schema is not None
467
        return self._output_schema
468

469
    @output_schema.setter
470
    def output_schema(self, schema):
471
        assert self._output_schema is None
472
        self._output_schema = schema
473

474
    @property
475
    def preproc_output_schema(self):
476
        assert self._preproc_output_schema is not None
477
        return self._preproc_output_schema
478

479
    @preproc_output_schema.setter
480
    def preproc_output_schema(self, schema):
481
        assert self._preproc_output_schema is None
482
        self._preproc_output_schema = schema
483

484
    @property
485
    def prediction(self):
486
        assert self._prediction, "model prediction is empty"
487
        return self._prediction
488

489
    def add_prediction(self, prediction, weight=1.0):
490
        assert prediction is not None, "Added prediction should not be None"
491
        self._prediction.append((prediction, weight))
492

493
    @property
494
    def transfer_learning_blob_name_mappings(self):
495
        return self._transfer_learning_blob_name_mappings
496

497
    @transfer_learning_blob_name_mappings.setter
498
    def transfer_learning_blob_name_mappings(self, blob_name_mappings):
499
        assert blob_name_mappings is not None, "Transfer learning blob name mappings should not be None"
500
        self._transfer_learning_blob_name_mappings = blob_name_mappings
501

502
    @property
503
    def loss(self):
504
        assert self._loss is not None
505
        return self._loss
506

507
    @loss.setter
508
    def loss(self, loss):
509
        assert self._loss is None
510
        self._loss = loss
511

512
    def has_loss(self):
513
        return self._loss is not None
514

515
    def add_loss(self, loss, name='unnamed'):
516
        assert loss is not None, "Added loss should not be None"
517
        assert isinstance(loss, schema.Scalar) or isinstance(
518
            loss, schema.Struct
519
        ), "Added loss should be a scalar or a struct"
520
        if self._loss is None:
521
            self._loss = schema.Struct((name, loss))
522
        else:
523
            # loss could've been set through model.loss directly which could be
524
            # a scalar
525
            if isinstance(self._loss, schema.Scalar):
526
                self._loss = schema.Struct(('unnamed', self._loss))
527

528
            prefix_base = name + '_auto_'
529
            index = 0
530
            prefix = name
531
            while prefix in self._loss:
532
                prefix = prefix_base + str(index)
533
                index += 1
534
            loss_struct = schema.Struct((prefix, loss))
535
            self._loss = self._loss + loss_struct
536

537
    def add_output_schema(self, name, value):
538
        assert value is not None, \
539
            'Added output schema {} should not be None'.format(name)
540
        assert isinstance(value, schema.Scalar) or \
541
            isinstance(value, schema.Struct), \
542
            'Added output schema {} should be a scalar or a struct.\n\
543
            Now it is {}.'.format(name, type(value))
544
        if self._output_schema is None:  # be the first field
545
            self._output_schema = schema.Struct((name, value))
546
        else:  # merge with other fields
547
            assert name not in self._output_schema.fields, \
548
                'Output Schema Field {} already exists'.format(name)
549
            self._output_schema = \
550
                self._output_schema + schema.Struct((name, value))
551

552
    def add_trainer_extra_schema(self, trainer_extra_schema):
553
        trainer_extra_record = schema.NewRecord(self.net, trainer_extra_schema)
554
        self._trainer_extra_schema += trainer_extra_record
555

556
    def __getattr__(self, layer):
557
        def is_functional_layer(layer):
558
            if core.IsOperator(layer):
559
                return True
560
            elif layer.startswith('FunctionalLayer'):
561
                return True
562
            else:
563
                return False
564

565
        def resolve_functional_layer(layer):
566
            if core.IsOperator(layer):
567
                return layer
568
            elif layer.startswith('FunctionalLayer'):
569
                return layer[len('FunctionalLayer'):]
570
            else:
571
                raise ValueError(
572
                    '%s cannot be resolved as functional layer' % layer
573
                )
574

575
        if layer.startswith('__'):
576
            raise AttributeError(layer)
577

578
        # TODO(amalevich): Add add support for ifbpy inline documentation
579
        if layers.layer_exists(layer):
580
            def wrapper(*args, **kwargs):
581
                new_layer = layers.create_layer(layer, self, *args, **kwargs)
582
                if kwargs.get("output_to_metrics", False):
583
                    new_layer.export_output_for_metrics()
584
                if kwargs.get("params_to_metrics", False):
585
                    new_layer.export_params_for_metrics()
586
                return self.add_layer(new_layer)
587
            return wrapper
588
        elif is_functional_layer(layer):
589
            # TODO(xlwang): Desginated layer shadows the usage of an op as a
590
            # single layer. To enforce using an op (e.g. Split) as functional
591
            # layer, one can call 'model.FunctionalLayerSplit'
592
            layer = resolve_functional_layer(layer)
593

594
            def wrapper(*args, **kwargs):
595
                def apply_operator(net, in_record, out_record, **kwargs):
596
                    # TODO(amalevich): Switch to net.operator as soon as it gets
597
                    # landed
598
                    net.__getattr__(layer)(in_record.field_blobs(),
599
                                           out_record.field_blobs(),
600
                                           **kwargs)
601

602
                if 'name' not in kwargs:
603
                    kwargs['name'] = layer
604

605
                new_layer = layers.create_layer(
606
                    'Functional',
607
                    self, *args, function=apply_operator,
608
                    **kwargs
609
                )
610

611
                if kwargs.get("output_to_metrics", False):
612
                    new_layer.export_output_for_metrics()
613
                if kwargs.get("params_to_metrics", False):
614
                    new_layer.export_params_for_metrics()
615

616
                return self.add_layer(new_layer)
617
            return wrapper
618
        else:
619
            # this needs to be an AttributeError to fit hasattr semantics
620
            raise AttributeError(
621
                "Trying to create non-registered layer: {}".format(layer))
622

623
    @property
624
    def layers(self):
625
        return self._layers
626

627
    def apply_regularizers_on_loss(
628
        self,
629
        train_net,
630
        train_init_net,
631
        blob_to_device=None,
632
    ):
633
        logger.info("apply regularizer on loss")
634
        for param, regularizer in self.param_to_reg.items():
635
            if regularizer is None:
636
                continue
637
            logger.info("add regularizer {0} for param {1} to loss".format(regularizer, param))
638
            assert isinstance(regularizer, Regularizer)
639
            added_loss_blob = regularizer(train_net, train_init_net, param, grad=None,
640
                                          by=RegularizationBy.ON_LOSS)
641
            logger.info(added_loss_blob)
642
            if added_loss_blob is not None:
643
                self.add_loss(
644
                    schema.Scalar(blob=added_loss_blob),
645
                    str(added_loss_blob)
646
                )
647

648
    def apply_regularizers_after_optimizer(
649
        self,
650
        train_net,
651
        train_init_net,
652
        grad_map,
653
        blob_to_device=None,
654
    ):
655
        logger.info("apply regularizer after optimizer")
656
        CPU = muji.OnCPU()
657
        # if given, blob_to_device is a map from blob to device_option
658
        blob_to_device = blob_to_device or {}
659
        for param, regularizer in self.param_to_reg.items():
660
            if regularizer is None:
661
                continue
662
            assert isinstance(regularizer, Regularizer)
663
            logger.info("add regularizer {0} for param {1} to optimizer".format(regularizer, param))
664
            device = get_param_device(
665
                param,
666
                grad_map.get(str(param)),
667
                param_to_device=blob_to_device,
668
                default_device=CPU,
669
            )
670
            with core.DeviceScope(device):
671
                regularizer(
672
                    train_net, train_init_net, param, grad=grad_map.get(str(param)),
673
                    by=RegularizationBy.AFTER_OPTIMIZER
674
                )
675

676
    def apply_post_grad_net_modifiers(
677
        self,
678
        trainer_net,
679
        trainer_init_net,
680
        grad_map,
681
        blob_to_device=None,
682
        modify_output_record=False,
683
    ):
684
        param_grad_map = {param: grad_map[param]
685
                          for param in self.param_to_optim.keys() if param in grad_map}
686

687
        for modifier in self._post_grad_net_modifiers:
688
            modifier(trainer_net, trainer_init_net, param_grad_map,
689
                     blob_to_device=blob_to_device,
690
                     modify_output_record=modify_output_record)
691

692
    def apply_final_net_modifiers(
693
        self,
694
        trainer_net,
695
        trainer_init_net,
696
        grad_map,
697
        blob_to_device=None,
698
        modify_output_record=False,
699
    ):
700
        for modifier in self._final_net_modifiers:
701
            modifier(trainer_net, trainer_init_net, grad_map,
702
                     blob_to_device=blob_to_device,
703
                     modify_output_record=modify_output_record)
704

705
    def apply_optimizers(
706
        self,
707
        train_net,
708
        train_init_net,
709
        grad_map,
710
        blob_to_device=None,
711
    ):
712
        CPU = muji.OnCPU()
713
        # if given, blob_to_device is a map from blob to device_option
714
        blob_to_device = blob_to_device or {}
715
        for param, optimizer in self.param_to_optim.items():
716
            assert optimizer is not None, \
717
                "default optimizer must have been set in add_layer"
718
            # note that not all params has gradient and thus we sent None if
719
            # gradient does not exists
720
            device = get_param_device(
721
                param,
722
                grad_map.get(str(param)),
723
                param_to_device=blob_to_device,
724
                default_device=CPU,
725
            )
726
            if device is not None:
727
                # extra info is not applicable for optimizers
728
                del device.extra_info[:]
729

730
            with core.DeviceScope(device):
731
                optimizer(
732
                    train_net, train_init_net, param, grad_map.get(str(param)))
733

734
    def _GetOne(self):
735
        return self.global_constants['ONE']
736

737
    # An optimizer which allows us to do NO optimization
738
    def NoOptim(self, *args, **kwargs):
739
        pass
740

741
    @property
742
    def breakdown_map(self):
743
        return self._breakdown_map
744

745
    @breakdown_map.setter
746
    def breakdown_map(self, breakdown_map):
747
        # TODO(xlwang): provide more rich feature information in breakdown_map;
748
        # and change the assertion accordingly
749
        assert isinstance(breakdown_map, dict)
750
        assert all(isinstance(k, str) for k in breakdown_map)
751
        assert sorted(breakdown_map.values()) == list(range(len(breakdown_map)))
752
        self._breakdown_map = breakdown_map
753

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.