pytorch

Форк
0
/
__init__.py 
599 строк · 24.5 Кб
1
# mypy: allow-untyped-defs
2
"""
3
``torch.autograd`` provides classes and functions implementing automatic differentiation of arbitrary scalar valued functions.
4

5
It requires minimal changes to the existing code - you only need to declare :class:`Tensor` s
6
for which gradients should be computed with the ``requires_grad=True`` keyword.
7
As of now, we only support autograd for floating point :class:`Tensor` types (
8
half, float, double and bfloat16) and complex :class:`Tensor` types (cfloat, cdouble).
9
"""
10

11
import warnings
12
from typing import cast, List, Optional, Sequence, Tuple, Union
13

14
import torch
15
from torch import _vmap_internals
16
from torch.overrides import handle_torch_function, has_torch_function, is_tensor_like
17
from torch.types import _size, _TensorOrTensors, _TensorOrTensorsOrGradEdge
18

19
from . import forward_ad, functional, graph
20
from .anomaly_mode import detect_anomaly, set_detect_anomaly
21
from .function import Function, NestedIOFunction
22
from .grad_mode import (
23
    _force_original_view_tracking,
24
    _unsafe_preserve_version_counter,
25
    enable_grad,
26
    inference_mode,
27
    no_grad,
28
    set_grad_enabled,
29
    set_multithreading_enabled,
30
)
31
from .gradcheck import gradcheck, gradgradcheck
32
from .graph import _engine_run_backward
33
from .variable import Variable
34

35

36
__all__ = [
37
    "Variable",
38
    "Function",
39
    "backward",
40
    "grad_mode",
41
    "NestedIOFunction",
42
    "detect_anomaly",
43
    "enable_grad",
44
    "grad",
45
    "gradcheck",
46
    "gradgradcheck",
47
    "inference_mode",
48
    "no_grad",
49
    "set_detect_anomaly",
50
    "set_grad_enabled",
51
    "set_multithreading_enabled",
52
    "variable",
53
]
54

55
_OptionalTensor = Optional[torch.Tensor]
56
_ShapeorNestedShape = Union[_size, Sequence[_size], torch.Tensor]
57

58

59
def _calculate_shape(
60
    output: Union[torch.Tensor, graph.GradientEdge],
61
    grad: torch.Tensor,
62
    is_grads_batched: bool,
63
) -> Tuple[_ShapeorNestedShape, _ShapeorNestedShape]:
64
    # is_same_size ensures that both tensors are either nested or non nested
65
    # circular import
66
    from torch.nested._internal.nested_tensor import NestedTensor
67

68
    if isinstance(output, graph.GradientEdge):
69
        # We have already checked that we are not a C++ NestedTensor
70
        if is_grads_batched:
71
            raise RuntimeError("Batched grads are not supported with GradientEdge")
72
        out_metadata = output.node._input_metadata[output.output_nr]
73
        return torch.Size(out_metadata.shape), grad.shape
74

75
    if output.is_nested and not isinstance(output, NestedTensor):
76
        if is_grads_batched:
77
            raise RuntimeError("Batched grads are not supported with Nested Tensor.")
78
        out_shape = output._nested_tensor_size()
79
        grad_shape = grad._nested_tensor_size()
80

81
        return out_shape, grad_shape
82

83
    reg_out_shape = output.shape
84
    reg_grad_shape = grad.shape if not is_grads_batched else grad.shape[1:]
85
    return reg_out_shape, reg_grad_shape
86

87

88
def _make_grads(
89
    outputs: Union[Sequence[torch.Tensor], Sequence[graph.GradientEdge]],
90
    grads: Sequence[_OptionalTensor],
91
    is_grads_batched: bool,
92
) -> Tuple[_OptionalTensor, ...]:
93
    new_grads: List[_OptionalTensor] = []
94
    for out, grad in zip(outputs, grads):
95
        out = cast(Union[torch.Tensor, graph.GradientEdge], out)
96
        out_size = None
97
        out_device = None
98

99
        if isinstance(out, graph.GradientEdge):
100
            out_metadata = out.node._input_metadata[out.output_nr]
101
            out_size = torch.Size(out_metadata.shape)
102
            out_dtype = out_metadata.dtype
