pytorch

Форк
0
/
normalize_op_test.py 
50 строк · 1.6 Кб
1

2

3

4

5
import functools
6

7
import numpy as np
8
from hypothesis import given, settings
9
from caffe2.python import core
10
import caffe2.python.hypothesis_test_util as hu
11
import copy
12

13

14
class TestNormalizeOp(hu.HypothesisTestCase):
15
    @given(
16
        X=hu.tensor(
17
            min_dim=1, max_dim=5, elements=hu.floats(min_value=0.5, max_value=1.0)
18
        ),
19
        **hu.gcs
20
    )
21
    @settings(max_examples=10, deadline=None)
22
    def test_normalize(self, X, gc, dc):
23
        def ref_normalize(X, axis):
24
            x_normed = X / np.maximum(
25
                np.sqrt((X ** 2).sum(axis=axis, keepdims=True)), 1e-12
26
            )
27
            return (x_normed,)
28

29
        for axis in range(-X.ndim, X.ndim):
30
            x = copy.copy(X)
31
            op = core.CreateOperator("Normalize", "X", "Y", axis=axis)
32
            self.assertReferenceChecks(
33
                gc, op, [x], functools.partial(ref_normalize, axis=axis)
34
            )
35
            self.assertDeviceChecks(dc, op, [x], [0])
36
            self.assertGradientChecks(gc, op, [x], 0, [0])
37

38
    @given(
39
        X=hu.tensor(
40
            min_dim=1, max_dim=5, elements=hu.floats(min_value=0.5, max_value=1.0)
41
        ),
42
        **hu.gcs
43
    )
44
    @settings(max_examples=10, deadline=None)
45
    def test_normalize_L1(self, X, gc, dc):
46
        def ref(X, axis):
47
            norm = abs(X).sum(axis=axis, keepdims=True)
48
            return (X / norm,)
49

50
        for axis in range(-X.ndim, X.ndim):
51
            print("axis: ", axis)
52
            op = core.CreateOperator("NormalizeL1", "X", "Y", axis=axis)
53
            self.assertReferenceChecks(gc, op, [X], functools.partial(ref, axis=axis))
54
            self.assertDeviceChecks(dc, op, [X], [0])
55

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.