StyleFeatureEditor

Форк
0
331 строка · 11.2 Кб
1
# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
2
#
3
# NVIDIA CORPORATION and its licensors retain all intellectual property
4
# and proprietary rights in and to this software, related documentation
5
# and any modifications thereto.  Any use, reproduction, disclosure or
6
# distribution of this software and related documentation without an express
7
# license agreement from NVIDIA CORPORATION is strictly prohibited.
8

9
import contextlib
10
import re
11
import warnings
12

13
import dnnlib
14
import numpy as np
15
import torch
16

17

18
# ----------------------------------------------------------------------------
19
# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
20
# same constant is used multiple times.
21

22
_constant_cache = dict()
23

24

25
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
26
    value = np.asarray(value)
27
    if shape is not None:
28
        shape = tuple(shape)
29
    if dtype is None:
30
        dtype = torch.get_default_dtype()
31
    if device is None:
32
        device = torch.device("cpu")
33
    if memory_format is None:
34
        memory_format = torch.contiguous_format
35

36
    key = (
37
        value.shape,
38
        value.dtype,
39
        value.tobytes(),
40
        shape,
41
        dtype,
42
        device,
43
        memory_format,
44
    )
45
    tensor = _constant_cache.get(key, None)
46
    if tensor is None:
47
        tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
48
        if shape is not None:
49
            tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
50
        tensor = tensor.contiguous(memory_format=memory_format)
51
        _constant_cache[key] = tensor
52
    return tensor
53

54

55
# ----------------------------------------------------------------------------
56
# Replace NaN/Inf with specified numerical values.
57

58
try:
59
    nan_to_num = torch.nan_to_num  # 1.8.0a0
60
except AttributeError:
61

62
    def nan_to_num(
63
        input, nan=0.0, posinf=None, neginf=None, *, out=None
64
    ):  # pylint: disable=redefined-builtin
65
        assert isinstance(input, torch.Tensor)
66
        if posinf is None:
67
            posinf = torch.finfo(input.dtype).max
68
        if neginf is None:
69
            neginf = torch.finfo(input.dtype).min
70
        assert nan == 0
71
        return torch.clamp(
72
            input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out
73
        )
74

75

76
# ----------------------------------------------------------------------------
77
# Symbolic assert.
78

79
try:
80
    symbolic_assert = torch._assert  # 1.8.0a0 # pylint: disable=protected-access
81
except AttributeError:
82
    symbolic_assert = torch.Assert  # 1.7.0
83

84
# ----------------------------------------------------------------------------
85
# Context manager to suppress known warnings in torch.jit.trace().
86

87

88
class suppress_tracer_warnings(warnings.catch_warnings):
89
    def __enter__(self):
90
        super().__enter__()
91
        warnings.simplefilter("ignore", category=torch.jit.TracerWarning)
92
        return self
93

94

95
# ----------------------------------------------------------------------------
96
# Assert that the shape of a tensor matches the given list of integers.
97
# None indicates that the size of a dimension is allowed to vary.
98
# Performs symbolic assertion when used in torch.jit.trace().
99

100

101
def assert_shape(tensor, ref_shape):
102
    if tensor.ndim != len(ref_shape):
103
        raise AssertionError(
104
            f"Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}"
105
        )
106
    for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
107
        if ref_size is None:
108
            pass
109
        elif isinstance(ref_size, torch.Tensor):
110
            with suppress_tracer_warnings():  # as_tensor results are registered as constants
111
                symbolic_assert(
112
                    torch.equal(torch.as_tensor(size), ref_size),
113
                    f"Wrong size for dimension {idx}",
114
                )
115
        elif isinstance(size, torch.Tensor):
116
            with suppress_tracer_warnings():  # as_tensor results are registered as constants
117
                symbolic_assert(
118
                    torch.equal(size, torch.as_tensor(ref_size)),
119
                    f"Wrong size for dimension {idx}: expected {ref_size}",
120
                )
121
        elif size != ref_size:
122
            raise AssertionError(
123
                f"Wrong size for dimension {idx}: got {size}, expected {ref_size}"
124
            )
125

126

127
# ----------------------------------------------------------------------------
128
# Function decorator that calls torch.autograd.profiler.record_function().
129

130

131
def profiled_function(fn):
132
    def decorator(*args, **kwargs):
133
        with torch.autograd.profiler.record_function(fn.__name__):
134
            return fn(*args, **kwargs)
135

136
    decorator.__name__ = fn.__name__
137
    return decorator
138

139

140
# ----------------------------------------------------------------------------
141
# Sampler for torch.utils.data.DataLoader that loops over the dataset
142
# indefinitely, shuffling items as it goes.
143

144

145
class InfiniteSampler(torch.utils.data.Sampler):
146
    def __init__(
147
        self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5
148
    ):
149
        assert len(dataset) > 0
150
        assert num_replicas > 0
151
        assert 0 <= rank < num_replicas
152
        assert 0 <= window_size <= 1
153
        super().__init__(dataset)
154
        self.dataset = dataset
155
        self.rank = rank
156
        self.num_replicas = num_replicas
157
        self.shuffle = shuffle
158
        self.seed = seed
159
        self.window_size = window_size
160

161
    def __iter__(self):
162
        order = np.arange(len(self.dataset))
163
        rnd = None
164
        window = 0
165
        if self.shuffle:
166
            rnd = np.random.RandomState(self.seed)
