pytorch
32 строки · 858.0 Байт
1import torch
2import torchvision
3
4from torch.distributed._tools import MemoryTracker
5
6
7def run_one_model(net: torch.nn.Module, input: torch.Tensor):
8net.cuda()
9input = input.cuda()
10
11# Create the memory Tracker
12mem_tracker = MemoryTracker()
13# start_monitor before the training iteration starts
14mem_tracker.start_monitor(net)
15
16# run one training iteration
17net.zero_grad(True)
18loss = net(input)
19if isinstance(loss, dict):
20loss = loss["out"]
21loss.sum().backward()
22net.zero_grad(set_to_none=True)
23
24# stop monitoring after the training iteration ends
25mem_tracker.stop()
26# print the memory stats summary
27mem_tracker.summary()
28# plot the memory traces at operator level
29mem_tracker.show_traces()
30
31
32run_one_model(torchvision.models.resnet34(), torch.rand(32, 3, 224, 224, device="cuda"))
33