pytorch

Форк
0
/
memory_tracker_example.py 
32 строки · 858.0 Байт
1
import torch
2
import torchvision
3

4
from torch.distributed._tools import MemoryTracker
5

6

7
def run_one_model(net: torch.nn.Module, input: torch.Tensor):
8
    net.cuda()
9
    input = input.cuda()
10

11
    # Create the memory Tracker
12
    mem_tracker = MemoryTracker()
13
    # start_monitor before the training iteration starts
14
    mem_tracker.start_monitor(net)
15

16
    # run one training iteration
17
    net.zero_grad(True)
18
    loss = net(input)
19
    if isinstance(loss, dict):
20
        loss = loss["out"]
21
    loss.sum().backward()
22
    net.zero_grad(set_to_none=True)
23

24
    # stop monitoring after the training iteration ends
25
    mem_tracker.stop()
26
    # print the memory stats summary
27
    mem_tracker.summary()
28
    # plot the memory traces at operator level
29
    mem_tracker.show_traces()
30

31

32
run_one_model(torchvision.models.resnet34(), torch.rand(32, 3, 224, 224, device="cuda"))
33

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.