StyleFeatureEditor
370 строк · 19.8 Кб
1# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2#
3# This work is made available under the Nvidia Source Code License-NC.
4# To view a copy of this license, visit
5# https://nvlabs.github.io/stylegan2/license.html
6
7"""Helper wrapper for a Tensorflow optimizer."""
8
9import platform
10import numpy as np
11import tensorflow as tf
12
13from collections import OrderedDict
14from typing import List, Union
15
16from . import autosummary
17from . import tfutil
18from .. import util
19
20from .tfutil import TfExpression, TfExpressionEx
21
22_collective_ops_warning_printed = False
23_collective_ops_group_key = 831766147
24_collective_ops_instance_key = 436340067
25
26class Optimizer:
27"""A Wrapper for tf.train.Optimizer.
28
29Automatically takes care of:
30- Gradient averaging for multi-GPU training.
31- Gradient accumulation for arbitrarily large minibatches.
32- Dynamic loss scaling and typecasts for FP16 training.
33- Ignoring corrupted gradients that contain NaNs/Infs.
34- Reporting statistics.
35- Well-chosen default settings.
36"""
37
38def __init__(self,
39name: str = "Train", # Name string that will appear in TensorFlow graph.
40tf_optimizer: str = "tf.train.AdamOptimizer", # Underlying optimizer class.
41learning_rate: TfExpressionEx = 0.001, # Learning rate. Can vary over time.
42minibatch_multiplier: TfExpressionEx = None, # Treat N consecutive minibatches as one by accumulating gradients.
43share: "Optimizer" = None, # Share internal state with a previously created optimizer?
44use_loss_scaling: bool = False, # Enable dynamic loss scaling for robust mixed-precision training?
45loss_scaling_init: float = 64.0, # Log2 of initial loss scaling factor.
46loss_scaling_inc: float = 0.0005, # Log2 of per-minibatch loss scaling increment when there is no overflow.
47loss_scaling_dec: float = 1.0, # Log2 of per-minibatch loss scaling decrement when there is an overflow.
48report_mem_usage: bool = False, # Report fine-grained memory usage statistics in TensorBoard?
49**kwargs):
50
51# Public fields.
52self.name = name
53self.learning_rate = learning_rate
54self.minibatch_multiplier = minibatch_multiplier
55self.id = self.name.replace("/", ".")
56self.scope = tf.get_default_graph().unique_name(self.id)
57self.optimizer_class = util.get_obj_by_name(tf_optimizer)
58self.optimizer_kwargs = dict(kwargs)
59self.use_loss_scaling = use_loss_scaling
60self.loss_scaling_init = loss_scaling_init
61self.loss_scaling_inc = loss_scaling_inc
62self.loss_scaling_dec = loss_scaling_dec
63
64# Private fields.
65self._updates_applied = False
66self._devices = OrderedDict() # device_name => EasyDict()
67self._shared_optimizers = OrderedDict() # device_name => optimizer_class
68self._gradient_shapes = None # [shape, ...]
69self._report_mem_usage = report_mem_usage
70
71# Validate arguments.
72assert callable(self.optimizer_class)
73
74# Share internal state if requested.
75if share is not None:
76assert isinstance(share, Optimizer)
77assert self.optimizer_class is share.optimizer_class
78assert self.learning_rate is share.learning_rate
79assert self.optimizer_kwargs == share.optimizer_kwargs
80self._shared_optimizers = share._shared_optimizers # pylint: disable=protected-access
81
82def _get_device(self, device_name: str):
83"""Get internal state for the given TensorFlow device."""
84tfutil.assert_tf_initialized()
85if device_name in self._devices:
86return self._devices[device_name]
87
88# Initialize fields.
89device = util.EasyDict()
90device.name = device_name
91device.optimizer = None # Underlying optimizer: optimizer_class
92device.loss_scaling_var = None # Log2 of loss scaling: tf.Variable
93device.grad_raw = OrderedDict() # Raw gradients: var => [grad, ...]
94device.grad_clean = OrderedDict() # Clean gradients: var => grad
95device.grad_acc_vars = OrderedDict() # Accumulation sums: var => tf.Variable
96device.grad_acc_count = None # Accumulation counter: tf.Variable
97device.grad_acc = OrderedDict() # Accumulated gradients: var => grad
98
99# Setup TensorFlow objects.
100with tfutil.absolute_name_scope(self.scope + "/Devices"), tf.device(device_name), tf.control_dependencies(None):
101if device_name not in self._shared_optimizers:
102optimizer_name = self.scope.replace("/", "_") + "_opt%d" % len(self._shared_optimizers)
103self._shared_optimizers[device_name] = self.optimizer_class(name=optimizer_name, learning_rate=self.learning_rate, **self.optimizer_kwargs)
104device.optimizer = self._shared_optimizers[device_name]
105if self.use_loss_scaling:
106device.loss_scaling_var = tf.Variable(np.float32(self.loss_scaling_init), trainable=False, name="loss_scaling_var")
107
108# Register device.
109self._devices[device_name] = device
110return device
111
112def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None:
113"""Register the gradients of the given loss function with respect to the given variables.
114Intended to be called once per GPU."""
115tfutil.assert_tf_initialized()
116assert not self._updates_applied
117device = self._get_device(loss.device)
118
119# Validate trainables.
120if isinstance(trainable_vars, dict):
121trainable_vars = list(trainable_vars.values()) # allow passing in Network.trainables as vars
122assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1
123assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss])
124assert all(var.device == device.name for var in trainable_vars)
125
126# Validate shapes.
127if self._gradient_shapes is None:
128self._gradient_shapes = [var.shape.as_list() for var in trainable_vars]
129assert len(trainable_vars) == len(self._gradient_shapes)
130assert all(var.shape.as_list() == var_shape for var, var_shape in zip(trainable_vars, self._gradient_shapes))
131
132# Report memory usage if requested.
133deps = []
134if self._report_mem_usage:
135self._report_mem_usage = False
136try:
137with tf.name_scope(self.id + '_mem'), tf.device(device.name), tf.control_dependencies([loss]):
138deps.append(autosummary.autosummary(self.id + "/mem_usage_gb", tf.contrib.memory_stats.BytesInUse() / 2**30))
139except tf.errors.NotFoundError:
140pass
141
142# Compute gradients.
143with tf.name_scope(self.id + "_grad"), tf.device(device.name), tf.control_dependencies(deps):
144loss = self.apply_loss_scaling(tf.cast(loss, tf.float32))
145gate = tf.train.Optimizer.GATE_NONE # disable gating to reduce memory usage
146grad_list = device.optimizer.compute_gradients(loss=loss, var_list=trainable_vars, gate_gradients=gate)
147
148# Register gradients.
149for grad, var in grad_list:
150if var not in device.grad_raw:
151device.grad_raw[var] = []
152device.grad_raw[var].append(grad)
153
154def apply_updates(self, allow_no_op: bool = False) -> tf.Operation:
155"""Construct training op to update the registered variables based on their gradients."""
156tfutil.assert_tf_initialized()
157assert not self._updates_applied
158self._updates_applied = True
159all_ops = []
160
161# Check for no-op.
162if allow_no_op and len(self._devices) == 0:
163with tfutil.absolute_name_scope(self.scope):
164return tf.no_op(name='TrainingOp')
165
166# Clean up gradients.
167for device_idx, device in enumerate(self._devices.values()):
168with tfutil.absolute_name_scope(self.scope + "/Clean%d" % device_idx), tf.device(device.name):
169for var, grad in device.grad_raw.items():
170
171# Filter out disconnected gradients and convert to float32.
172grad = [g for g in grad if g is not None]
173grad = [tf.cast(g, tf.float32) for g in grad]
174
175# Sum within the device.
176if len(grad) == 0:
177grad = tf.zeros(var.shape) # No gradients => zero.
178elif len(grad) == 1:
179grad = grad[0] # Single gradient => use as is.
180else:
181grad = tf.add_n(grad) # Multiple gradients => sum.
182
183# Scale as needed.
184scale = 1.0 / len(device.grad_raw[var]) / len(self._devices)
185scale = tf.constant(scale, dtype=tf.float32, name="scale")
186if self.minibatch_multiplier is not None:
187scale /= tf.cast(self.minibatch_multiplier, tf.float32)
188scale = self.undo_loss_scaling(scale)
189device.grad_clean[var] = grad * scale
190
191# Sum gradients across devices.
192if len(self._devices) > 1:
193with tfutil.absolute_name_scope(self.scope + "/Broadcast"), tf.device(None):
194if platform.system() == "Windows": # Windows => NCCL ops are not available.
195self._broadcast_fallback()
196elif tf.VERSION.startswith("1.15."): # TF 1.15 => NCCL ops are broken: https://github.com/tensorflow/tensorflow/issues/41539
197self._broadcast_fallback()
198else: # Otherwise => NCCL ops are safe to use.
199self._broadcast_nccl()
200
201# Apply updates separately on each device.
202for device_idx, device in enumerate(self._devices.values()):
203with tfutil.absolute_name_scope(self.scope + "/Apply%d" % device_idx), tf.device(device.name):
204# pylint: disable=cell-var-from-loop
205
206# Accumulate gradients over time.
207if self.minibatch_multiplier is None:
208acc_ok = tf.constant(True, name='acc_ok')
209device.grad_acc = OrderedDict(device.grad_clean)
210else:
211# Create variables.
212with tf.control_dependencies(None):
213for var in device.grad_clean.keys():
214device.grad_acc_vars[var] = tf.Variable(tf.zeros(var.shape), trainable=False, name="grad_acc_var")
215device.grad_acc_count = tf.Variable(tf.zeros([]), trainable=False, name="grad_acc_count")
216
217# Track counter.
218count_cur = device.grad_acc_count + 1.0
219count_inc_op = lambda: tf.assign(device.grad_acc_count, count_cur)
220count_reset_op = lambda: tf.assign(device.grad_acc_count, tf.zeros([]))
221acc_ok = (count_cur >= tf.cast(self.minibatch_multiplier, tf.float32))
222all_ops.append(tf.cond(acc_ok, count_reset_op, count_inc_op))
223
224# Track gradients.
225for var, grad in device.grad_clean.items():
226acc_var = device.grad_acc_vars[var]
227acc_cur = acc_var + grad
228device.grad_acc[var] = acc_cur
229with tf.control_dependencies([acc_cur]):
230acc_inc_op = lambda: tf.assign(acc_var, acc_cur)
231acc_reset_op = lambda: tf.assign(acc_var, tf.zeros(var.shape))
232all_ops.append(tf.cond(acc_ok, acc_reset_op, acc_inc_op))
233
234# No overflow => apply gradients.
235all_ok = tf.reduce_all(tf.stack([acc_ok] + [tf.reduce_all(tf.is_finite(g)) for g in device.grad_acc.values()]))
236apply_op = lambda: device.optimizer.apply_gradients([(tf.cast(grad, var.dtype), var) for var, grad in device.grad_acc.items()])
237all_ops.append(tf.cond(all_ok, apply_op, tf.no_op))
238
239# Adjust loss scaling.
240if self.use_loss_scaling:
241ls_inc_op = lambda: tf.assign_add(device.loss_scaling_var, self.loss_scaling_inc)
242ls_dec_op = lambda: tf.assign_sub(device.loss_scaling_var, self.loss_scaling_dec)
243ls_update_op = lambda: tf.group(tf.cond(all_ok, ls_inc_op, ls_dec_op))
244all_ops.append(tf.cond(acc_ok, ls_update_op, tf.no_op))
245
246# Last device => report statistics.
247if device_idx == len(self._devices) - 1:
248all_ops.append(autosummary.autosummary(self.id + "/learning_rate", tf.convert_to_tensor(self.learning_rate)))
249all_ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(all_ok, 0, 1), condition=acc_ok))
250if self.use_loss_scaling:
251all_ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", device.loss_scaling_var))
252
253# Initialize variables.
254self.reset_optimizer_state()
255if self.use_loss_scaling:
256tfutil.init_uninitialized_vars([device.loss_scaling_var for device in self._devices.values()])
257if self.minibatch_multiplier is not None:
258tfutil.run([var.initializer for device in self._devices.values() for var in list(device.grad_acc_vars.values()) + [device.grad_acc_count]])
259
260# Group everything into a single op.
261with tfutil.absolute_name_scope(self.scope):
262return tf.group(*all_ops, name="TrainingOp")
263
264def reset_optimizer_state(self) -> None:
265"""Reset internal state of the underlying optimizer."""
266tfutil.assert_tf_initialized()
267tfutil.run([var.initializer for device in self._devices.values() for var in device.optimizer.variables()])
268
269def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]:
270"""Get or create variable representing log2 of the current dynamic loss scaling factor."""
271return self._get_device(device).loss_scaling_var
272
273def apply_loss_scaling(self, value: TfExpression) -> TfExpression:
274"""Apply dynamic loss scaling for the given expression."""
275assert tfutil.is_tf_expression(value)
276if not self.use_loss_scaling:
277return value
278return value * tfutil.exp2(self.get_loss_scaling_var(value.device))
279
280def undo_loss_scaling(self, value: TfExpression) -> TfExpression:
281"""Undo the effect of dynamic loss scaling for the given expression."""
282assert tfutil.is_tf_expression(value)
283if not self.use_loss_scaling:
284return value
285return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type
286
287def _broadcast_nccl(self):
288"""Sum gradients across devices using NCCL ops (fast path)."""
289from tensorflow.python.ops import nccl_ops # pylint: disable=no-name-in-module
290for all_vars in zip(*[device.grad_clean.keys() for device in self._devices.values()]):
291if any(x.shape.num_elements() > 0 for x in all_vars):
292all_grads = [device.grad_clean[var] for device, var in zip(self._devices.values(), all_vars)]
293all_grads = nccl_ops.all_sum(all_grads)
294for device, var, grad in zip(self._devices.values(), all_vars, all_grads):
295device.grad_clean[var] = grad
296
297def _broadcast_fallback(self):
298"""Sum gradients across devices using TensorFlow collective ops (slow fallback path)."""
299from tensorflow.python.ops import collective_ops # pylint: disable=no-name-in-module
300global _collective_ops_warning_printed, _collective_ops_group_key, _collective_ops_instance_key
301if all(x.shape.num_elements() == 0 for device in self._devices.values() for x in device.grad_clean.values()):
302return
303if not _collective_ops_warning_printed:
304print("------------------------------------------------------------------------")
305print("WARNING: Using slow fallback implementation for inter-GPU communication.")
306print("Please use TensorFlow 1.14 on Linux for optimal training performance.")
307print("------------------------------------------------------------------------")
308_collective_ops_warning_printed = True
309for device in self._devices.values():
310with tf.device(device.name):
311combo = [tf.reshape(x, [x.shape.num_elements()]) for x in device.grad_clean.values()]
312combo = tf.concat(combo, axis=0)
313combo = collective_ops.all_reduce(combo, merge_op='Add', final_op='Id',
314group_size=len(self._devices), group_key=_collective_ops_group_key,
315instance_key=_collective_ops_instance_key)
316cur_ofs = 0
317for var, grad_old in device.grad_clean.items():
318grad_new = tf.reshape(combo[cur_ofs : cur_ofs + grad_old.shape.num_elements()], grad_old.shape)
319cur_ofs += grad_old.shape.num_elements()
320device.grad_clean[var] = grad_new
321_collective_ops_instance_key += 1
322
323
324class SimpleAdam:
325"""Simplified version of tf.train.AdamOptimizer that behaves identically when used with dnnlib.tflib.Optimizer."""
326
327def __init__(self, name="Adam", learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8):
328self.name = name
329self.learning_rate = learning_rate
330self.beta1 = beta1
331self.beta2 = beta2
332self.epsilon = epsilon
333self.all_state_vars = []
334
335def variables(self):
336return self.all_state_vars
337
338def compute_gradients(self, loss, var_list, gate_gradients=tf.train.Optimizer.GATE_NONE):
339assert gate_gradients == tf.train.Optimizer.GATE_NONE
340return list(zip(tf.gradients(loss, var_list), var_list))
341
342def apply_gradients(self, grads_and_vars):
343with tf.name_scope(self.name):
344state_vars = []
345update_ops = []
346
347# Adjust learning rate to deal with startup bias.
348with tf.control_dependencies(None):
349b1pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False)
350b2pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False)
351state_vars += [b1pow_var, b2pow_var]
352b1pow_new = b1pow_var * self.beta1
353b2pow_new = b2pow_var * self.beta2
354update_ops += [tf.assign(b1pow_var, b1pow_new), tf.assign(b2pow_var, b2pow_new)]
355lr_new = self.learning_rate * tf.sqrt(1 - b2pow_new) / (1 - b1pow_new)
356
357# Construct ops to update each variable.
358for grad, var in grads_and_vars:
359with tf.control_dependencies(None):
360m_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
361v_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
362state_vars += [m_var, v_var]
363m_new = self.beta1 * m_var + (1 - self.beta1) * grad
364v_new = self.beta2 * v_var + (1 - self.beta2) * tf.square(grad)
365var_delta = lr_new * m_new / (tf.sqrt(v_new) + self.epsilon)
366update_ops += [tf.assign(m_var, m_new), tf.assign(v_var, v_new), tf.assign_sub(var, var_delta)]
367
368# Group everything together.
369self.all_state_vars += state_vars
370return tf.group(*update_ops)
371