StyleFeatureEditor

Форк
0
370 строк · 19.8 Кб
1
# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
#
3
# This work is made available under the Nvidia Source Code License-NC.
4
# To view a copy of this license, visit
5
# https://nvlabs.github.io/stylegan2/license.html
6

7
"""Helper wrapper for a Tensorflow optimizer."""
8

9
import platform
10
import numpy as np
11
import tensorflow as tf
12

13
from collections import OrderedDict
14
from typing import List, Union
15

16
from . import autosummary
17
from . import tfutil
18
from .. import util
19

20
from .tfutil import TfExpression, TfExpressionEx
21

22
_collective_ops_warning_printed = False
23
_collective_ops_group_key       = 831766147
24
_collective_ops_instance_key    = 436340067
25

26
class Optimizer:
27
    """A Wrapper for tf.train.Optimizer.
28

29
    Automatically takes care of:
30
    - Gradient averaging for multi-GPU training.
31
    - Gradient accumulation for arbitrarily large minibatches.
32
    - Dynamic loss scaling and typecasts for FP16 training.
33
    - Ignoring corrupted gradients that contain NaNs/Infs.
34
    - Reporting statistics.
35
    - Well-chosen default settings.
36
    """
37

38
    def __init__(self,
39
        name:                   str             = "Train",                  # Name string that will appear in TensorFlow graph.
40
        tf_optimizer:           str             = "tf.train.AdamOptimizer", # Underlying optimizer class.
41
        learning_rate:          TfExpressionEx  = 0.001,                    # Learning rate. Can vary over time.
42
        minibatch_multiplier:   TfExpressionEx  = None,                     # Treat N consecutive minibatches as one by accumulating gradients.
43
        share:                  "Optimizer"     = None,                     # Share internal state with a previously created optimizer?
44
        use_loss_scaling:       bool            = False,                    # Enable dynamic loss scaling for robust mixed-precision training?
45
        loss_scaling_init:      float           = 64.0,                     # Log2 of initial loss scaling factor.
46
        loss_scaling_inc:       float           = 0.0005,                   # Log2 of per-minibatch loss scaling increment when there is no overflow.
47
        loss_scaling_dec:       float           = 1.0,                      # Log2 of per-minibatch loss scaling decrement when there is an overflow.
48
        report_mem_usage:       bool            = False,                    # Report fine-grained memory usage statistics in TensorBoard?
49
        **kwargs):
50

51
        # Public fields.
52
        self.name                   = name
53
        self.learning_rate          = learning_rate
54
        self.minibatch_multiplier   = minibatch_multiplier
55
        self.id                     = self.name.replace("/", ".")
56
        self.scope                  = tf.get_default_graph().unique_name(self.id)
57
        self.optimizer_class        = util.get_obj_by_name(tf_optimizer)
58
        self.optimizer_kwargs       = dict(kwargs)
59
        self.use_loss_scaling       = use_loss_scaling
60
        self.loss_scaling_init      = loss_scaling_init
61
        self.loss_scaling_inc       = loss_scaling_inc
62
        self.loss_scaling_dec       = loss_scaling_dec
63

64
        # Private fields.
65
        self._updates_applied       = False
66
        self._devices               = OrderedDict() # device_name => EasyDict()
67
        self._shared_optimizers     = OrderedDict() # device_name => optimizer_class
68
        self._gradient_shapes       = None          # [shape, ...]
69
        self._report_mem_usage      = report_mem_usage
70

71
        # Validate arguments.
72
        assert callable(self.optimizer_class)
73

74
        # Share internal state if requested.
75
        if share is not None:
76
            assert isinstance(share, Optimizer)
77
            assert self.optimizer_class is share.optimizer_class
78
            assert self.learning_rate is share.learning_rate
79
            assert self.optimizer_kwargs == share.optimizer_kwargs
80
            self._shared_optimizers = share._shared_optimizers # pylint: disable=protected-access
81

82
    def _get_device(self, device_name: str):
