pytorch

Форк
0
/
embedding_generation_benchmark.py 
196 строк · 5.1 Кб
1
## @package embedding_generation_benchmark
2
# Module caffe2.python.embedding_generation_benchmark
3

4

5

6

7

8
from caffe2.proto import caffe2_pb2
9
from caffe2.python import workspace, core, utils, model_helper
10

11
import argparse
12
import numpy as np
13
import time
14

15
import logging
16

17
logging.basicConfig()
18
log = logging.getLogger("embedding_generation_benchmark")
19
log.setLevel(logging.DEBUG)
20

21

22
def generate_data(T, batch_size, max_seq_length):
23
    '''
24
    Fill a queue with input data
25
    '''
26
    log.info("Generating T={} batches".format(T))
27

28
    generate_input_init_net = core.Net('generate_input_init')
29
    queue = generate_input_init_net.CreateBlobsQueue(
30
        [], "inputqueue", num_blobs=1, capacity=T,
31
    )
32
    workspace.RunNetOnce(generate_input_init_net)
33

34
    generate_input_net = core.Net('generate_input')
35
    generate_input_net.EnqueueBlobs([queue, "scratch"], ["scratch"])
36
    np.random.seed(2603)
37

38
    for t in range(T):
39
        if (t % (max(10, T // 10)) == 0):
40
            log.info("Generating data {}/{}".format(t, T))
41
        X = np.tile(np.arange(max_seq_length), [batch_size, 1]).transpose()
42
        workspace.FeedBlob("scratch", X)
43
        workspace.RunNetOnce(generate_input_net.Proto())
44

45
    log.info("Finished data generation")
46
    return queue
47

48

49
def generate_embedding_table(vocab_size, embedding_size):
50
    log.info("Generating embedding table with dimensions {}"
51
             .format([vocab_size, embedding_size]))
52

53
    generate_table_net = core.Net('generate_table')
54
    table = generate_table_net.GaussianFill(
55
        [],
56
        ['embedding_table'],
57
        shape=[vocab_size, embedding_size],
58
    )
59

60
    workspace.RunNetOnce(generate_table_net)
61
    return table
62

63

64
def create_model(args, queue, embedding_table, embedding_size):
65
    model = model_helper.ModelHelper(name='embedding_generation_bench')
66
    input_blob = model.net.DequeueBlobs(queue, 'input_data')
67

68
    if args.implementation == 'sinusoid':
69
        model.net.SinusoidPositionEncoding(
70
            [input_blob],
71
            ['output'],
72
            embedding_size=embedding_size
73
        )
74
    else:
75
        model.net.Gather(
76
            [embedding_table, input_blob],
77
            ['output'],
78
        )
79

80
    return model
81

82

83
def Caffe2EmbeddingGeneration(args):
84
    T = args.data_size // args.batch_size
85

86
    queue = generate_data(T, args.batch_size, args.seq_length)
87

88
    embedding_table = None
89
    if args.implementation == 'table':
90
        embedding_table = generate_embedding_table(
91
            args.seq_length,
92
            args.embedding_size,
93
        )
94

95
    model = create_model(args, queue, embedding_table, args.embedding_size)
96

97
    workspace.RunNetOnce(model.param_init_net)
98
    workspace.CreateNet(model.net)
99

100
    start_time = time.time()
101
    num_iters = T
102
    total_iters = 0
103

104
    # Run the Benchmark
105
    log.info("------ Warming up ------")
106
    workspace.RunNet(model.net.Proto().name)
107

108
    log.info("------ Starting benchmark ------")
109
    start_time = time.time()
110
    last_time = time.time()
111
    for iteration in range(1, num_iters, args.iters_to_report):
112
        iters_once = min(args.iters_to_report, num_iters - iteration)
113
        total_iters += iters_once
114
        workspace.RunNet(model.net.Proto().name, iters_once)
115

116
        new_time = time.time()
117
        log.info(
118
            "Iter: {} / {}. Embeddings Generated Per Second: {}k.".format(
119
                iteration,
120
                num_iters,
121
                (iters_once * args.batch_size * args.seq_length) /
122
                (new_time - last_time) // 100 / 10,
123
            )
124
        )
125
        last_time = new_time
126

127
    total_per_sec = (num_iters - 1) * args.batch_size * args.seq_length
128
    total_per_sec = total_per_sec / (time.time() - start_time) // 100 / 10
129

130
    log.info("Done. Total embeddings generated per second " +
131
             "excluding 1st iteration: {}k".format(total_per_sec))
132

133
    return time.time() - start_time
134

135

136
@utils.debug
137
def Benchmark(args):
138
    return Caffe2EmbeddingGeneration(args)
139

140

141
def GetArgumentParser():
142
    parser = argparse.ArgumentParser(
143
        description="Embedding generation benchmark."
144
    )
145

146
    parser.add_argument(
147
        "--embedding_size",
148
        type=int,
149
        default=512,
150
        help="Embedding size",
151
    )
152
    parser.add_argument(
153
        "--batch_size",
154
        type=int,
155
        default=16,
156
        help="The batch size."
157
    )
158
    parser.add_argument(
159
        "--data_size",
160
        type=int,
161
        default=10000,
162
        help="Number of sequences to generate"
163
    )
164
    parser.add_argument(
165
        "--seq_length",
166
        type=int,
167
        default=128,
168
        help="Max sequence length"
169
    )
170
    parser.add_argument(
171
        "--iters_to_report",
172
        type=int,
173
        default=20,
174
        help="Number of iterations to report progress"
175
    )
176
    parser.add_argument(
177
        "--implementation",
178
        type=str,
179
        default="sinusoid",
180
        help="'table' or 'sinusoid'",
181
    )
182
    return parser
183

184

185
if __name__ == '__main__':
186
    args, extra_args = GetArgumentParser().parse_known_args()
187

188
    workspace.GlobalInit([
189
        'caffe2',
190
        '--caffe2_log_level=0',
191
        '--caffe2_print_blob_sizes_at_exit=0'] + extra_args)
192

193
    device = core.DeviceOption(caffe2_pb2.CPU)
194

195
    with core.DeviceScope(device):
196
        Benchmark(args)
197

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.