pytorch

Форк
0
/
unsafe_coalesce_test.py 
76 строк · 2.9 Кб
1
#!/usr/bin/env python3
2

3
import caffe2.python.hypothesis_test_util as hu
4
import hypothesis.strategies as st
5
import numpy as np
6
import numpy.testing as npt
7
from caffe2.python import core, workspace
8
from hypothesis import given
9

10

11
class TestUnsafeCoalesceOp(hu.HypothesisTestCase):
12
    @given(
13
        n=st.integers(1, 5),
14
        shape=st.lists(st.integers(0, 5), min_size=1, max_size=3),
15
        **hu.gcs
16
    )
17
    def test_unsafe_coalesce_op(self, n, shape, dc, gc):
18
        workspace.ResetWorkspace()
19
        test_inputs = [(100 * np.random.random(shape)).astype(np.float32) for _ in range(n)]
20
        test_input_blobs = ["x_{}".format(i) for i in range(n)]
21

22
        coalesce_op = core.CreateOperator(
23
            "UnsafeCoalesce",
24
            test_input_blobs,
25
            test_input_blobs + ["shared_memory_blob"],
26
            device_option=gc,
27
        )
28

29
        def reference_func(*args):
30
            self.assertEqual(len(args), n)
31
            return list(args) + [np.concatenate([x.flatten() for x in args])]
32

33
        self.assertReferenceChecks(gc, coalesce_op, test_inputs, reference_func)
34

35
    @given(
36
        n=st.integers(1, 5),
37
        shape=st.lists(st.integers(1, 5), min_size=1, max_size=3),
38
        seed=st.integers(0, 65535),
39
        **hu.gcs
40
    )
41
    def test_unsafe_coalesce_op_blob_sharing(self, n, shape, seed, dc, gc):
42
        workspace.ResetWorkspace()
43
        # Can make debugging of the test more predictable
44
        np.random.seed(seed)
45
        test_inputs = [(np.random.random(shape)).astype(np.float32) for _ in range(n)]
46
        test_input_blobs = ["x_{}".format(i) for i in range(n)]
47

48
        coalesce_op = core.CreateOperator(
49
            "UnsafeCoalesce",
50
            test_input_blobs,
51
            test_input_blobs + ["shared_memory_blob"],
52
            device_option=gc,
53
        )
54
        for name, value in zip(test_input_blobs, test_inputs):
55
            workspace.FeedBlob(name, value, device_option=gc)
56

57
        workspace.RunOperatorOnce(coalesce_op)
58
        blob_value = workspace.blobs["shared_memory_blob"]
59
        npt.assert_almost_equal(
60
            blob_value,
61
            np.concatenate([x.flatten() for x in test_inputs]),
62
            decimal=4
63
        )
64
        # np.random generates values in range [0, 1), so -2 is outside of range
65
        blob_value.fill(-2.0)
66
        self.assertTrue((blob_value != workspace.blobs["shared_memory_blob"]).all())
67
        workspace.FeedBlob("shared_memory_blob", blob_value, device_option=gc)
68

69
        # All blobs preserved shape, but got overwritted to -2
70
        for name, value in zip(test_input_blobs, test_inputs):
71
            self.assertEqual(value.shape, workspace.blobs[name].shape)
72
            self.assertTrue((value != workspace.blobs[name]).all())
73
            self.assertTrue((workspace.blobs[name] == -2).all())
74

75
        # It should be OK to reuse operator as long as it's blob shapes are not changing
76
        workspace.RunOperatorOnce(coalesce_op)
77

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.