103
            out_device = out_metadata.device
104
            out_is_nested = out_metadata.is_nested_tensor
105
            if out_metadata.is_cpp_nested_tensor:
106
                raise RuntimeError(
107
                    "C++ NestedTensor are not supported with GradientEdge"
108
                )
109
            out_is_cpp_nested = False
110
        else:
111
            # circular import
112
            from torch.nested._internal.nested_tensor import NestedTensor
113

114
            assert isinstance(out, torch.Tensor)
115
            out_dtype = out.dtype
116
            out_is_nested = out.is_nested
117
            out_is_cpp_nested = out_is_nested and not isinstance(out, NestedTensor)
118
            if not out_is_cpp_nested:
119
                out_size = out.shape
120

121
        if isinstance(grad, torch.Tensor):
122
            from torch.fx.experimental.symbolic_shapes import expect_true, sym_eq
123

124
            first_grad = grad if not is_grads_batched else grad[0]
125

126
            # TODO: We can remove this conditional once we uniformly use
127
            # singleton int to represent jagged dimension, so that size() call
128
            # on nested tensor works.
129
            if out_is_cpp_nested:
130
                assert isinstance(out, torch.Tensor)
131
                shape_matches = torch.is_same_size(out, first_grad)
132
            else:
133
                # We need to do a regular size check, without going through
134
                # the operator, to be able to handle unbacked symints
135
                # (expect_true ensures we can deal with unbacked)
136
                assert out_size is not None
137
                shape_matches = expect_true(sym_eq(out_size, first_grad.size()))
138

139
            if not shape_matches:
140
                out = cast(Union[torch.Tensor, graph.GradientEdge], out)
141
                out_shape, grad_shape = _calculate_shape(
142
                    out, first_grad, is_grads_batched
143
                )
144
                if is_grads_batched:
145
                    raise RuntimeError(
146
                        "If `is_grads_batched=True`, we interpret the first "
147
                        "dimension of each grad_output as the batch dimension. "
148
                        "The sizes of the remaining dimensions are expected to match "
149
                        "the shape of corresponding output, but a mismatch "
150
                        "was detected: grad_output["
151
                        + str(grads.index(grad))
152
                        + "] has a shape of "
153
                        + str(grad_shape)
154
                        + " and output["
155
                        + str(outputs.index(out))
156
                        + "] has a shape of "
157
                        + str(out_shape)
158
                        + ". "
159
                        "If you only want some tensors in `grad_output` to be considered "
160
                        "batched, consider using vmap."
161
                    )
162
                else:
163
                    raise RuntimeError(
164
                        "Mismatch in shape: grad_output["
165
                        + str(grads.index(grad))
166
                        + "] has a shape of "
167
                        + str(grad_shape)
168
                        + " and output["
169
                        + str(outputs.index(out))
170
                        + "] has a shape of "
171
                        + str(out_shape)
172
                        + "."
173
                    )
174
            if out_dtype.is_complex != grad.dtype.is_complex:
175
                raise RuntimeError(
176
                    "For complex Tensors, both grad_output and output"
177
                    " are required to have the same dtype."
178
                    " Mismatch in dtype: grad_output["
179
                    + str(grads.index(grad))
180
                    + "] has a dtype of "
181
                    + str(grad.dtype)
182
                    + " and output["
183
                    + str(outputs.index(out))
184
                    + "] has a dtype of "
185
                    + str(out_dtype)
186
                    + "."
187
                )
188
            new_grads.append(grad)
189
        elif grad is None:
190
            if isinstance(out, graph.GradientEdge) or out.requires_grad:  # type: ignore[attr-defined]
191
                if isinstance(out, graph.GradientEdge):
192
                    assert out_size is not None
193
                    out_numel_is_1 = all(o == 1 for o in out_size)
194
                else:
195
                    assert isinstance(out, torch.Tensor)
196
                    out_numel_is_1 = out.numel() == 1
197
                if not out_numel_is_1:
198
                    raise RuntimeError(
199
                        "grad can be implicitly created only for scalar outputs"
200
                    )
201
                if not out_dtype.is_floating_point:
202
                    msg = (
203
                        "grad can be implicitly created only for real scalar outputs"
204
                        f" but got {out_dtype}"
205
                    )
