StyleFeatureEditor
331 строка · 11.2 Кб
1# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2#
3# NVIDIA CORPORATION and its licensors retain all intellectual property
4# and proprietary rights in and to this software, related documentation
5# and any modifications thereto. Any use, reproduction, disclosure or
6# distribution of this software and related documentation without an express
7# license agreement from NVIDIA CORPORATION is strictly prohibited.
8
9import contextlib
10import re
11import warnings
12
13import dnnlib
14import numpy as np
15import torch
16
17
18# ----------------------------------------------------------------------------
19# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
20# same constant is used multiple times.
21
22_constant_cache = dict()
23
24
25def constant(value, shape=None, dtype=None, device=None, memory_format=None):
26value = np.asarray(value)
27if shape is not None:
28shape = tuple(shape)
29if dtype is None:
30dtype = torch.get_default_dtype()
31if device is None:
32device = torch.device("cpu")
33if memory_format is None:
34memory_format = torch.contiguous_format
35
36key = (
37value.shape,
38value.dtype,
39value.tobytes(),
40shape,
41dtype,
42device,
43memory_format,
44)
45tensor = _constant_cache.get(key, None)
46if tensor is None:
47tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
48if shape is not None:
49tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
50tensor = tensor.contiguous(memory_format=memory_format)
51_constant_cache[key] = tensor
52return tensor
53
54
55# ----------------------------------------------------------------------------
56# Replace NaN/Inf with specified numerical values.
57
58try:
59nan_to_num = torch.nan_to_num # 1.8.0a0
60except AttributeError:
61
62def nan_to_num(
63input, nan=0.0, posinf=None, neginf=None, *, out=None
64): # pylint: disable=redefined-builtin
65assert isinstance(input, torch.Tensor)
66if posinf is None:
67posinf = torch.finfo(input.dtype).max
68if neginf is None:
69neginf = torch.finfo(input.dtype).min
70assert nan == 0
71return torch.clamp(
72input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out
73)
74
75
76# ----------------------------------------------------------------------------
77# Symbolic assert.
78
79try:
80symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
81except AttributeError:
82symbolic_assert = torch.Assert # 1.7.0
83
84# ----------------------------------------------------------------------------
85# Context manager to suppress known warnings in torch.jit.trace().
86
87
88class suppress_tracer_warnings(warnings.catch_warnings):
89def __enter__(self):
90super().__enter__()
91warnings.simplefilter("ignore", category=torch.jit.TracerWarning)
92return self
93
94
95# ----------------------------------------------------------------------------
96# Assert that the shape of a tensor matches the given list of integers.
97# None indicates that the size of a dimension is allowed to vary.
98# Performs symbolic assertion when used in torch.jit.trace().
99
100
101def assert_shape(tensor, ref_shape):
102if tensor.ndim != len(ref_shape):
103raise AssertionError(
104f"Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}"
105)
106for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
107if ref_size is None:
108pass
109elif isinstance(ref_size, torch.Tensor):
110with suppress_tracer_warnings(): # as_tensor results are registered as constants
111symbolic_assert(
112torch.equal(torch.as_tensor(size), ref_size),
113f"Wrong size for dimension {idx}",
114)
115elif isinstance(size, torch.Tensor):
116with suppress_tracer_warnings(): # as_tensor results are registered as constants
117symbolic_assert(
118torch.equal(size, torch.as_tensor(ref_size)),
119f"Wrong size for dimension {idx}: expected {ref_size}",
120)
121elif size != ref_size:
122raise AssertionError(
123f"Wrong size for dimension {idx}: got {size}, expected {ref_size}"
124)
125
126
127# ----------------------------------------------------------------------------
128# Function decorator that calls torch.autograd.profiler.record_function().
129
130
131def profiled_function(fn):
132def decorator(*args, **kwargs):
133with torch.autograd.profiler.record_function(fn.__name__):
134return fn(*args, **kwargs)
135
136decorator.__name__ = fn.__name__
137return decorator
138
139
140# ----------------------------------------------------------------------------
141# Sampler for torch.utils.data.DataLoader that loops over the dataset
142# indefinitely, shuffling items as it goes.
143
144
145class InfiniteSampler(torch.utils.data.Sampler):
146def __init__(
147self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5
148):
149assert len(dataset) > 0
150assert num_replicas > 0
151assert 0 <= rank < num_replicas
152assert 0 <= window_size <= 1
153super().__init__(dataset)
154self.dataset = dataset
155self.rank = rank
156self.num_replicas = num_replicas
157self.shuffle = shuffle
158self.seed = seed
159self.window_size = window_size
160
161def __iter__(self):
162order = np.arange(len(self.dataset))
163rnd = None
164window = 0
165if self.shuffle:
166rnd = np.random.RandomState(self.seed)
167rnd.shuffle(order)
168window = int(np.rint(order.size * self.window_size))
169
170idx = 0
171while True:
172i = idx % order.size
173if idx % self.num_replicas == self.rank:
174yield order[i]
175if window >= 2:
176j = (i - rnd.randint(window)) % order.size
177order[i], order[j] = order[j], order[i]
178idx += 1
179
180
181# ----------------------------------------------------------------------------
182# Utilities for operating with torch.nn.Module parameters and buffers.
183
184
185def params_and_buffers(module):
186assert isinstance(module, torch.nn.Module)
187return list(module.parameters()) + list(module.buffers())
188
189
190def named_params_and_buffers(module):
191assert isinstance(module, torch.nn.Module)
192return list(module.named_parameters()) + list(module.named_buffers())
193
194
195def copy_params_and_buffers(src_module, dst_module, require_all=False):
196assert isinstance(src_module, torch.nn.Module)
197assert isinstance(dst_module, torch.nn.Module)
198src_tensors = {
199name: tensor for name, tensor in named_params_and_buffers(src_module)
200}
201for name, tensor in named_params_and_buffers(dst_module):
202assert (name in src_tensors) or (not require_all)
203if name in src_tensors:
204tensor.copy_(src_tensors[name].detach()).requires_grad_(
205tensor.requires_grad
206)
207
208
209# ----------------------------------------------------------------------------
210# Context manager for easily enabling/disabling DistributedDataParallel
211# synchronization.
212
213
214@contextlib.contextmanager
215def ddp_sync(module, sync):
216assert isinstance(module, torch.nn.Module)
217if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
218yield
219else:
220with module.no_sync():
221yield
222
223
224# ----------------------------------------------------------------------------
225# Check DistributedDataParallel consistency across processes.
226
227
228def check_ddp_consistency(module, ignore_regex=None):
229assert isinstance(module, torch.nn.Module)
230for name, tensor in named_params_and_buffers(module):
231fullname = type(module).__name__ + "." + name
232if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
233continue
234tensor = tensor.detach()
235other = tensor.clone()
236torch.distributed.broadcast(tensor=other, src=0)
237assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname
238
239
240# ----------------------------------------------------------------------------
241# Print summary table of module hierarchy.
242
243
244def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
245assert isinstance(module, torch.nn.Module)
246assert not isinstance(module, torch.jit.ScriptModule)
247assert isinstance(inputs, (tuple, list))
248
249# Register hooks.
250entries = []
251nesting = [0]
252
253def pre_hook(_mod, _inputs):
254nesting[0] += 1
255
256def post_hook(mod, _inputs, outputs):
257nesting[0] -= 1
258if nesting[0] <= max_nesting:
259outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
260outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
261entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
262
263hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
264hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
265
266# Run module.
267outputs = module(*inputs)
268for hook in hooks:
269hook.remove()
270
271# Identify unique outputs, parameters, and buffers.
272tensors_seen = set()
273for e in entries:
274e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
275e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
276e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
277tensors_seen |= {
278id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs
279}
280
281# Filter out redundant entries.
282if skip_redundant:
283entries = [
284e
285for e in entries
286if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)
287]
288
289# Construct table.
290rows = [
291[type(module).__name__, "Parameters", "Buffers", "Output shape", "Datatype"]
292]
293rows += [["---"] * len(rows[0])]
294param_total = 0
295buffer_total = 0
296submodule_names = {mod: name for name, mod in module.named_modules()}
297for e in entries:
298name = "<top-level>" if e.mod is module else submodule_names[e.mod]
299param_size = sum(t.numel() for t in e.unique_params)
300buffer_size = sum(t.numel() for t in e.unique_buffers)
301output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs]
302output_dtypes = [str(t.dtype).split(".")[-1] for t in e.outputs]
303rows += [
304[
305name + (":0" if len(e.outputs) >= 2 else ""),
306str(param_size) if param_size else "-",
307str(buffer_size) if buffer_size else "-",
308(output_shapes + ["-"])[0],
309(output_dtypes + ["-"])[0],
310]
311]
312for idx in range(1, len(e.outputs)):
313rows += [
314[name + f":{idx}", "-", "-", output_shapes[idx], output_dtypes[idx]]
315]
316param_total += param_size
317buffer_total += buffer_size
318rows += [["---"] * len(rows[0])]
319rows += [["Total", str(param_total), str(buffer_total), "-", "-"]]
320
321# Print table.
322widths = [max(len(cell) for cell in column) for column in zip(*rows)]
323print()
324for row in rows:
325print(
326" ".join(
327cell + " " * (width - len(cell)) for cell, width in zip(row, widths)
328)
329)
330print()
331return outputs
332