intel-extension-for-pytorch

Форк
0
/
test_fp8_autocast.py 
230 строк · 7.9 Кб
1
import torch
2
import unittest
3

4
from intel_extension_for_pytorch.quantization.fp8 import (
5
    fp8_autocast,
6
    DelayedScaling,
7
    Format,
8
    prepare_fp8,
9
)
10
import intel_extension_for_pytorch._C as core
11

12
from torch.testing._internal.common_utils import TestCase
13
from torch.optim import SGD
14

15

16
class TestFP8Cases(TestCase):
17
    @unittest.skipIf(
18
        not core.onednn_has_fp8_support(),
19
        "IPEX FP8 is not supported on this CPU device",
20
    )
21
    def test_fp8_linear_base(self):
22
        class MyModel(torch.nn.Module):
23
            def __init__(self):
24
                super().__init__()
25
                self.ln = torch.nn.LayerNorm(5, eps=1e-05)
26
                self.lin1 = torch.nn.Linear(5, 4, bias=False)
27
                self.lin2 = torch.nn.Linear(4, 3, bias=True)
28
                self.dropout = torch.nn.Dropout()
29

30
            def forward(self, x):
31
                x = self.ln(x)
32
                x = self.lin1(x)
33
                x = torch.nn.functional.gelu(x, approximate="tanh")
34
                x = self.lin2(x)
35
                z = self.dropout(x)
36
                return z
37

38
        torch.manual_seed(2024)
39

40
        my_linear = MyModel()
41
        my_linear.train()
42
        inp = torch.randn((10, 7, 3, 5), dtype=torch.float32)
43
        inp1 = inp.clone().requires_grad_(True)
44
        inp2 = inp.clone().requires_grad_(True)
45

46
        origin_optimizer = SGD(my_linear.parameters(), lr=0.01, momentum=0.9)
47
        fp8_linear, ipex_optimizer = prepare_fp8(my_linear, origin_optimizer)
48

49
        with fp8_autocast(
50
            enabled=True,
51
            fp8_recipe=DelayedScaling(fp8_format=Format.E4M3),
52
            device="cpu",
53
        ):
54
            for i in range(10):
55
                torch.manual_seed(2024)
56
                out = fp8_linear(inp2[i])
57
                ipex_optimizer.zero_grad()
58
                out.mean().backward()
59
                ipex_optimizer.step()
60

61
        for i in range(10):
62
            torch.manual_seed(2024)
63
            out_nn = my_linear(inp1[i])
64
            origin_optimizer.zero_grad()
65
            out_nn.mean().backward()
66
            origin_optimizer.step()
67

68
        self.assertEqual(out, out_nn, atol=0.05, rtol=0.1)
69
        self.assertEqual(inp1[-1].grad, inp2[-1].grad, atol=0.01, rtol=0.1)
70

71
        origin_model_state = my_linear.state_dict()
72
        ipex_model_state = fp8_linear.state_dict()
73
        for var_name in origin_model_state:
74
            self.assertEqual(
75
                origin_model_state[var_name],
76
                ipex_model_state[var_name],
77
                atol=0.01,
78
                rtol=0.1,
79
            )
80
        for name, _ in fp8_linear.named_children():
81
            if hasattr(getattr(my_linear, name), "weight"):
82
                if getattr(my_linear, name).weight is not None:
83
                    self.assertEqual(
84
                        getattr(my_linear, name).weight.grad,
85
                        getattr(fp8_linear, name).weight.grad,
86
                        atol=0.01,
87
                        rtol=0.1,
88
                    )
89
            if hasattr(getattr(my_linear, name), "bias"):
90
                if getattr(my_linear, name).bias is not None:
91
                    self.assertEqual(
92
                        getattr(my_linear, name).bias.grad,
93
                        getattr(fp8_linear, name).bias.grad,
94
                        atol=0.01,
95
                        rtol=0.1,
96
                    )
97

98
        origin_optimizer_state = origin_optimizer.state_dict()
99
        ipex_optimizer_state = ipex_optimizer.state_dict()
100
        for var_name in origin_optimizer_state:
101
            if var_name == "state":
102
                print(origin_optimizer_state[var_name])
103
                print(ipex_optimizer_state[var_name])
104
                self.assertEqual(
105
                    origin_optimizer_state[var_name],
106
                    ipex_optimizer_state[var_name],
107
                    atol=0.01,
108
                    rtol=0.1,
109
                )
110

111
    @unittest.skipIf(
112
        not core.onednn_has_fp8_support(),
113
        "IPEX FP8 is not supported on this CPU device",
114
    )
115
    def test_fp8_linear_calibration(self):