83
        """Get internal state for the given TensorFlow device."""
84
        tfutil.assert_tf_initialized()
85
        if device_name in self._devices:
86
            return self._devices[device_name]
87

88
        # Initialize fields.
89
        device = util.EasyDict()
90
        device.name             = device_name
91
        device.optimizer        = None          # Underlying optimizer:     optimizer_class
92
        device.loss_scaling_var = None          # Log2 of loss scaling:     tf.Variable
93
        device.grad_raw         = OrderedDict() # Raw gradients:            var => [grad, ...]
94
        device.grad_clean       = OrderedDict() # Clean gradients:          var => grad
95
        device.grad_acc_vars    = OrderedDict() # Accumulation sums:        var => tf.Variable
96
        device.grad_acc_count   = None          # Accumulation counter:     tf.Variable
97
        device.grad_acc         = OrderedDict() # Accumulated gradients:    var => grad
98

99
        # Setup TensorFlow objects.
100
        with tfutil.absolute_name_scope(self.scope + "/Devices"), tf.device(device_name), tf.control_dependencies(None):
101
            if device_name not in self._shared_optimizers:
102
                optimizer_name = self.scope.replace("/", "_") + "_opt%d" % len(self._shared_optimizers)
103
                self._shared_optimizers[device_name] = self.optimizer_class(name=optimizer_name, learning_rate=self.learning_rate, **self.optimizer_kwargs)
104
            device.optimizer = self._shared_optimizers[device_name]
105
            if self.use_loss_scaling:
106
                device.loss_scaling_var = tf.Variable(np.float32(self.loss_scaling_init), trainable=False, name="loss_scaling_var")
107

108
        # Register device.
109
        self._devices[device_name] = device
110
        return device
111

112
    def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None:
113
        """Register the gradients of the given loss function with respect to the given variables.
114
        Intended to be called once per GPU."""
115
        tfutil.assert_tf_initialized()
116
        assert not self._updates_applied
117
        device = self._get_device(loss.device)
118

119
        # Validate trainables.
120
        if isinstance(trainable_vars, dict):
121
            trainable_vars = list(trainable_vars.values())  # allow passing in Network.trainables as vars
122
        assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1
123
        assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss])
124
        assert all(var.device == device.name for var in trainable_vars)
125

126
        # Validate shapes.
127
        if self._gradient_shapes is None:
128
            self._gradient_shapes = [var.shape.as_list() for var in trainable_vars]
129
        assert len(trainable_vars) == len(self._gradient_shapes)
130
        assert all(var.shape.as_list() == var_shape for var, var_shape in zip(trainable_vars, self._gradient_shapes))
131

132
        # Report memory usage if requested.
133
        deps = []
134
        if self._report_mem_usage:
135
            self._report_mem_usage = False
136
            try:
137
                with tf.name_scope(self.id + '_mem'), tf.device(device.name), tf.control_dependencies([loss]):
138
                    deps.append(autosummary.autosummary(self.id + "/mem_usage_gb", tf.contrib.memory_stats.BytesInUse() / 2**30))
139
            except tf.errors.NotFoundError:
140
                pass
141

142
        # Compute gradients.
143
        with tf.name_scope(self.id + "_grad"), tf.device(device.name), tf.control_dependencies(deps):
144
            loss = self.apply_loss_scaling(tf.cast(loss, tf.float32))
145
            gate = tf.train.Optimizer.GATE_NONE  # disable gating to reduce memory usage
146
            grad_list = device.optimizer.compute_gradients(loss=loss, var_list=trainable_vars, gate_gradients=gate)
147

148
        # Register gradients.
149
        for grad, var in grad_list:
150
            if var not in device.grad_raw:
151
                device.grad_raw[var] = []
152
            device.grad_raw[var].append(grad)
153

154
    def apply_updates(self, allow_no_op: bool = False) -> tf.Operation:
155
        """Construct training op to update the registered variables based on their gradients."""