206
                    raise RuntimeError(msg)
207
                if isinstance(out, graph.GradientEdge):
208
                    assert out_size is not None
209
                    assert out_device is not None
210
                    new_grads.append(
211
                        torch.ones(
212
                            out_size,
213
                            dtype=out_dtype,
214
                            device=out_device,
215
                        )
216
                    )
217
                else:
218
                    assert isinstance(out, torch.Tensor)
219
                    new_grads.append(
220
                        torch.ones_like(out, memory_format=torch.preserve_format)
221
                    )
222
            else:
223
                new_grads.append(None)
224
        else:
225
            raise TypeError(
226
                "gradients can be either Tensors or None, but got "
227
                + type(grad).__name__
228
            )
229
    return tuple(new_grads)
230

231

232
def _tensor_or_tensors_to_tuple(
233
    tensors: Optional[_TensorOrTensors], length: int
234
) -> Tuple[_OptionalTensor, ...]:
235
    if tensors is None:
236
        return (None,) * length
237
    if isinstance(tensors, torch.Tensor):
238
        return (tensors,)
239
    return tuple(tensors)
240

241

242
def backward(
243
    tensors: _TensorOrTensors,
244
    grad_tensors: Optional[_TensorOrTensors] = None,
245
    retain_graph: Optional[bool] = None,
246
    create_graph: bool = False,
247
    grad_variables: Optional[_TensorOrTensors] = None,
248
    inputs: Optional[_TensorOrTensorsOrGradEdge] = None,
249
) -> None:
250
    r"""Compute the sum of gradients of given tensors with respect to graph leaves.
251

252
    The graph is differentiated using the chain rule. If any of ``tensors``
253
    are non-scalar (i.e. their data has more than one element) and require
254
    gradient, then the Jacobian-vector product would be computed, in this
255
    case the function additionally requires specifying ``grad_tensors``.
256
    It should be a sequence of matching length, that contains the "vector"
257
    in the Jacobian-vector product, usually the gradient of the differentiated
258
    function w.r.t. corresponding tensors (``None`` is an acceptable value for
259
    all tensors that don't need gradient tensors).
260

261
    This function accumulates gradients in the leaves - you might need to zero
262
    ``.grad`` attributes or set them to ``None`` before calling it.
263
    See :ref:`Default gradient layouts<default-grad-layouts>`
264
    for details on the memory layout of accumulated gradients.
265

266
    .. note::
267
        Using this method with ``create_graph=True`` will create a reference cycle
268
        between the parameter and its gradient which can cause a memory leak.
269
        We recommend using ``autograd.grad`` when creating the graph to avoid this.
270
        If you have to use this function, make sure to reset the ``.grad`` fields of your
271
        parameters to ``None`` after use to break the cycle and avoid the leak.
272

273
    .. note::
274

275
        If you run any forward ops, create ``grad_tensors``, and/or call ``backward``
276
        in a user-specified CUDA stream context, see
277
        :ref:`Stream semantics of backward passes<bwd-cuda-stream-semantics>`.
278

279
    .. note::
280

281
        When ``inputs`` are provided and a given input is not a leaf,
282
        the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients).
283
        It is an implementation detail on which the user should not rely.
284
        See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.
285

286
    Args:
287
        tensors (Sequence[Tensor] or Tensor): Tensors of which the derivative will be
288
            computed.
289
        grad_tensors (Sequence[Tensor or None] or Tensor, optional): The "vector" in
290
            the Jacobian-vector product, usually gradients w.r.t. each element of
291
            corresponding tensors. None values can be specified for scalar Tensors or
292
            ones that don't require grad. If a None value would be acceptable for all
293
            grad_tensors, then this argument is optional.
294
        retain_graph (bool, optional): If ``False``, the graph used to compute the grad
295
            will be freed. Note that in nearly all cases setting this option to ``True``
296
            is not needed and often can be worked around in a much more efficient
297
            way. Defaults to the value of ``create_graph``.
298
        create_graph (bool, optional): If ``True``, graph of the derivative will
299
            be constructed, allowing to compute higher order derivative products.
300
            Defaults to ``False``.
301
        inputs (Sequence[Tensor] or Tensor or Sequence[GradientEdge], optional): Inputs w.r.t. which the gradient
302
            be will accumulated into ``.grad``. All other Tensors will be ignored. If
303
            not provided, the gradient is accumulated into all the leaf Tensors that
304
            were used to compute the :attr:`tensors`.
305
    """