116
        class ClassA(torch.nn.Module):
117
            def __init__(self):
118
                super().__init__()
119
                self.ln = torch.nn.LayerNorm(5, eps=1e-05)
120

121
            def forward(self, x):
122
                z = self.ln(x)
123
                return z
124

125
        class ClassC(torch.nn.Module):
126
            def __init__(self):
127
                super().__init__()
128
                self.lin2 = torch.nn.Linear(4, 3, bias=True)
129
                self.dropout = torch.nn.Dropout()
130

131
            def forward(self, x):
132
                x = self.lin2(x)
133
                z = self.dropout(x)
134
                return z
135

136
        class ClassB(torch.nn.Module):
137
            def __init__(self):
138
                super().__init__()
139
                self.lin1 = torch.nn.Linear(5, 4, bias=False)
140
                self.lin2_dropout = ClassC()
141

142
            def forward(self, x):
143
                x = self.lin1(x)
144
                x = torch.nn.functional.gelu(x, approximate="tanh")
145
                z = self.lin2_dropout(x)
146
                return z
147

148
        class MyModel(torch.nn.Module):
149
            def __init__(self):
150
                super().__init__()
151
                self.ln = ClassA()
152
                self.lin1_gelu = ClassB()
153

154
            def forward(self, x):
155
                x = self.ln(x)
156
                z = self.lin1_gelu(x)
157
                return z
158

159
        # FP32 reference model
160
        my_linear = MyModel()
161
        my_linear.train()
162
        inp = torch.randn((5, 7, 3, 5), dtype=torch.float32)
163
        inp1 = inp.clone().requires_grad_(True)
164
        inp2 = inp.clone().requires_grad_(False)
165

166
        origin_optimizer = SGD(my_linear.parameters(), lr=0.01, momentum=0.9)
167

168
        for i in range(4):
169
            out_nn = my_linear(inp1[i])
170
            origin_optimizer.zero_grad()
171
            out_nn.mean().backward()
172
            origin_optimizer.step()
173

174
        torch.save(my_linear.state_dict(), "my_linear_inference.pt")
175
        my_linear_inference = MyModel()
176
        my_linear_inference.load_state_dict(torch.load("my_linear_inference.pt"))
177
        my_linear_inference.eval()
178
        out_nn_iter5 = my_linear_inference(inp1[4])
179

180
        fp8_linear_inference = prepare_fp8(my_linear_inference)
181
        # Do calibration to store amax of input and weight
182
        for i in range(4):
183
            with fp8_autocast(
184
                enabled=False,
185
                calibrating=True,
186
                fp8_recipe=DelayedScaling(fp8_format=Format.E4M3),
187
                device="cpu",
188
            ):
189
                _ = fp8_linear_inference(inp2[i])
190
        torch.save(fp8_linear_inference.state_dict(), "fp8_linear_inference.pt")
191

192
        # FP8 model with calibration
193
        fp8_linear_with_calibration = MyModel()
194
        fp8_linear_with_calibration = prepare_fp8(fp8_linear_with_calibration)
195
        fp8_linear_with_calibration.load_state_dict(
196
            torch.load("fp8_linear_inference.pt")
197
        )
198
        fp8_linear_with_calibration.eval()
199

200
        # Run model inference using calibration data
201
        with fp8_autocast(
202
            enabled=True,
203
            calibrating=False,
204
            fp8_recipe=DelayedScaling(fp8_format=Format.E4M3),
205
            device="cpu",
206
        ):
207
            out_fp8_iter5 = fp8_linear_with_calibration(inp2[4])
208
        self.assertEqual(out_fp8_iter5, out_nn_iter5, atol=0.01, rtol=0.1)
209

210
    @unittest.skipIf(
211
        not core.onednn_has_fp8_support(),
212
        "IPEX FP8 is not supported on this CPU device",
213
    )
214
    def test_fp8_non_contiguous_weight(self):
215
        nn_linear = torch.nn.Linear(2, 2)
216
        nn_linear.weight = torch.nn.Parameter(nn_linear.weight.transpose(0, 1))
217
        inp = torch.ones(3, 2)
218
        fp8_linear = prepare_fp8(nn_linear)
219
        with fp8_autocast(
220
            enabled=True,
221
            fp8_recipe=DelayedScaling(fp8_format=Format.E4M3),
222
            device="cpu",
223
        ):
224
            fp8_out = fp8_linear(inp)
225
        nn_out = nn_linear(inp)
226
        self.assertEqual(nn_out, fp8_out, atol=0.01, rtol=0.1)
227

228

229
if __name__ == "__main__":
230
    test = unittest.main()
231

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.