pytorch

Форк
0
/
test_fsdp_uneven.py 
68 строк · 2.3 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import sys
4

5
import torch
6
from torch import distributed as dist
7
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
8
from torch.nn import Linear
9
from torch.optim import SGD
10
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
11
from torch.testing._internal.common_fsdp import FSDPTest
12
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
13

14
if not dist.is_available():
15
    print("Distributed not available, skipping tests", file=sys.stderr)
16
    sys.exit(0)
17

18
if TEST_WITH_DEV_DBG_ASAN:
19
    print(
20
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
21
        file=sys.stderr,
22
    )
23
    sys.exit(0)
24

25

26
class TestUnevenParamShard(FSDPTest):
27
    def _get_ref_results(self, model, input, my_lr):
28
        with torch.no_grad():
29
            # Compute one iteration local output.
30
            weight = model.weight.T.clone().to(self.rank)
31
            v = torch.Tensor(input[self.rank]).to(self.rank)
32
            ref_forward_output_my_rank = torch.matmul(v, weight)
33
            # Compute one iteration global weight update.
34
            v = torch.Tensor(input[: self.world_size]).to(self.rank)
35
            grad = v.float().sum(0).repeat(weight.shape[0], 1).div(self.world_size)
36
            ref_weight_out = weight - grad.T * my_lr
37

38
        return ref_forward_output_my_rank, ref_weight_out
39

40
    @skip_if_lt_x_gpu(2)
41
    def test_one_iteration(self):
42
        """Test FSDP with uneven divide of parameter shards."""
43
        model = Linear(3, 3, bias=False)
44
        input = torch.rand(8, 3)
45
        my_lr = 0.1
46

47
        ref_forward_output_my_rank, ref_weight_out = self._get_ref_results(
48
            model, input, my_lr
49
        )
50

51
        model.to(self.rank)
52
        model = FSDP(model)
53
        optim = SGD(model.parameters(), lr=my_lr)
54
        self.assertTrue(len(input) >= self.world_size)
55
        in_data = torch.Tensor(input[self.rank]).to(self.rank)
56
        out = model(in_data)
57
        out.float().sum().backward()
58
        optim.step()
59
        optim.zero_grad()
60

61
        with model.summon_full_params(model):
62
            weight_out = model.module.weight.T.clone()
63
            self.assertEqual(ref_forward_output_my_rank, out)
64
            self.assertEqual(ref_weight_out, weight_out)
65

66

67
if __name__ == "__main__":
68
    run_tests()
69

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.