pytorch

Форк
0
/
lengths_tile_op_test.py 
46 строк · 1.3 Кб
1

2

3

4

5

6
from caffe2.python import core
7
import caffe2.python.hypothesis_test_util as hu
8
import caffe2.python.serialized_test.serialized_test_util as serial
9
import hypothesis.strategies as st
10
import numpy as np
11

12

13
class TestLengthsTileOp(serial.SerializedTestCase):
14

15
    @serial.given(
16
        inputs=st.integers(min_value=1, max_value=20).flatmap(
17
            lambda size: st.tuples(
18
                hu.arrays([size], dtype=np.float32),
19
                hu.arrays([size], dtype=np.int32,
20
                          elements=st.integers(min_value=0, max_value=20)),
21
            )
22
        ),
23
        **hu.gcs)
24
    def test_lengths_tile(self, inputs, gc, dc):
25
        data, lengths = inputs
26

27
        def lengths_tile_op(data, lengths):
28
            return [np.concatenate([
29
                [d] * l for d, l in zip(data, lengths)
30
            ])]
31

32
        op = core.CreateOperator(
33
            "LengthsTile",
34
            ["data", "lengths"],
35
            ["output"],
36
        )
37

38
        self.assertReferenceChecks(
39
            device_option=gc,
40
            op=op,
41
            inputs=[data, lengths],
42
            reference=lengths_tile_op,
43
        )
44

45
        self.assertGradientChecks(
46
            device_option=gc,
47
            op=op,
48
            inputs=[data, lengths],
49
            outputs_to_check=0,
50
            outputs_with_grads=[0]
51
        )
52

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.