pytorch
29 строк · 776.0 Байт
1import argparse
2
3import torch
4from common import SubTensor, SubWithTorchFunction, WithTorchFunction # noqa: F401
5
6Tensor = torch.tensor
7
8NUM_REPEATS = 1000000
9
10if __name__ == "__main__":
11parser = argparse.ArgumentParser(
12description="Run the torch.add for a given class a given number of times."
13)
14parser.add_argument(
15"tensor_class", metavar="TensorClass", type=str, help="The class to benchmark."
16)
17parser.add_argument(
18"--nreps", "-n", type=int, default=NUM_REPEATS, help="The number of repeats."
19)
20args = parser.parse_args()
21
22TensorClass = globals()[args.tensor_class]
23NUM_REPEATS = args.nreps
24
25t1 = TensorClass([1.0])
26t2 = TensorClass([2.0])
27
28for _ in range(NUM_REPEATS):
29torch.add(t1, t2)
30