pytorch

Форк
0
/
test_memory_tracker.py 
68 строк · 2.4 Кб
1
# Owner(s): ["oncall: distributed"]
2
import os
3
import unittest
4

5
import torch
6
import torch.nn as nn
7
from torch.distributed._tools import MemoryTracker
8
from torch.testing._internal.common_cuda import TEST_CUDA
9
from torch.testing._internal.common_utils import run_tests, TestCase
10

11

12
class TestMemoryTracker(TestCase):
13
    @unittest.skipIf(not TEST_CUDA, "no cuda")
14
    def test_local_model(self):
15
        """
16
        Minimal test case to check the memory tracker can collect the expected
17
        memory stats at operator level, as well as can print the summary result
18
        without crash.
19
        """
20
        # Create a model with a hierarchy of modules
21
        torch.manual_seed(0)
22
        model = nn.Sequential(
23
            nn.Sequential(
24
                nn.Conv2d(3, 64, kernel_size=(3, 3), padding=(1, 1), bias=False),
25
                nn.BatchNorm2d(64),
26
                nn.ReLU(inplace=False),
27
                nn.AdaptiveAvgPool2d(output_size=(1, 1)),
28
            ),
29
            nn.Flatten(start_dim=1),
30
            nn.Sequential(nn.Linear(64, 2), nn.ReLU(inplace=True)),
31
        ).cuda()
32

33
        # Run one iteration of forward and backward pass
34
        tracker = MemoryTracker()
35
        tracker.start_monitor(model)
36

37
        x = torch.randn(size=(2, 3, 224, 224), device=torch.device("cuda"))
38
        # torch.LongTensor expects cpu device type, not cuda device type in
39
        # constructor, so calling .cuda() outside constructor here.
40
        target = torch.LongTensor([0, 1]).cuda()
41
        criterion = nn.CrossEntropyLoss()
42
        criterion(model(x), target).backward()
43

44
        self.assertTrue(len(tracker._hooks) > 0)
45

46
        tracker.stop()
47

48
        self.assertTrue(len(tracker._hooks) == 0)
49

50
        path = "memory.trace"
51
        tracker.save_stats(path)
52
        tracker.load(path)
53
        tracker.summary()
54
        if os.path.exists(path):
55
            os.remove(path)
56

57
        self.assertTrue(tracker._op_index > 0)
58
        self.assertTrue(len(tracker._operator_names) > 0)
59
        self.assertEqual(len(tracker.memories_allocated), tracker._op_index)
60
        self.assertEqual(len(tracker.memories_active), tracker._op_index)
61
        self.assertEqual(len(tracker.memories_reserved), tracker._op_index)
62
        self.assertTrue(len(tracker._markers) == 2)
63
        self.assertTrue(tracker._cur_module_name != "")
64
        self.assertTrue(hasattr(tracker, "_num_cuda_retries"))
65

66

67
if __name__ == "__main__":
68
    run_tests()
69

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.