pytorch
45 строк · 1.5 Кб
1
2
3
4
5
6import argparse
7import numpy as np
8
9from caffe2.python import core, workspace
10
11
12def benchmark_mul_gradient(args):
13workspace.FeedBlob("dC", np.random.rand(args.m, args.n).astype(np.float32))
14workspace.FeedBlob("A", np.random.rand(args.m, args.n).astype(np.float32))
15workspace.FeedBlob("B", np.random.rand(args.n).astype(np.float32))
16
17net = core.Net("mynet")
18net.MulGradient(
19["dC", "A", "B"],
20["dC" if args.inplace else "dA", "dB"],
21broadcast=True,
22axis=1,
23allow_broadcast_fastpath=args.allow_broadcast_fastpath,
24)
25workspace.CreateNet(net)
26
27workspace.BenchmarkNet(net.Name(), 1, args.iteration, True)
28
29
30if __name__ == "__main__":
31parser = argparse.ArgumentParser(
32description="benchmark for MulGradient.")
33parser.add_argument(
34'-m', type=int, default=9508,
35help="The number of rows of A")
36parser.add_argument(
37"-n", type=int, default=80,
38help="The number of columns of A")
39parser.add_argument(
40'-i', "--iteration", type=int, default=100,
41help="The number of iterations.")
42parser.add_argument(
43"--inplace",
44action='store_true', help="Whether to perform the op in-place.")
45parser.add_argument(
46"--allow-broadcast-fastpath",
47action='store_true', help="Whether the broadcast fastpath is enabled.")
48args, extra_args = parser.parse_known_args()
49core.GlobalInit(['python'] + extra_args)
50benchmark_mul_gradient(args)
51