156
        tfutil.assert_tf_initialized()
157
        assert not self._updates_applied
158
        self._updates_applied = True
159
        all_ops = []
160

161
        # Check for no-op.
162
        if allow_no_op and len(self._devices) == 0:
163
            with tfutil.absolute_name_scope(self.scope):
164
                return tf.no_op(name='TrainingOp')
165

166
        # Clean up gradients.
167
        for device_idx, device in enumerate(self._devices.values()):
168
            with tfutil.absolute_name_scope(self.scope + "/Clean%d" % device_idx), tf.device(device.name):
169
                for var, grad in device.grad_raw.items():
170

171
                    # Filter out disconnected gradients and convert to float32.
172
                    grad = [g for g in grad if g is not None]
173
                    grad = [tf.cast(g, tf.float32) for g in grad]
174

175
                    # Sum within the device.
176
                    if len(grad) == 0:
177
                        grad = tf.zeros(var.shape)  # No gradients => zero.
178
                    elif len(grad) == 1:
179
                        grad = grad[0]              # Single gradient => use as is.
180
                    else:
181
                        grad = tf.add_n(grad)       # Multiple gradients => sum.
182

183
                    # Scale as needed.
184
                    scale = 1.0 / len(device.grad_raw[var]) / len(self._devices)
185
                    scale = tf.constant(scale, dtype=tf.float32, name="scale")
186
                    if self.minibatch_multiplier is not None:
187
                        scale /= tf.cast(self.minibatch_multiplier, tf.float32)
188
                    scale = self.undo_loss_scaling(scale)
189
                    device.grad_clean[var] = grad * scale
190

191
        # Sum gradients across devices.
192
        if len(self._devices) > 1:
193
            with tfutil.absolute_name_scope(self.scope + "/Broadcast"), tf.device(None):
194
                if platform.system() == "Windows":    # Windows => NCCL ops are not available.
195
                    self._broadcast_fallback()
196
                elif tf.VERSION.startswith("1.15."):  # TF 1.15 => NCCL ops are broken: https://github.com/tensorflow/tensorflow/issues/41539
197
                    self._broadcast_fallback()
198
                else:                                 # Otherwise => NCCL ops are safe to use.
199
                    self._broadcast_nccl()
200

201
        # Apply updates separately on each device.
202
        for device_idx, device in enumerate(self._devices.values()):
203
            with tfutil.absolute_name_scope(self.scope + "/Apply%d" % device_idx), tf.device(device.name):
204
                # pylint: disable=cell-var-from-loop
205

206
                # Accumulate gradients over time.
207
                if self.minibatch_multiplier is None:
208
                    acc_ok = tf.constant(True, name='acc_ok')
209
                    device.grad_acc = OrderedDict(device.grad_clean)
210
                else:
211
                    # Create variables.
212
                    with tf.control_dependencies(None):
213
                        for var in device.grad_clean.keys():
214
                            device.grad_acc_vars[var] = tf.Variable(tf.zeros(var.shape), trainable=False, name="grad_acc_var")
215
                        device.grad_acc_count = tf.Variable(tf.zeros([]), trainable=False, name="grad_acc_count")
216

217
                    # Track counter.
218
                    count_cur = device.grad_acc_count + 1.0
219
                    count_inc_op = lambda: tf.assign(device.grad_acc_count, count_cur)
220
                    count_reset_op = lambda: tf.assign(device.grad_acc_count, tf.zeros([]))
221
                    acc_ok = (count_cur >= tf.cast(self.minibatch_multiplier, tf.float32))
222
                    all_ops.append(tf.cond(acc_ok, count_reset_op, count_inc_op))
223

224
                    # Track gradients.
225
                    for var, grad in device.grad_clean.items():
226
                        acc_var = device.grad_acc_vars[var]
227
                        acc_cur = acc_var + grad
228
                        device.grad_acc[var] = acc_cur
229
                        with tf.control_dependencies([acc_cur]):
