pytorch

Форк
0
/
__init__.py 
53 строки · 2.1 Кб
1
from .modules import *  # noqa: F403
2
from .parameter import (
3
    Parameter as Parameter,
4
    UninitializedParameter as UninitializedParameter,
5
    UninitializedBuffer as UninitializedBuffer,
6
)
7
from .parallel import DataParallel as DataParallel
8
from . import init
9
from . import functional
10
from . import utils
11
from . import attention
12

13

14
def factory_kwargs(kwargs):
15
    r"""Return a canonicalized dict of factory kwargs.
16

17
    Given kwargs, returns a canonicalized dict of factory kwargs that can be directly passed
18
    to factory functions like torch.empty, or errors if unrecognized kwargs are present.
19

20
    This function makes it simple to write code like this::
21

22
        class MyModule(nn.Module):
23
            def __init__(self, **kwargs):
24
                factory_kwargs = torch.nn.factory_kwargs(kwargs)
25
                self.weight = Parameter(torch.empty(10, **factory_kwargs))
26

27
    Why should you use this function instead of just passing `kwargs` along directly?
28

29
    1. This function does error validation, so if there are unexpected kwargs we will
30
    immediately report an error, instead of deferring it to the factory call
31
    2. This function supports a special `factory_kwargs` argument, which can be used to
32
    explicitly specify a kwarg to be used for factory functions, in the event one of the
33
    factory kwargs conflicts with an already existing argument in the signature (e.g.
34
    in the signature ``def f(dtype, **kwargs)``, you can specify ``dtype`` for factory
35
    functions, as distinct from the dtype argument, by saying
36
    ``f(dtype1, factory_kwargs={"dtype": dtype2})``)
37
    """
38
    if kwargs is None:
39
        return {}
40
    simple_keys = {"device", "dtype", "memory_format"}
41
    expected_keys = simple_keys | {"factory_kwargs"}
42
    if not kwargs.keys() <= expected_keys:
43
        raise TypeError(f"unexpected kwargs {kwargs.keys() - expected_keys}")
44

45
    # guarantee no input kwargs is untouched
46
    r = dict(kwargs.get("factory_kwargs", {}))
47
    for k in simple_keys:
48
        if k in kwargs:
49
            if k in r:
50
                raise TypeError(f"{k} specified twice, in **kwargs and in factory_kwargs")
51
            r[k] = kwargs[k]
52

53
    return r
54

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.