pytorch
72 строки · 2.1 Кб
1import torch2import torch.nn.functional as F3from torch.testing._internal.common_nn import wrap_functional4
5
6"""
7`sample_functional` is used by `test_cpp_api_parity.py` to test that Python / C++ API
8parity test harness works for `torch.nn.functional` functions.
9
10When `has_parity=true` is passed to `sample_functional`, behavior of `sample_functional`
11is the same as the C++ equivalent.
12
13When `has_parity=false` is passed to `sample_functional`, behavior of `sample_functional`
14is different from the C++ equivalent.
15"""
16
17
18def sample_functional(x, has_parity):19if has_parity:20return x * 221else:22return x * 423
24
25torch.nn.functional.sample_functional = sample_functional26
27SAMPLE_FUNCTIONAL_CPP_SOURCE = """\n28namespace torch {
29namespace nn {
30namespace functional {
31
32struct C10_EXPORT SampleFunctionalFuncOptions {
33SampleFunctionalFuncOptions(bool has_parity) : has_parity_(has_parity) {}
34
35TORCH_ARG(bool, has_parity);
36};
37
38Tensor sample_functional(Tensor x, SampleFunctionalFuncOptions options) {
39return x * 2;
40}
41
42} // namespace functional
43} // namespace nn
44} // namespace torch
45"""
46
47functional_tests = [48dict(49constructor=wrap_functional(F.sample_functional, has_parity=True),50cpp_options_args="F::SampleFunctionalFuncOptions(true)",51input_size=(1, 2, 3),52fullname="sample_functional_has_parity",53has_parity=True,54),55dict(56constructor=wrap_functional(F.sample_functional, has_parity=False),57cpp_options_args="F::SampleFunctionalFuncOptions(false)",58input_size=(1, 2, 3),59fullname="sample_functional_no_parity",60has_parity=False,61),62# This is to test that setting the `test_cpp_api_parity=False` flag skips63# the C++ API parity test accordingly (otherwise this test would run and64# throw a parity error).65dict(66constructor=wrap_functional(F.sample_functional, has_parity=False),67cpp_options_args="F::SampleFunctionalFuncOptions(false)",68input_size=(1, 2, 3),69fullname="sample_functional_THIS_TEST_SHOULD_BE_SKIPPED",70test_cpp_api_parity=False,71),72]
73