230
                            acc_inc_op = lambda: tf.assign(acc_var, acc_cur)
231
                            acc_reset_op = lambda: tf.assign(acc_var, tf.zeros(var.shape))
232
                            all_ops.append(tf.cond(acc_ok, acc_reset_op, acc_inc_op))
233

234
                # No overflow => apply gradients.
235
                all_ok = tf.reduce_all(tf.stack([acc_ok] + [tf.reduce_all(tf.is_finite(g)) for g in device.grad_acc.values()]))
236
                apply_op = lambda: device.optimizer.apply_gradients([(tf.cast(grad, var.dtype), var) for var, grad in device.grad_acc.items()])
237
                all_ops.append(tf.cond(all_ok, apply_op, tf.no_op))
238

239
                # Adjust loss scaling.
240
                if self.use_loss_scaling:
241
                    ls_inc_op = lambda: tf.assign_add(device.loss_scaling_var, self.loss_scaling_inc)
242
                    ls_dec_op = lambda: tf.assign_sub(device.loss_scaling_var, self.loss_scaling_dec)
243
                    ls_update_op = lambda: tf.group(tf.cond(all_ok, ls_inc_op, ls_dec_op))
244
                    all_ops.append(tf.cond(acc_ok, ls_update_op, tf.no_op))
245

246
                # Last device => report statistics.
247
                if device_idx == len(self._devices) - 1:
248
                    all_ops.append(autosummary.autosummary(self.id + "/learning_rate", tf.convert_to_tensor(self.learning_rate)))
249
                    all_ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(all_ok, 0, 1), condition=acc_ok))
250
                    if self.use_loss_scaling:
251
                        all_ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", device.loss_scaling_var))
252

253
        # Initialize variables.
254
        self.reset_optimizer_state()
255
        if self.use_loss_scaling:
256
            tfutil.init_uninitialized_vars([device.loss_scaling_var for device in self._devices.values()])
257
        if self.minibatch_multiplier is not None:
258
            tfutil.run([var.initializer for device in self._devices.values() for var in list(device.grad_acc_vars.values()) + [device.grad_acc_count]])
259

260
        # Group everything into a single op.
261
        with tfutil.absolute_name_scope(self.scope):
262
            return tf.group(*all_ops, name="TrainingOp")
263

264
    def reset_optimizer_state(self) -> None:
265
        """Reset internal state of the underlying optimizer."""
266
        tfutil.assert_tf_initialized()
267
        tfutil.run([var.initializer for device in self._devices.values() for var in device.optimizer.variables()])
268

269
    def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]:
270
        """Get or create variable representing log2 of the current dynamic loss scaling factor."""
271
        return self._get_device(device).loss_scaling_var
272

273
    def apply_loss_scaling(self, value: TfExpression) -> TfExpression:
274
        """Apply dynamic loss scaling for the given expression."""
275
        assert tfutil.is_tf_expression(value)
276
        if not self.use_loss_scaling:
277
            return value
278
        return value * tfutil.exp2(self.get_loss_scaling_var(value.device))
279

280
    def undo_loss_scaling(self, value: TfExpression) -> TfExpression:
281
        """Undo the effect of dynamic loss scaling for the given expression."""
282
        assert tfutil.is_tf_expression(value)
283
        if not self.use_loss_scaling:
284
            return value
285
        return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type
286

287
    def _broadcast_nccl(self):
288
        """Sum gradients across devices using NCCL ops (fast path)."""
289
        from tensorflow.python.ops import nccl_ops # pylint: disable=no-name-in-module
290
        for all_vars in zip(*[device.grad_clean.keys() for device in self._devices.values()]):
291
            if any(x.shape.num_elements() > 0 for x in all_vars):
292
                all_grads = [device.grad_clean[var] for device, var in zip(self._devices.values(), all_vars)]
293
                all_grads = nccl_ops.all_sum(all_grads)
294
                for device, var, grad in zip(self._devices.values(), all_vars, all_grads):
