pytorch

Форк
0
/
test_unflatten.py 
75 строк · 2.1 Кб
1
# Copyright (c) Meta Platforms, Inc. and affiliates
2
# Owner(s): ["oncall: distributed"]
3
import torch
4
from torch.distributed.pipelining import pipe_split, pipeline
5
from torch.testing._internal.common_utils import run_tests, TestCase
6

7

8
# Building block for model
9
class Block(torch.nn.Module):
10
    def __init__(self) -> None:
11
        super().__init__()
12
        self.conv = torch.nn.Conv2d(
13
            in_channels=16, out_channels=16, kernel_size=3, padding=1
14
        )
15
        self.lin0 = torch.nn.Linear(256, 256)
16
        self.relu = torch.nn.ReLU()
17
        self.lin1 = torch.nn.Linear(256, 256)
18

19
    def forward(self, x: torch.Tensor, constant=None) -> torch.Tensor:
20
        x = self.conv(x)
21
        x = self.lin0(x)
22
        pipe_split()
23
        x.add_(constant)
24
        x = self.lin1(x)
25
        return self.relu(x)
26

27

28
# Full model
29
class M(torch.nn.Module):
30
    def __init__(self) -> None:
31
        super().__init__()
32
        self.block0 = Block()
33
        self.block1 = Block()
34

35
    def forward(self, x: torch.Tensor, constant=None) -> torch.Tensor:
36
        x = self.block0(x, constant=constant)
37
        pipe_split()
38
        x = self.block1(x, constant=constant)
39
        return x
40

41

42
class UnflattenTests(TestCase):
43
    def test_unflatten(self):
44
        x = torch.randn(1, 16, 256, 256)
45
        constant = torch.ones(1, 16, 256, 256)
46

47
        mod = M()
48

49
        pipe = pipeline(
50
            mod,
51
            (x,),
52
            {"constant": constant},
53
        )
54

55
        assert pipe.num_stages == 4
56
        orig_state_dict = mod.state_dict()
57

58
        # Check qualnames
59
        for stage_idx in range(pipe.num_stages):
60
            stage_mod = pipe.get_stage_module(stage_idx)
61
            for param_name, param in stage_mod.named_parameters():
62
                assert (
63
                    param_name in orig_state_dict
64
                ), f"{param_name} not in original state dict"
65
        print("Param qualname test passed")
66

67
        # Check equivalence
68
        ref = mod(x, constant)
69
        out = pipe(x, constant)[0]
70
        torch.testing.assert_close(out, ref)
71
        print(f"Equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}")
72

73

74
if __name__ == "__main__":
75
    run_tests()
76

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.