pytorch
68 строк · 2.4 Кб
1# Owner(s): ["oncall: distributed"]
2import os
3import unittest
4
5import torch
6import torch.nn as nn
7from torch.distributed._tools import MemoryTracker
8from torch.testing._internal.common_cuda import TEST_CUDA
9from torch.testing._internal.common_utils import run_tests, TestCase
10
11
12class TestMemoryTracker(TestCase):
13@unittest.skipIf(not TEST_CUDA, "no cuda")
14def test_local_model(self):
15"""
16Minimal test case to check the memory tracker can collect the expected
17memory stats at operator level, as well as can print the summary result
18without crash.
19"""
20# Create a model with a hierarchy of modules
21torch.manual_seed(0)
22model = nn.Sequential(
23nn.Sequential(
24nn.Conv2d(3, 64, kernel_size=(3, 3), padding=(1, 1), bias=False),
25nn.BatchNorm2d(64),
26nn.ReLU(inplace=False),
27nn.AdaptiveAvgPool2d(output_size=(1, 1)),
28),
29nn.Flatten(start_dim=1),
30nn.Sequential(nn.Linear(64, 2), nn.ReLU(inplace=True)),
31).cuda()
32
33# Run one iteration of forward and backward pass
34tracker = MemoryTracker()
35tracker.start_monitor(model)
36
37x = torch.randn(size=(2, 3, 224, 224), device=torch.device("cuda"))
38# torch.LongTensor expects cpu device type, not cuda device type in
39# constructor, so calling .cuda() outside constructor here.
40target = torch.LongTensor([0, 1]).cuda()
41criterion = nn.CrossEntropyLoss()
42criterion(model(x), target).backward()
43
44self.assertTrue(len(tracker._hooks) > 0)
45
46tracker.stop()
47
48self.assertTrue(len(tracker._hooks) == 0)
49
50path = "memory.trace"
51tracker.save_stats(path)
52tracker.load(path)
53tracker.summary()
54if os.path.exists(path):
55os.remove(path)
56
57self.assertTrue(tracker._op_index > 0)
58self.assertTrue(len(tracker._operator_names) > 0)
59self.assertEqual(len(tracker.memories_allocated), tracker._op_index)
60self.assertEqual(len(tracker.memories_active), tracker._op_index)
61self.assertEqual(len(tracker.memories_reserved), tracker._op_index)
62self.assertTrue(len(tracker._markers) == 2)
63self.assertTrue(tracker._cur_module_name != "")
64self.assertTrue(hasattr(tracker, "_num_cuda_retries"))
65
66
67if __name__ == "__main__":
68run_tests()
69