167
            rnd.shuffle(order)
168
            window = int(np.rint(order.size * self.window_size))
169

170
        idx = 0
171
        while True:
172
            i = idx % order.size
173
            if idx % self.num_replicas == self.rank:
174
                yield order[i]
175
            if window >= 2:
176
                j = (i - rnd.randint(window)) % order.size
177
                order[i], order[j] = order[j], order[i]
178
            idx += 1
179

180

181
# ----------------------------------------------------------------------------
182
# Utilities for operating with torch.nn.Module parameters and buffers.
183

184

185
def params_and_buffers(module):
186
    assert isinstance(module, torch.nn.Module)
187
    return list(module.parameters()) + list(module.buffers())
188

189

190
def named_params_and_buffers(module):
191
    assert isinstance(module, torch.nn.Module)
192
    return list(module.named_parameters()) + list(module.named_buffers())
193

194

195
def copy_params_and_buffers(src_module, dst_module, require_all=False):
196
    assert isinstance(src_module, torch.nn.Module)
197
    assert isinstance(dst_module, torch.nn.Module)
198
    src_tensors = {
199
        name: tensor for name, tensor in named_params_and_buffers(src_module)
200
    }
201
    for name, tensor in named_params_and_buffers(dst_module):
202
        assert (name in src_tensors) or (not require_all)
203
        if name in src_tensors:
204
            tensor.copy_(src_tensors[name].detach()).requires_grad_(
205
                tensor.requires_grad
206
            )
207

208

209
# ----------------------------------------------------------------------------
210
# Context manager for easily enabling/disabling DistributedDataParallel
211
# synchronization.
212

213

214
@contextlib.contextmanager
215
def ddp_sync(module, sync):
216
    assert isinstance(module, torch.nn.Module)
217
    if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
218
        yield
219
    else:
220
        with module.no_sync():
221
            yield
222

223

224
# ----------------------------------------------------------------------------
225
# Check DistributedDataParallel consistency across processes.
226

227

228
def check_ddp_consistency(module, ignore_regex=None):
229
    assert isinstance(module, torch.nn.Module)
230
    for name, tensor in named_params_and_buffers(module):
231
        fullname = type(module).__name__ + "." + name
232
        if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
233
            continue
234
        tensor = tensor.detach()
235
        other = tensor.clone()
236
        torch.distributed.broadcast(tensor=other, src=0)
237
        assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname
238

239

240
# ----------------------------------------------------------------------------
241
# Print summary table of module hierarchy.
242

243

244
def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
245
    assert isinstance(module, torch.nn.Module)
246
    assert not isinstance(module, torch.jit.ScriptModule)
247
    assert isinstance(inputs, (tuple, list))
248

249
    # Register hooks.
250
    entries = []
251
    nesting = [0]
252

253
    def pre_hook(_mod, _inputs):
254
        nesting[0] += 1
255

256
    def post_hook(mod, _inputs, outputs):
257
        nesting[0] -= 1
258
        if nesting[0] <= max_nesting:
259
            outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
260
            outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
261
            entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
262

263
    hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
264
    hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
265

266
    # Run module.
267
    outputs = module(*inputs)
268
    for hook in hooks:
269
        hook.remove()
270

271
    # Identify unique outputs, parameters, and buffers.
272
    tensors_seen = set()
273
    for e in entries:
274
        e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
275
        e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
276
        e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
277
        tensors_seen |= {
278
            id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs
279
        }
280

281
    # Filter out redundant entries.
282
    if skip_redundant:
283
        entries = [
284
            e
285
            for e in entries
286
            if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)
287
        ]
288

289
    # Construct table.
290
    rows = [
291
        [type(module).__name__, "Parameters", "Buffers", "Output shape", "Datatype"]
292
    ]
293
    rows += [["---"] * len(rows[0])]
294
    param_total = 0
295
    buffer_total = 0
296
    submodule_names = {mod: name for name, mod in module.named_modules()}
297
    for e in entries:
298
        name = "<top-level>" if e.mod is module else submodule_names[e.mod]
299
        param_size = sum(t.numel() for t in e.unique_params)
300
        buffer_size = sum(t.numel() for t in e.unique_buffers)
301
        output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs]
302
        output_dtypes = [str(t.dtype).split(".")[-1] for t in e.outputs]
303
        rows += [
304
            [
305
                name + (":0" if len(e.outputs) >= 2 else ""),
306
                str(param_size) if param_size else "-",
307
                str(buffer_size) if buffer_size else "-",
308
                (output_shapes + ["-"])[0],
309
                (output_dtypes + ["-"])[0],
310
            ]
311
        ]
312
        for idx in range(1, len(e.outputs)):
313
            rows += [
314
                [name + f":{idx}", "-", "-", output_shapes[idx], output_dtypes[idx]]
315
            ]
316
        param_total += param_size
317
        buffer_total += buffer_size
318
    rows += [["---"] * len(rows[0])]
319
    rows += [["Total", str(param_total), str(buffer_total), "-", "-"]]
320

321
    # Print table.
322
    widths = [max(len(cell) for cell in column) for column in zip(*rows)]
323
    print()
324
    for row in rows:
325
        print(
326
            "  ".join(
327
                cell + " " * (width - len(cell)) for cell, width in zip(row, widths)
328
            )
329
        )
330
    print()
331
    return outputs
332

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.