306
    if torch._C._are_functorch_transforms_active():
307
        raise RuntimeError(
308
            "backward() called inside a functorch transform. This is not "
309
            "supported, please use functorch.grad or functorch.vjp instead "
310
            "or call backward() outside of functorch transforms."
311
        )
312

313
    if grad_variables is not None:
314
        warnings.warn(
315
            "`grad_variables` is deprecated. Use `grad_tensors` instead.",
316
            FutureWarning,
317
            stacklevel=2,
318
        )
319
        if grad_tensors is None:
320
            grad_tensors = grad_variables
321
        else:
322
            raise RuntimeError(
323
                "`grad_tensors` and `grad_variables` (deprecated) "
324
                "arguments both passed to `backward()`. Please only "
325
                "use `grad_tensors`."
326
            )
327
    if inputs is not None and len(inputs) == 0:
328
        raise RuntimeError("`inputs` argument to `backward()` cannot be empty.")
329

330
    tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors)
331
    inputs = (
332
        (inputs,)
333
        if isinstance(inputs, (torch.Tensor, graph.GradientEdge))
334
        else tuple(inputs)
335
        if inputs is not None
336
        else ()
337
    )
338

339
    grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors))
340
    grad_tensors_ = _make_grads(tensors, grad_tensors_, is_grads_batched=False)
341
    if retain_graph is None:
342
        retain_graph = create_graph
343

344
    # The reason we repeat the same comment below is that
345
    # some Python versions print out the first line of a multi-line function
346
    # calls in the traceback and some print out the last line
347
    _engine_run_backward(
348
        tensors,
349
        grad_tensors_,
350
        retain_graph,
351
        create_graph,
352
        inputs,
353
        allow_unreachable=True,
354
        accumulate_grad=True,
355
    )
356

357

358
def grad(
359
    outputs: _TensorOrTensorsOrGradEdge,
360
    inputs: _TensorOrTensorsOrGradEdge,
361
    grad_outputs: Optional[_TensorOrTensors] = None,
362
    retain_graph: Optional[bool] = None,
363
    create_graph: bool = False,
364
    only_inputs: bool = True,
365
    allow_unused: Optional[bool] = None,
366
    is_grads_batched: bool = False,
367
    materialize_grads: bool = False,
368
) -> Tuple[torch.Tensor, ...]:
369
    r"""Compute and return the sum of gradients of outputs with respect to the inputs.
370

371
    ``grad_outputs`` should be a sequence of length matching ``output``
372
    containing the "vector" in vector-Jacobian product, usually the pre-computed
373
    gradients w.r.t. each of the outputs. If an output doesn't require_grad,
374
    then the gradient can be ``None``).
375

376
    .. note::
377

378
        If you run any forward ops, create ``grad_outputs``, and/or call ``grad``
379
        in a user-specified CUDA stream context, see
380
        :ref:`Stream semantics of backward passes<bwd-cuda-stream-semantics>`.
381

382
    .. note::
383

384
        ``only_inputs`` argument is deprecated and is ignored now (defaults to ``True``).
385
        To accumulate gradient for other parts of the graph, please use
386
        ``torch.autograd.backward``.
387

388
    Args:
389
        outputs (sequence of Tensor or GradientEdge): outputs of the differentiated function.
390
        inputs (sequence of Tensor or GradientEdge): Inputs w.r.t. which the gradient will be
391
            returned (and not accumulated into ``.grad``).
392
        grad_outputs (sequence of Tensor): The "vector" in the vector-Jacobian product.
393
            Usually gradients w.r.t. each output. None values can be specified for scalar
394
            Tensors or ones that don't require grad. If a None value would be acceptable
395
            for all grad_tensors, then this argument is optional. Default: None.
396
        retain_graph (bool, optional): If ``False``, the graph used to compute the grad
397
            will be freed. Note that in nearly all cases setting this option to ``True``
398
            is not needed and often can be worked around in a much more efficient
399
            way. Defaults to the value of ``create_graph``.
400
        create_graph (bool, optional): If ``True``, graph of the derivative will
401
            be constructed, allowing to compute higher order derivative products.
402
            Default: ``False``.
403
        allow_unused (Optional[bool], optional): If ``False``, specifying inputs
404
            that were not used when computing outputs (and therefore their grad is
405
            always zero) is an error. Defaults to the value of ``materialize_grads``.
406
        is_grads_batched (bool, optional): If ``True``, the first dimension of each
407
            tensor in ``grad_outputs`` will be interpreted as the batch dimension.
408
            Instead of computing a single vector-Jacobian product, we compute a
409
            batch of vector-Jacobian products for each "vector" in the batch.
410
            We use the vmap prototype feature as the backend to vectorize calls
411
            to the autograd engine so that this computation can be performed in a
412
            single call. This should lead to performance improvements when compared
413
            to manually looping and performing backward multiple times. Note that
414
            due to this feature being experimental, there may be performance
415
            cliffs. Please use ``torch._C._debug_only_display_vmap_fallback_warnings(True)``
416
            to show any performance warnings and file an issue on github if warnings exist
417
            for your use case. Defaults to ``False``.
418
        materialize_grads (bool, optional): If ``True``, set the gradient for unused inputs
419
            to zero instead of None. This is useful when computing higher-order derivatives.
420
            If ``materialize_grads`` is ``True`` and ``allow_unused`` is ``False``, an error
421
            will be raised. Defaults to ``False``.
422

423
    """
