pytorch

Форк
0
/
hyperbolic_ops_test.py 
40 строк · 1.4 Кб
1

2

3

4

5

6
from caffe2.python import core
7
import caffe2.python.hypothesis_test_util as hu
8
import caffe2.python.serialized_test.serialized_test_util as serial
9
import hypothesis.strategies as st
10
import numpy as np
11

12

13
class TestHyperbolicOps(serial.SerializedTestCase):
14
    def _test_hyperbolic_op(self, op_name, np_ref, X, in_place, engine, gc, dc):
15
        op = core.CreateOperator(
16
            op_name,
17
            ["X"],
18
            ["X"] if in_place else ["Y"],
19
            engine=engine,)
20

21
        def ref(X):
22
            return [np_ref(X)]
23

24
        self.assertReferenceChecks(
25
            device_option=gc,
26
            op=op,
27
            inputs=[X],
28
            reference=ref,
29
            ensure_outputs_are_inferred=True,
30
        )
31
        self.assertDeviceChecks(dc, op, [X], [0])
32
        self.assertGradientChecks(gc, op, [X], 0, [0], ensure_outputs_are_inferred=True)
33

34
    @serial.given(X=hu.tensor(dtype=np.float32), **hu.gcs)
35
    def test_sinh(self, X, gc, dc):
36
        self._test_hyperbolic_op("Sinh", np.sinh, X, False, "", gc, dc)
37

38
    @serial.given(X=hu.tensor(dtype=np.float32), **hu.gcs)
39
    def test_cosh(self, X, gc, dc):
40
        self._test_hyperbolic_op("Cosh", np.cosh, X, False, "", gc, dc)
41

42
    @serial.given(X=hu.tensor(dtype=np.float32), in_place=st.booleans(),
43
           engine=st.sampled_from(["", "CUDNN"]), **hu.gcs)
44
    def test_tanh(self, X, in_place, engine, gc, dc):
45
        self._test_hyperbolic_op("Tanh", np.tanh, X, in_place, engine, gc, dc)
46

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.