colossalai

Форк
0
/
test_runtime_mem_tracer.py 
54 строки · 1.8 Кб
1
from copy import deepcopy
2

3
import numpy as np
4
import pytest
5
import torch
6

7
from colossalai.testing import DummyDataloader, clear_cache_before_run
8
from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer
9
from tests.kit.model_zoo import model_zoo, run_fwd_bwd
10

11

12
@pytest.mark.skip("this is not used")
13
@clear_cache_before_run()
14
def test_runtime_mem_tracer():
15
    test_models = ["gpt2", "bert", "simple_net", "repeated_computed_layers", "nested_model", "albert"]
16

17
    for model_name in test_models:
18
        model_builder, data_gen_fn, output_transform_fn, *_ = next(
19
            iter(model_zoo.get_sub_registry(model_name).values())
20
        )
21

22
        model = model_builder().cuda()
23

24
        model_bk = deepcopy(model)
25
        runtime_mem_tracer = RuntimeMemTracer(model)
26

27
        train_dataloader = DummyDataloader(data_gen_fn)
28
        for i, data in enumerate(train_dataloader):
29
            if i > 1:
30
                break
31
            data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
32

33
            run_fwd_bwd(runtime_mem_tracer, data, output_transform_fn, optimizer=runtime_mem_tracer)
34

35
        for p1, p2 in zip(model_bk.parameters(), model.parameters()):
36
            torch.allclose(p1.to(torch.half), p2)
37

38
        non_model_data_list = runtime_mem_tracer._memstats.non_model_data_list("cuda")
39
        cuda_non_model_data_list = np.array(non_model_data_list) / 1024**2
40
        print("cuda_non_model_data_list", len(cuda_non_model_data_list))
41
        print(non_model_data_list)
42

43
        cnt1 = 0
44
        for p in runtime_mem_tracer.parameters_in_runtime_order():
45
            cnt1 += 1
46
        cnt2 = 0
47
        for p in model.parameters():
48
            cnt2 += 1
49
        assert cnt2 == cnt1, f"visited param number {cnt1} vs real param number {cnt2}"
50
        del model
51

52

53
if __name__ == "__main__":
54
    test_runtime_mem_tracer()
55

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.