pytorch

Форк
0
/
pipeline_test.py 
72 строки · 2.5 Кб
1

2

3

4

5

6
from caffe2.python.schema import (
7
    Struct, FetchRecord, NewRecord, FeedRecord, InitEmptyRecord)
8
from caffe2.python import core, workspace
9
from caffe2.python.session import LocalSession
10
from caffe2.python.dataset import Dataset
11
from caffe2.python.pipeline import pipe
12
from caffe2.python.queue_util import Queue
13
from caffe2.python.task import TaskGroup
14
from caffe2.python.test_util import TestCase
15
from caffe2.python.net_builder import ops
16
import numpy as np
17
import math
18

19

20
class TestPipeline(TestCase):
21
    def test_dequeue_many(self):
22
        init_net = core.Net('init')
23
        N = 17
24
        NUM_DEQUEUE_RECORDS = 3
25
        src_values = Struct(
26
            ('uid', np.array(range(N))),
27
            ('value', 0.1 * np.array(range(N))))
28
        expected_dst = Struct(
29
            ('uid', 2 * np.array(range(N))),
30
            ('value', np.array(N * [0.0])))
31

32
        with core.NameScope('init'):
33
            src_blobs = NewRecord(init_net, src_values)
34
            dst_blobs = InitEmptyRecord(init_net, src_values.clone_schema())
35
            counter = init_net.Const(0)
36
            ONE = init_net.Const(1)
37

38
        def proc1(rec):
39
            with core.NameScope('proc1'):
40
                out = NewRecord(ops, rec)
41
            ops.Add([rec.uid(), rec.uid()], [out.uid()])
42
            out.value.set(blob=rec.value(), unsafe=True)
43
            return out
44

45
        def proc2(rec):
46
            with core.NameScope('proc2'):
47
                out = NewRecord(ops, rec)
48
            out.uid.set(blob=rec.uid(), unsafe=True)
49
            ops.Sub([rec.value(), rec.value()], [out.value()])
50
            ops.Add([counter, ONE], [counter])
51
            return out
52

53
        src_ds = Dataset(src_blobs)
54
        dst_ds = Dataset(dst_blobs)
55

56
        with TaskGroup() as tg:
57
            out1 = pipe(
58
                src_ds.reader(),
59
                output=Queue(
60
                    capacity=11, num_dequeue_records=NUM_DEQUEUE_RECORDS),
61
                processor=proc1)
62
            out2 = pipe(out1, processor=proc2)
63
            pipe(out2, dst_ds.writer())
64

65
        ws = workspace.C.Workspace()
66
        FeedRecord(src_blobs, src_values, ws)
67
        session = LocalSession(ws)
68
        session.run(init_net)
69
        session.run(tg)
70
        output = FetchRecord(dst_blobs, ws=ws)
71
        num_dequeues = ws.blobs[str(counter)].fetch()
72

73
        self.assertEqual(
74
            num_dequeues, int(math.ceil(float(N) / NUM_DEQUEUE_RECORDS)))
75

76
        for a, b in zip(output.field_blobs(), expected_dst.field_blobs()):
77
            np.testing.assert_array_equal(a, b)
78

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.