424
    if materialize_grads and allow_unused is False:
425
        raise ValueError(
426
            "Expected allow_unused to be True or not passed when materialize_grads=True, "
427
            "but got: allow_unused=False."
428
        )
429
    if allow_unused is None:
430
        allow_unused = materialize_grads
431
    if is_tensor_like(outputs) or isinstance(outputs, graph.GradientEdge):
432
        outputs = cast(
433
            Union[Sequence[torch.Tensor], Sequence[graph.GradientEdge]], (outputs,)
434
        )
435
    else:
436
        outputs = tuple(outputs)
437
    if is_tensor_like(inputs) or isinstance(inputs, graph.GradientEdge):
438
        inputs = cast(_TensorOrTensorsOrGradEdge, (inputs,))
439
    else:
440
        inputs = tuple(inputs)
441
    t_outputs = tuple(i for i in outputs if is_tensor_like(i))
442
    t_inputs = tuple(i for i in inputs if is_tensor_like(i))
443
    overridable_args = t_outputs + t_inputs
444
    if has_torch_function(overridable_args):
445
        return handle_torch_function(
446
            grad,
447
            overridable_args,
448
            outputs,
449
            inputs,
450
            grad_outputs=grad_outputs,
451
            retain_graph=retain_graph,
452
            create_graph=create_graph,
453
            only_inputs=only_inputs,
454
            allow_unused=allow_unused,
455
            is_grads_batched=is_grads_batched,
456
            materialize_grads=materialize_grads,
457
        )
458

459
    if not only_inputs:
460
        warnings.warn(
461
            "only_inputs argument is deprecated and is ignored now "
462
            "(defaults to True). To accumulate gradient for other "
463
            "parts of the graph, please use torch.autograd.backward.",
464
            FutureWarning,
465
            stacklevel=2,
466
        )
467

468
    grad_outputs_ = _tensor_or_tensors_to_tuple(grad_outputs, len(outputs))
469
    grad_outputs_ = _make_grads(
470
        outputs, grad_outputs_, is_grads_batched=is_grads_batched
471
    )
472

473
    if retain_graph is None:
474
        retain_graph = create_graph
475

476
    # The reason we repeat the same comment several times below is because
477
    # some Python versions print out the first line of multi-line function
478
    # calls in the traceback and some print out the last line
479
    if is_grads_batched:
480

481
        def vjp(gO):
482
            return _engine_run_backward(
483
                outputs,
484
                gO,
485
                retain_graph,
486
                create_graph,
487
                inputs,
488
                allow_unused,
489
                accumulate_grad=False,
490
            )
491

