pytorch

Форк
0
/
test_fsdp_input.py 
76 строк · 2.1 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import sys
4

5
import torch
6
from torch import distributed as dist
7
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
8
from torch.nn import Linear, Module
9
from torch.optim import SGD
10
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
11
from torch.testing._internal.common_fsdp import FSDPTest
12
from torch.testing._internal.common_utils import (
13
    instantiate_parametrized_tests,
14
    parametrize,
15
    run_tests,
16
    subtest,
17
    TEST_WITH_DEV_DBG_ASAN,
18
)
19

20
if not dist.is_available():
21
    print("Distributed not available, skipping tests", file=sys.stderr)
22
    sys.exit(0)
23

24
if TEST_WITH_DEV_DBG_ASAN:
25
    print(
26
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
27
        file=sys.stderr,
28
    )
29
    sys.exit(0)
30

31

32
class TestInput(FSDPTest):
33
    @property
34
    def world_size(self):
35
        return 1
36

37
    @skip_if_lt_x_gpu(1)
38
    @parametrize("input_cls", [subtest(dict, name="dict"), subtest(list, name="list")])
39
    def test_input_type(self, input_cls):
40
        """Test FSDP with input being a list or a dict, only single GPU."""
41

42
        class Model(Module):
43
            def __init__(self):
44
                super().__init__()
45
                self.layer = Linear(4, 4)
46

47
            def forward(self, input):
48
                if isinstance(input, list):
49
                    input = input[0]
50
                else:
51
                    assert isinstance(input, dict), input
52
                    input = input["in"]
53
                return self.layer(input)
54

55
        model = FSDP(Model()).cuda()
56
        optim = SGD(model.parameters(), lr=0.1)
57

58
        for _ in range(5):
59
            in_data = torch.rand(64, 4).cuda()
60
            in_data.requires_grad = True
61
            if input_cls is list:
62
                in_data = [in_data]
63
            else:
64
                self.assertTrue(input_cls is dict)
65
                in_data = {"in": in_data}
66

67
            out = model(in_data)
68
            out.sum().backward()
69
            optim.step()
70
            optim.zero_grad()
71

72

73
instantiate_parametrized_tests(TestInput)
74

75
if __name__ == "__main__":
76
    run_tests()
77

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.