pytorch

Форк
0
/
test_apply_optimizer_in_backward.py 
162 строки · 5.5 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
# Copyright (c) Meta Platforms, Inc. and affiliates.
4
# All rights reserved.
5
#
6
# This source code is licensed under the BSD-style license found in the
7
# LICENSE file in the root directory of this source tree.
8

9

10
import unittest
11
from copy import deepcopy
12

13
import torch
14
import torch.nn as nn
15
from torch.distributed.optim import (
16
    _apply_optimizer_in_backward,
17
    _get_in_backward_optimizers,
18
)
19

20

21
# TODO (rohan-varma): Add FSDP & DDP tests once supported
22

23

24
def _validate_params(params_list, fn):
25
    ref_params = params_list[0]
26
    for param_list in params_list[1:]:
27
        for p1, p2 in zip(ref_params, param_list):
28
            fn(p1, p2)
29

30

31
class ApplyOverlappedOptimizerTest(unittest.TestCase):
32
    def _run_training_loop_and_validate(self, inp, models, optimizers):
33
        for i in range(6):
34
            for model in models:
35
                model(inp).sum().backward()
36
            for opt in optimizers:
37
                opt.step()
38

39
            with self.subTest(i):
40
                _validate_params(
41
                    [model.parameters() for model in models],
42
                    torch.testing.assert_allclose,
43
                )
44

45
            for opt in optimizers:
46
                opt.zero_grad(set_to_none=True)
47

48
    def _test_apply_optimizer_in_backward(self, share_params) -> None:
49
        weight_optimizer_kwargs = {"lr": 1.0}
50
        bias_optimizer_kwargs = {"lr": 0.5}
51
        model = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10))
52
        if share_params:
53
            model[0].weight = model[1].weight
54

55
        # Use different optimizers for weights & biases.
56
        weights = [m.weight for m in model]
57
        biases = [m.bias for m in model]
58
        optim_weight = torch.optim.SGD(weights, **weight_optimizer_kwargs)
59
        optim_bias = torch.optim.SGD(biases, **bias_optimizer_kwargs)
60
        model_with_opt_in_bwd = deepcopy(model)
61

62
        # Apply different optimizer in backwards for weights and biases.
63
        _apply_optimizer_in_backward(
64
            torch.optim.SGD,
65
            [m.weight for m in model_with_opt_in_bwd],
66
            optimizer_kwargs=weight_optimizer_kwargs,
67
        )
68

69
        _apply_optimizer_in_backward(
70
            torch.optim.SGD,
71
            [m.bias for m in model_with_opt_in_bwd],
72
            optimizer_kwargs=bias_optimizer_kwargs,
73
        )
74

75
        _validate_params(
76
            [
77
                model.parameters(),
78
                model_with_opt_in_bwd.parameters(),
79
            ],
80
            torch.testing.assert_allclose,
81
        )
82

83
        self._run_training_loop_and_validate(
84
            torch.randn(4, 10),
85
            [model, model_with_opt_in_bwd],
86
            [optim_weight, optim_bias],
87
        )
88

89
    def test_apply_optimizer_in_backward(self) -> None:
90
        self._test_apply_optimizer_in_backward(share_params=False)
91

92
    def test_apply_optimizer_in_backward_shared_params(self) -> None:
93
        self._test_apply_optimizer_in_backward(share_params=True)
94

95
    def test_no_register_hook(self):
96
        model_with_hook = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10))
97
        initial_model = deepcopy(model_with_hook)
98
        model_no_hook = deepcopy(model_with_hook)
99
        _apply_optimizer_in_backward(
100
            torch.optim.SGD,
101
            model_with_hook.parameters(),
102
            optimizer_kwargs={"lr": 0.03},
103
        )
104
        _apply_optimizer_in_backward(
105
            torch.optim.SGD,
106
            model_no_hook.parameters(),
107
            optimizer_kwargs={"lr": 0.03},
108
            register_hook=False,
109
        )
110
        inp = torch.randn(4, 10)
111
        model_with_hook(inp).sum().backward()
112
        model_no_hook(inp).sum().backward()
113

114
        for p1, p2 in zip(model_with_hook.parameters(), initial_model.parameters()):
115
            with self.assertRaises(AssertionError):
116
                torch.testing.assert_allclose(p1, p2)
117

118
        for p1, p2 in zip(model_no_hook.parameters(), initial_model.parameters()):
119
            torch.testing.assert_allclose(p1, p2)
120

121
    def test_multiple_optim_for_params(self) -> None:
122
        model = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10))
123
        opt_0_kwargs = {"lr": 0.03}
124
        opt_1_kwargs = {"lr": 0.01}
125
        opt_0 = torch.optim.SGD(model.parameters(), **opt_0_kwargs)
126
        opt_1 = torch.optim.SGD(model.parameters(), **opt_1_kwargs)
127
        model_with_opt_in_bwd = deepcopy(model)
128
        _apply_optimizer_in_backward(
129
            torch.optim.SGD,
130
            model_with_opt_in_bwd.parameters(),
131
            optimizer_kwargs=opt_0_kwargs,
132
        )
133
        _apply_optimizer_in_backward(
134
            torch.optim.SGD,
135
            model_with_opt_in_bwd.parameters(),
136
            optimizer_kwargs=opt_1_kwargs,
137
        )
138
        self._run_training_loop_and_validate(
139
            torch.randn(4, 10),
140
            [model, model_with_opt_in_bwd],
141
            [opt_0, opt_1],
142
        )
143

144
    def test_get_optimizers_in_backward(self):
145
        # Create a simple test model
146
        class TestModel(torch.nn.Module):
147
            def __init__(self) -> None:
148
                super().__init__()
149
                self.linear1 = torch.nn.Linear(10, 5)
150
                self.linear2 = torch.nn.Linear(5, 2)
151

152
        model = TestModel()
153

154
        # Apply optimizers in backward
155
        _apply_optimizer_in_backward(torch.optim.SGD, model.parameters(), {"lr": 0.01})
156
        in_backward_optims = _get_in_backward_optimizers(model)
157
        self.assertEqual(len(list(model.parameters())), len(in_backward_optims))
158
        result = set(in_backward_optims)
159
        expected = {
160
            optim for p in model.parameters() for optim in p._in_backward_optimizers
161
        }
162
        self.assertEqual(result, expected)
163

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.