295
                    device.grad_clean[var] = grad
296

297
    def _broadcast_fallback(self):
298
        """Sum gradients across devices using TensorFlow collective ops (slow fallback path)."""
299
        from tensorflow.python.ops import collective_ops # pylint: disable=no-name-in-module
300
        global _collective_ops_warning_printed, _collective_ops_group_key, _collective_ops_instance_key
301
        if all(x.shape.num_elements() == 0 for device in self._devices.values() for x in device.grad_clean.values()):
302
            return
303
        if not _collective_ops_warning_printed:
304
            print("------------------------------------------------------------------------")
305
            print("WARNING: Using slow fallback implementation for inter-GPU communication.")
306
            print("Please use TensorFlow 1.14 on Linux for optimal training performance.")
307
            print("------------------------------------------------------------------------")
308
            _collective_ops_warning_printed = True
309
        for device in self._devices.values():
310
            with tf.device(device.name):
311
                combo = [tf.reshape(x, [x.shape.num_elements()]) for x in device.grad_clean.values()]
312
                combo = tf.concat(combo, axis=0)
313
                combo = collective_ops.all_reduce(combo, merge_op='Add', final_op='Id',
314
                    group_size=len(self._devices), group_key=_collective_ops_group_key,
315
                    instance_key=_collective_ops_instance_key)
316
                cur_ofs = 0
317
                for var, grad_old in device.grad_clean.items():
318
                    grad_new = tf.reshape(combo[cur_ofs : cur_ofs + grad_old.shape.num_elements()], grad_old.shape)
319
                    cur_ofs += grad_old.shape.num_elements()
320
                    device.grad_clean[var] = grad_new
321
        _collective_ops_instance_key += 1
322

323

324
class SimpleAdam:
325
    """Simplified version of tf.train.AdamOptimizer that behaves identically when used with dnnlib.tflib.Optimizer."""
326

327
    def __init__(self, name="Adam", learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8):
328
        self.name = name
329
        self.learning_rate = learning_rate
330
        self.beta1 = beta1
331
        self.beta2 = beta2
332
        self.epsilon = epsilon
333
        self.all_state_vars = []
334

335
    def variables(self):
336
        return self.all_state_vars
337

338
    def compute_gradients(self, loss, var_list, gate_gradients=tf.train.Optimizer.GATE_NONE):
339
        assert gate_gradients == tf.train.Optimizer.GATE_NONE
340
        return list(zip(tf.gradients(loss, var_list), var_list))
341

342
    def apply_gradients(self, grads_and_vars):
343
        with tf.name_scope(self.name):
344
            state_vars = []
345
            update_ops = []
346

347
            # Adjust learning rate to deal with startup bias.
348
            with tf.control_dependencies(None):
349
                b1pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False)
350
                b2pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False)
351
                state_vars += [b1pow_var, b2pow_var]
352
            b1pow_new = b1pow_var * self.beta1
353
            b2pow_new = b2pow_var * self.beta2
354
            update_ops += [tf.assign(b1pow_var, b1pow_new), tf.assign(b2pow_var, b2pow_new)]
355
            lr_new = self.learning_rate * tf.sqrt(1 - b2pow_new) / (1 - b1pow_new)
356

357
            # Construct ops to update each variable.
358
            for grad, var in grads_and_vars:
359
                with tf.control_dependencies(None):
360
                    m_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
361
                    v_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
362
                    state_vars += [m_var, v_var]
363
                m_new = self.beta1 * m_var + (1 - self.beta1) * grad
364
                v_new = self.beta2 * v_var + (1 - self.beta2) * tf.square(grad)
365
                var_delta = lr_new * m_new / (tf.sqrt(v_new) + self.epsilon)
366
                update_ops += [tf.assign(m_var, m_new), tf.assign(v_var, v_new), tf.assign_sub(var, var_delta)]
367

368
            # Group everything together.
369
            self.all_state_vars += state_vars
370
            return tf.group(*update_ops)
371

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.