colossalai

Форк
0
46 строк · 1.7 Кб
1
import torch
2

3
from colossalai.nn.optimizer import CPUAdam, HybridAdam
4
from colossalai.testing import clear_cache_before_run, parameterize
5
from tests.kit.model_zoo import model_zoo
6

7

8
def move_some_params_to_cuda(model, torch_model):
9
    model.embed.weight.data = model.embed.weight.cuda()
10
    torch_model.embed.weight.data = model.embed.weight.cuda()
11
    model.ln1.weight.data = model.ln1.weight.cuda()
12
    torch_model.ln1.weight.data = model.ln1.weight.cuda()
13

14

15
def check_params_equal(model, torch_model):
16
    for p, torch_p in zip(model.parameters(), torch_model.parameters()):
17
        assert torch.allclose(p, torch_p, atol=1e-3), f"diff: {torch.abs(p - torch_p)}"
18

19

20
@clear_cache_before_run()
21
@parameterize("nvme_offload_fraction", [0.0, 0.5, 1.0])
22
@parameterize("nvme_offload_dir", ["./offload", None])
23
@parameterize("adam_cls", [CPUAdam, HybridAdam])
24
def test_nvme_adam(nvme_offload_fraction, nvme_offload_dir, adam_cls):
25
    model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry("custom_simple_net").values()))
26
    model = model_builder()
27
    torch_model = model_builder()
28
    move_some_params_to_cuda(model, torch_model)
29
    optimizer = adam_cls(
30
        model.parameters(), lr=0.1, nvme_offload_fraction=nvme_offload_fraction, nvme_offload_dir=nvme_offload_dir
31
    )
32
    torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.1)
33
    with torch.no_grad():
34
        for p, torch_p in zip(model.parameters(), torch_model.parameters()):
35
            torch_p.copy_(p)
36
            p.grad = torch.rand_like(p)
37
            torch_p.grad = p.grad
38

39
        for _ in range(3):
40
            optimizer.step()
41
            torch_optimizer.step()
42
            check_params_equal(model, torch_model)
43

44

45
if __name__ == "__main__":
46
    test_nvme_adam(0.5, "./offload", CPUAdam)
47

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.