492
        result = _vmap_internals._vmap(vjp, 0, 0, allow_none_pass_through=True)(
493
            grad_outputs_
494
        )
495
    else:
496
        result = _engine_run_backward(
497
            outputs,
498
            grad_outputs_,
499
            retain_graph,
500
            create_graph,
501
            inputs,
502
            allow_unused,
503
            accumulate_grad=False,
504
        )
505
    if materialize_grads:
506
        if any(
507
            result[i] is None and not is_tensor_like(inputs[i])
508
            for i in range(len(inputs))
509
        ):
510
            raise RuntimeError(
511
                "materialize_grads cannot be used when the given input is a GradientEdge"
512
            )
513
        result = tuple(
514
            output
515
            if output is not None
516
            else torch.zeros_like(input, requires_grad=True)
517
            for (output, input) in zip(result, inputs)
518
        )
519
    return result
520

521

522
# This function applies in case of gradient checkpointing for memory
523
# optimization. Currently, gradient checkpointing is supported only if the
524
# execution engine is invoked through torch.autograd.backward() and its
525
# inputs argument is not passed. It is not supported for torch.autograd.grad().
526
# This is because if inputs are specified, the gradient won't be calculated for
527
# anything else e.g. model parameters like weights, bias etc.
528
#
529
# This function returns whether the checkpointing is valid i.e. torch.autograd.backward
530
# or not i.e. torch.autograd.grad. The implementation works by maintaining a thread
531
# local variable in torch/csrc/autograd/engine.cpp which looks at the NodeTask
532
# in the stack and before a NodeTask is executed in evaluate_function, it
533
# checks for whether reentrant backwards is imperative or not.
534
# See https://github.com/pytorch/pytorch/pull/4594 for more discussion/context
535
def _is_checkpoint_valid():
536
    return Variable._execution_engine.is_checkpoint_valid()
537

538

539
def variable(*args, **kwargs):  # noqa: D103
540
    raise RuntimeError(
541
        "torch.autograd.variable(...) is deprecated, use torch.tensor(...) instead"
542
    )
543

544

545
# Monkey patching variable.Variable to fix FX codegen. FX generates a call by roughly doing
546
# f"{fn.__module__}.{fn.__name__}(...). This yields torch.autograd.variable.Variable(...) in the
547
# output of an FX graph.  Unfortunately the module name torch.autograd.variable is shadowed by the
548
# deprecated function - variable(...).
549
variable.Variable = Variable  # type: ignore[attr-defined]
550

551
if not torch._C._autograd_init():
552
    raise RuntimeError("autograd initialization failed")
553

554
# Import all native method/classes
555
from torch._C._autograd import (
556
    _add_metadata_json,
557
    _disable_profiler,
558
    _disable_profiler_legacy,
559
    _enable_profiler,
560
    _enable_profiler_legacy,
561
    _enable_record_function,
562
    _get_sequence_nr,
563
    _kineto_step,
564
    _KinetoEvent,
565
    _pop_saved_tensors_default_hooks,
566
    _prepare_profiler,
567
    _profiler_enabled,
568
    _ProfilerResult,
569
    _push_saved_tensors_default_hooks,
570
    _record_function_with_args_enter,
571
    _record_function_with_args_exit,
572
    _set_empty_test_observer,
573
    _supported_activities,
574
    _toggle_collection_dynamic,
575
    DeviceType,
576
    kineto_available,
577
    ProfilerEvent,
578
    SavedTensor,
579
)
580
from torch._C._profiler import ProfilerActivity, ProfilerConfig, ProfilerState
581

582
from . import profiler
583

584

585
def _register_py_tensor_class_for_device(device, cls):
586
    if not isinstance(cls, type):
587
        raise RuntimeError("cls isn't a typeinfo object")
588
    torch._C._register_py_class_for_device(device, cls)
589

590

591
is_multithreading_enabled = torch._C._is_multithreading_enabled
592
torch._C._add_docstr(
593
    is_multithreading_enabled, "Returns True if multithreading is currently enabled."
594
)
595

596
is_view_replay_enabled = torch._C._is_view_replay_enabled
597
torch._C._add_docstr(
598
    is_view_replay_enabled, "Returns True if view-replay is currently enabled."
599
)
600

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.