pytorch

Форк
0
/
sample_functional.py 
72 строки · 2.1 Кб
1
import torch
2
import torch.nn.functional as F
3
from torch.testing._internal.common_nn import wrap_functional
4

5

6
"""
7
`sample_functional` is used by `test_cpp_api_parity.py` to test that Python / C++ API
8
parity test harness works for `torch.nn.functional` functions.
9

10
When `has_parity=true` is passed to `sample_functional`, behavior of `sample_functional`
11
is the same as the C++ equivalent.
12

13
When `has_parity=false` is passed to `sample_functional`, behavior of `sample_functional`
14
is different from the C++ equivalent.
15
"""
16

17

18
def sample_functional(x, has_parity):
19
    if has_parity:
20
        return x * 2
21
    else:
22
        return x * 4
23

24

25
torch.nn.functional.sample_functional = sample_functional
26

27
SAMPLE_FUNCTIONAL_CPP_SOURCE = """\n
28
namespace torch {
29
namespace nn {
30
namespace functional {
31

32
struct C10_EXPORT SampleFunctionalFuncOptions {
33
  SampleFunctionalFuncOptions(bool has_parity) : has_parity_(has_parity) {}
34

35
  TORCH_ARG(bool, has_parity);
36
};
37

38
Tensor sample_functional(Tensor x, SampleFunctionalFuncOptions options) {
39
    return x * 2;
40
}
41

42
} // namespace functional
43
} // namespace nn
44
} // namespace torch
45
"""
46

47
functional_tests = [
48
    dict(
49
        constructor=wrap_functional(F.sample_functional, has_parity=True),
50
        cpp_options_args="F::SampleFunctionalFuncOptions(true)",
51
        input_size=(1, 2, 3),
52
        fullname="sample_functional_has_parity",
53
        has_parity=True,
54
    ),
55
    dict(
56
        constructor=wrap_functional(F.sample_functional, has_parity=False),
57
        cpp_options_args="F::SampleFunctionalFuncOptions(false)",
58
        input_size=(1, 2, 3),
59
        fullname="sample_functional_no_parity",
60
        has_parity=False,
61
    ),
62
    # This is to test that setting the `test_cpp_api_parity=False` flag skips
63
    # the C++ API parity test accordingly (otherwise this test would run and
64
    # throw a parity error).
65
    dict(
66
        constructor=wrap_functional(F.sample_functional, has_parity=False),
67
        cpp_options_args="F::SampleFunctionalFuncOptions(false)",
68
        input_size=(1, 2, 3),
69
        fullname="sample_functional_THIS_TEST_SHOULD_BE_SKIPPED",
70
        test_cpp_api_parity=False,
71
    ),
72
]
73

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.