pytorch

Форк
0
29 строк · 776.0 Байт
1
import argparse
2

3
import torch
4
from common import SubTensor, SubWithTorchFunction, WithTorchFunction  # noqa: F401
5

6
Tensor = torch.tensor
7

8
NUM_REPEATS = 1000000
9

10
if __name__ == "__main__":
11
    parser = argparse.ArgumentParser(
12
        description="Run the torch.add for a given class a given number of times."
13
    )
14
    parser.add_argument(
15
        "tensor_class", metavar="TensorClass", type=str, help="The class to benchmark."
16
    )
17
    parser.add_argument(
18
        "--nreps", "-n", type=int, default=NUM_REPEATS, help="The number of repeats."
19
    )
20
    args = parser.parse_args()
21

22
    TensorClass = globals()[args.tensor_class]
23
    NUM_REPEATS = args.nreps
24

25
    t1 = TensorClass([1.0])
26
    t2 = TensorClass([2.0])
27

28
    for _ in range(NUM_REPEATS):
29
        torch.add(t1, t2)
30

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.