intel-extension-for-pytorch

Форк
0
/
test_ipex_optimize.py 
917 строк · 35.6 Кб
1
import torch
2
import torch.fx.experimental.optimization as optimization
3
import intel_extension_for_pytorch as ipex
4
import intel_extension_for_pytorch._C as core
5
from intel_extension_for_pytorch.nn.utils._weight_prepack import (
6
    _IPEXLinear as _IPEXLinear,
7
    _IPEXConv2d as _IPEXConv2d,
8
)
9
from torch.testing._internal.common_utils import TestCase
10
from torch.optim import (
11
    Adadelta,
12
    Adagrad,
13
    Adam,
14
    AdamW,
15
    Adamax,
16
    ASGD,
17
    RMSprop,
18
    Rprop,
19
    SGD,
20
)
21
import unittest
22
import itertools
23
import copy
24
from common_utils import TestModule, _empty_weight_bias_parameter_names
25
from intel_extension_for_pytorch.optim._lamb import Lamb
26
import os
27

28
try:
29
    import transformers
30

31
    HAS_TRANSFORMERS = True
32
except ImportError:
33
    HAS_TRANSFORMERS = False
34
skipIfNoTransformers = unittest.skipIf(not HAS_TRANSFORMERS, "no transformers")
35

36
curpath = os.path.abspath(os.path.dirname(__file__))
37

38

39
class ConvBatchNorm(torch.nn.Module):
40
    def __init__(
41
        self,
42
    ):
43
        super(ConvBatchNorm, self).__init__()
44
        self.input1 = torch.randn(1, 3, 224, 224)
45
        self.conv = torch.nn.Conv2d(
46
            3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3)
47
        )
48
        self.bn = torch.nn.BatchNorm2d(
49
            64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
50
        )
51

52
    def forward(self, x):
53
        return self.bn(self.conv(x))
54

55

56
class TwoLayerMLP(torch.nn.Module):
57
    def __init__(self):
58
        super(TwoLayerMLP, self).__init__()
59
        self.input1 = torch.randn(2, 2)
60
        self.input2 = torch.randn(3, 3)
61
        self.l1 = torch.nn.Linear(2, 2)
62
        self.l2 = torch.nn.Linear(3, 3)
63

64
    def forward(self, x1, x2):
65
        return self.l1(x1).sum() + self.l2(x2).sum()
66

67

68
class OneLayerMLP(torch.nn.Module):
69
    def __init__(self):
70
        super(OneLayerMLP, self).__init__()
71
        self.input1 = torch.randn(2, 2)
72
        self.l1 = torch.nn.Linear(2, 2)
73

74
    def forward(self, x1):
75
        return self.l1(x1)
76

77

78
class ConvTranspose2d(torch.nn.Module):
79
    def __init__(
80
        self,
81
    ):
82
        super(ConvTranspose2d, self).__init__()
83
        self.conv_transpose2d = torch.nn.ConvTranspose2d(5, 5, (3, 3))
84
        self.input1 = torch.randn(5, 5, 3, 3)
85

86
    def forward(self, x):
87
        x = self.conv_transpose2d(x)
88
        return x
89

90

91
class LinearBatchNormNd(torch.nn.Module):
92
    def __init__(self, dim):
93
        super(LinearBatchNormNd, self).__init__()
94
        self.linear = torch.nn.Linear(32, 32)
95
        if dim == 1:
96
            self.input1 = torch.randn(1, 32)
97
            self.bn = torch.nn.BatchNorm1d(32)
98
        elif dim == 2:
99
            self.input1 = torch.randn(1, 32, 32, 32)
100
            self.bn = torch.nn.BatchNorm2d(32)
101
        elif dim == 3:
102
            self.input1 = torch.randn(1, 32, 32, 32, 32)
103
            self.bn = torch.nn.BatchNorm3d(32)
104

105
    def forward(self, x):
106
        return self.bn(self.linear(x))
107

108

109
class ConvBatchNormLinearBatchNorm(torch.nn.Module):
110
    def __init__(
111
        self,
112
    ):
113
        super(ConvBatchNormLinearBatchNorm, self).__init__()
114
        self.input1 = torch.randn(1, 32, 32, 32)
115
        self.conv = torch.nn.Conv2d(32, 32, 1)
116
        self.bn1 = torch.nn.BatchNorm2d(32)
117
        self.linear = torch.nn.Linear(32, 32)
118
        self.bn2 = torch.nn.BatchNorm2d(32)
119

120
    def forward(self, x):
121
        return self.bn2(self.linear(self.bn1(self.conv(x))))
122

123

124
class TestOptimizeCases(TestCase):
125
    def test_optimize_conv_bn_parameters_behavior(self):
126
        model = ConvBatchNorm().eval()
127
        pre_te_enable_status = torch._C._jit_texpr_fuser_enabled()
128
        torch._C._jit_set_texpr_fuser_enabled(False)
129
        for level in ["O0", "O1"]:
130
            for conv_bn_folding in [True, False]:
131
                opt_M = ipex.optimize(
132
                    model,
133
                    level=level,
134
                    dtype=torch.float,
135
                    conv_bn_folding=conv_bn_folding,
136
                )
137
                with torch.no_grad():
138
                    x = model.input1
139
                    traced_model = torch.jit.trace(opt_M, x)
140
                    trace_graph = traced_model.graph_for(x)
141
                self.assertEqual(
142
                    any(n.kind() == "ipex::batch_norm" for n in trace_graph.nodes()),
143
                    not (conv_bn_folding),
144
                )
145
            # TODO check weight_prepack.
146
        torch._C._jit_set_texpr_fuser_enabled(pre_te_enable_status)
147

148
    def test_optimize_linear_bn_parameters_behavior(self):
149
        for dim in [1, 2, 3]:
150
            model = LinearBatchNormNd(dim=dim).eval()
151
            for level in ["O0", "O1"]:
152
                for linear_bn_folding in [True, False]:
153
                    opt_M = ipex.optimize(
154
                        model,
155
                        level=level,
156
                        dtype=torch.float,
157
                        linear_bn_folding=linear_bn_folding,
158
                    )
159
                    with torch.no_grad():
160
                        x = model.input1
161
                        traced_model = torch.jit.trace(opt_M, x)
162
                        trace_graph = traced_model.graph_for(x)
163
                    self.assertEqual(
164
                        any(
165
                            n.kind() == "ipex::batch_norm" for n in trace_graph.nodes()
166
                        ),
167
                        not (linear_bn_folding),
168
                    )
169

170
    def test_optimize_conv_bn_linear_bn_parameters_behavior(self):
171
        model = ConvBatchNormLinearBatchNorm().eval()
172
        max_num_folding = 2
173
        for level in ["O0", "O1"]:
174
            for conv_bn_folding in [True, False]:
175
                for linear_bn_folding in [True, False]:
176
                    opt_M = ipex.optimize(
177
                        model,
178
                        level=level,
179
                        dtype=torch.float,
180
                        conv_bn_folding=conv_bn_folding,
181
                        linear_bn_folding=linear_bn_folding,
182
                    )
183
                    with torch.no_grad():
184
                        x = model.input1
185
                        traced_model = torch.jit.trace(opt_M, x)
186
                        trace_graph = traced_model.graph_for(x)
187
                    self.assertEqual(
188
                        len(
189
                            [
190
                                n
191
                                for n in trace_graph.nodes()
192
                                if n.kind() == "ipex::batch_norm"
193
                            ]
194
                        ),
195
                        max_num_folding - (conv_bn_folding + linear_bn_folding),
196
                    )
197

198
    def test_optimize_bf16_model(self):
199
        model = ConvBatchNorm()
200
        optimized_model = ipex.optimize(model.eval(), dtype=torch.bfloat16)
201
        # model should not has master weight attr for infernence model.
202
        self.assertTrue(not hasattr(optimized_model.conv, "master_weight"))
203
        # model should has master weight attr for infernence model.
204
        sgd = torch.optim.SGD(model.parameters(), lr=0.1)
205
        optimized_model, optimized_sgd = ipex.optimize(
206
            model.train(),
207
            optimizer=sgd,
208
            dtype=torch.bfloat16,
209
            split_master_weight_for_bf16=False,
210
        )
211
        self.assertEqual(optimized_model.conv.weight.dtype, torch.bfloat16)
212

213
        def found_wrapper(parameter, params_attr):
214
            for _, v in params_attr.items():
215
                if parameter is v.parameter:
216
                    return v
217
            return None
218

219
        wrapper = found_wrapper(optimized_model.conv.weight, optimized_sgd.params_attr)
220
        self.assertTrue(wrapper is not None)
221
        self.assertEqual(wrapper.master_parameter.dtype, torch.float)
222

223
    @skipIfNoTransformers
224
    def test_optimize_bf16_AlbertMLMHead(self):
225
        from transformers.models import albert
226
        from intel_extension_for_pytorch.nn.utils import _parameter_wrapper
227

228
        config = transformers.AutoConfig.from_pretrained(
229
            f"{curpath}/hf_configs/albert-base-v1"
230
        )
231
        model = albert.modeling_albert.AlbertForMaskedLM(config)
232
        params_attr = {}
233
        _parameter_wrapper.get_shared_parameter_status(model, params_attr)
234
        for name, param in model.named_parameters():
235
            if name == "albert.embeddings.word_embeddings.weight":
236
                self.assertTrue(
237
                    albert.modeling_albert.AlbertMLMHead
238
                    in params_attr[param].modules_cls
239
                )
240
                self.assertEqual(param.dtype, torch.float32)
241
                self.assertTrue(params_attr[param].can_cast_inference(torch.bfloat16))
242
                params_attr[param].cast_for_inference(torch.bfloat16)
243
                self.assertEqual(param.dtype, torch.bfloat16)
244
                break
245

246
    def test_optimize_pretrain_model(self):
247
        optimizer_options = [
248
            Lamb,
249
            Adadelta,
250
            Adagrad,
251
            Adam,
252
            AdamW,
253
            Adamax,
254
            ASGD,
255
            # RMSprop, # TODO: accuracy fails on SPR starting from oneDNN commit 0f354d
256
            Rprop,
257
            SGD,
258
        ]
259

260
        options = itertools.product([torch.float, torch.bfloat16], optimizer_options)
261
        for dtype, optimizer in options:
262
            model = ConvBatchNorm().to(memory_format=torch.channels_last).train()
263
            model.conv.weight.requires_grad_(False)
264
            model.conv.bias.requires_grad_(False)
265
            origin_model = copy.deepcopy(model)
266
            lr = 1e-4 if optimizer is SGD else 1e-2
267
            origin_optimizer = optimizer(origin_model.parameters(), lr=lr)
268
            ipex_model, ipex_optimizer = ipex.optimize(
269
                origin_model, optimizer=origin_optimizer, dtype=dtype
270
            )
271
            self.assertEqual(
272
                origin_model.conv.weight.requires_grad,
273
                ipex_model.conv.weight.requires_grad,
274
            )
275
            self.assertEqual(
276
                origin_model.conv.bias.requires_grad, ipex_model.conv.bias.requires_grad
277
            )
278
            self.assertEqual(
279
                origin_model.bn.weight.requires_grad, ipex_model.bn.weight.requires_grad
280
            )
281
            self.assertEqual(
282
                origin_model.bn.bias.requires_grad, ipex_model.bn.bias.requires_grad
283
            )
284

285
            x = model.input1.to(memory_format=torch.channels_last)
286
            origin_x = x.clone()
287
            ipex_x = x.clone()
288
            with torch.cpu.amp.autocast(enabled=True, dtype=dtype):
289
                y1 = origin_model(origin_x)
290
                grad_y = torch.ones_like(y1)
291
                origin_optimizer.zero_grad()
292
                y1.backward(grad_y)
293
                origin_optimizer.step()
294
                # train one step for ipex.
295
                y2 = ipex_model(ipex_x)
296
                ipex_optimizer.zero_grad()
297
                y2.backward(grad_y)
298
                ipex_optimizer.step()
299
                self.assertEqual(y1, y2, rtol=1e-4, atol=5e-02)
300
                origin_model_state = origin_model.state_dict()
301
                ipex_model_state = ipex_model.state_dict()
302
                for var_name in origin_model_state:
303
                    self.assertEqual(
304
                        origin_model_state[var_name],
305
                        ipex_model_state[var_name],
306
                        rtol=1e-4,
307
                        atol=5e-02,
308
                    )
309
                self.assertTrue(origin_model.conv.weight.grad is None)
310
                self.assertTrue(ipex_model.conv.weight.grad is None)
311

312
    def test_optimize_unsupport_dtype_conversion(self):
313
        class Conv(torch.nn.Module):
314
            def __init__(
315
                self,
316
            ):
317
                super(Conv, self).__init__()
318
                self.conv = torch.nn.Conv2d(
319
                    3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
320
                )
321

322
            def forward(self, x):
323
                return self.conv(x)
324

325
        model = Conv().double()
326
        with self.assertWarnsRegex(
327
            UserWarning, "WARNING: Can't convert model's parameters dtype"
328
        ):
329
            optimized_model = ipex.optimize(model.eval(), dtype=torch.bfloat16)
330

331
    def test_optimize_bf16_upsupported(self):
332
        class Conv(torch.nn.Module):
333
            def __init__(
334
                self,
335
            ):
336
                super(Conv, self).__init__()
337
                self.conv = torch.nn.Conv2d(
338
                    3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
339
                )
340

341
        def forward(self, x):
342
            return self.conv(x)
343

344
        model = Conv()
345
        if not core.onednn_has_bf16_support():
346
            msg = r"BF16 weight prepack needs the cpu support avx512bw, avx512vl and avx512dq, \
347
                please set dtype to torch.float or set weights_prepack to False."
348
            with self.assertRaisesRegex(AssertionError, msg):
349
                optimized_model = ipex.optimize(model.eval(), dtype=torch.bfloat16)
350

351
    def test_optimize_unsupport_freeze_optimization(self):
352
        model = ConvBatchNorm().eval()
353
        x = model.input1
354
        with torch.no_grad():
355
            traced_model = torch.jit.trace(model, x)
356
            frozen_model = torch.jit.freeze(traced_model)
357
        optimized_model = ipex.optimize(frozen_model)
358
        self.assertTrue(frozen_model == optimized_model)
359

360
    def test_optimize_inplace_behavior_eval_mode(self):
361
        M_ori = TestModule()
362
        options = itertools.product([torch.float32, torch.bfloat16], ["O0", "O1"])
363
        for dtype, level in options:
364
            # non-inplace
365
            M = copy.deepcopy(M_ori).eval()
366
            opt_M = ipex.optimize(M, dtype=dtype, level=level, inplace=False)
367
            self.assertTrue(
368
                M.linear.weight.data_ptr() != opt_M.linear.weight.data_ptr()
369
            )
370
            self.assertTrue(M.conv.weight.data_ptr() != opt_M.conv.weight.data_ptr())
371
            self.assertTrue(
372
                M.embeddingbag.weight.data_ptr() != opt_M.embeddingbag.weight.data_ptr()
373
            )
374

375
            # inplace
376
            M = copy.deepcopy(M_ori).eval()
377
            opt_M = ipex.optimize(M, dtype=dtype, level=level, inplace=True)
378
            # After ConvBN folding,  opt_M will be Graph Module while the M is original nn.Module which they
379
            # share parameters. But the changes on Graph Module cannot be reflected on original module. So
380
            # only the un-opitimized weight will use same mem buffer with original module.
381
            if level == "O1":
382
                self.assertTrue(
383
                    M.conv.weight.data_ptr() != opt_M.conv.weight.data_ptr()
384
                )  # linear is optimized and used same parameter with original model
385
                self.assertTrue(M.linear.weight is opt_M.linear.weight)
386
                self.assertTrue(isinstance(opt_M.linear, _IPEXLinear))
387
            # un-optimized part should be inplaced
388
            self.assertTrue(
389
                M.embeddingbag.weight.data_ptr() == opt_M.embeddingbag.weight.data_ptr()
390
            )
391

392
    def test_optimize_inplace_behavior_training_mode_with_optimizer(self):
393
        M_ori = TestModule()
394
        options = itertools.product([torch.float32, torch.bfloat16], ["O0", "O1"])
395
        for dtype, level in options:
396
            # non-inplace
397
            M = copy.deepcopy(M_ori).train()
398
            sgd = torch.optim.SGD(M.parameters(), lr=0.1)
399
            opt_M, _ = ipex.optimize(
400
                M, dtype=dtype, optimizer=sgd, level=level, inplace=False
401
            )
402
            self.assertTrue(
403
                M.linear.weight.data_ptr() != opt_M.linear.weight.data_ptr()
404
            )
405
            self.assertTrue(M.conv.weight.data_ptr() != opt_M.conv.weight.data_ptr())
406
            self.assertTrue(
407
                M.embeddingbag.weight.data_ptr() != opt_M.embeddingbag.weight.data_ptr()
408
            )
409
            if level == "O1":
410
                self.assertEqual(M.linear.weight.dtype, torch.float)
411
                self.assertEqual(M.conv.weight.dtype, torch.float)
412
                self.assertEqual(M.embeddingbag.weight.dtype, torch.float)
413
                self.assertEqual(M.bn.weight.dtype, torch.float)
414
                self.assertEqual(opt_M.linear.weight.dtype, dtype)
415
                self.assertEqual(opt_M.conv.weight.dtype, dtype)
416
                self.assertEqual(opt_M.embeddingbag.weight.dtype, dtype)
417
                self.assertEqual(opt_M.bn.weight.dtype, torch.float)
418

419
            # inplace
420
            M = copy.deepcopy(M_ori).train()
421
            sgd = torch.optim.SGD(M.parameters(), lr=0.1)
422
            opt_M, _ = ipex.optimize(
423
                M, dtype=dtype, optimizer=sgd, level=level, inplace=True
424
            )
425
            self.assertTrue(
426
                M.linear.weight.data_ptr() == opt_M.linear.weight.data_ptr()
427
            )
428
            self.assertTrue(M.conv.weight.data_ptr() == opt_M.conv.weight.data_ptr())
429
            self.assertTrue(
430
                M.embeddingbag.weight.data_ptr() == opt_M.embeddingbag.weight.data_ptr()
431
            )
432
            if level == "O1":
433
                self.assertEqual(M.linear.weight.dtype, dtype)
434
                self.assertEqual(M.conv.weight.dtype, dtype)
435
                self.assertEqual(M.embeddingbag.weight.dtype, dtype)
436
                self.assertEqual(M.bn.weight.dtype, torch.float)
437

438
    def _test_tensor_convert(self, tensor, bf16_tensor):
439
        top_half, bot_half = torch.ops.torch_ipex.split_float_bfloat16(tensor)
440
        # truncated top half should equal with convert fp32 to bf16 by ".bfloat()"
441
        self.assertEqual(bf16_tensor, top_half)
442
        # recovery float tensor with top half and bottom half
443
        float_tensor = torch.ops.torch_ipex.cat_bfloat16_float(top_half, bot_half)
444
        self.assertEqual(tensor, float_tensor)
445
        self.assertEqual(tensor.stride(), top_half.stride())
446
        self.assertEqual(tensor.stride(), float_tensor.stride())
447

448
    def test_tensor_convert(self):
449
        # contiguous case
450
        tensor = torch.rand(100, 100)
451
        self._test_tensor_convert(tensor, tensor.bfloat16())
452
        # transposed case
453
        self._test_tensor_convert(tensor.t(), tensor.bfloat16().t())
454
        # sliced-out case
455
        self._test_tensor_convert(tensor[2:5, 2:5], tensor.bfloat16()[2:5, 2:5])
456
        # nc11 channel-last case
457
        tensor = torch.rand(128, 256, 1, 1).to(memory_format=torch.channels_last)
458
        self._test_tensor_convert(tensor, tensor.bfloat16())
459

460
    def test_module_conversion(self):
461
        M_ori = TestModule()
462
        options = itertools.product(
463
            [torch.bfloat16, torch.float32], ["O0", "O1"], [True, False]
464
        )
465
        for dtype, level, auto_kernel_selection in options:
466
            sgd = torch.optim.SGD(M_ori.parameters(), lr=0.1)
467
            opt_M, _ = ipex.optimize(
468
                M_ori,
469
                dtype=dtype,
470
                optimizer=sgd,
471
                level=level,
472
                auto_kernel_selection=auto_kernel_selection,
473
            )
474
            if level == "O0":
475
                self.assertTrue(isinstance(opt_M.linear, torch.nn.Linear))
476
                self.assertTrue(isinstance(opt_M.conv, torch.nn.Conv2d))
477
            else:
478
                if not auto_kernel_selection and dtype == torch.float32:
479
                    self.assertTrue(isinstance(opt_M.linear, torch.nn.Linear))
480
                else:
481
                    self.assertTrue(isinstance(opt_M.linear, _IPEXLinear))
482
                self.assertTrue(isinstance(opt_M.conv, _IPEXConv2d))
483

484
    def test_record_shape(self):
485
        options = itertools.product([OneLayerMLP, TwoLayerMLP], [True, False])
486
        for module, inference_only in options:
487
            M = module()
488
            input = M.input1
489
            if isinstance(M, TwoLayerMLP):
490
                input = (M.input1, M.input2)
491
            if inference_only:
492
                M.eval()
493
                opt_M = ipex.optimize(M, sample_input=input, auto_kernel_selection=True)
494
            else:
495
                optimizer = torch.optim.SGD(M.parameters(), lr=0.01)
496
                opt_M, _ = ipex.optimize(
497
                    M,
498
                    optimizer=optimizer,
499
                    sample_input=input,
500
                    auto_kernel_selection=True,
501
                )
502
            self.assertEqual(opt_M.l1.batch_size_collapsed, 2)
503
            if isinstance(M, TwoLayerMLP):
504
                self.assertEqual(opt_M.l2.batch_size_collapsed, 3)
505

506
    def test_traced_model_serialization(self):
507
        for module in [ConvBatchNorm, OneLayerMLP, ConvTranspose2d]:
508
            for dtype in [torch.float, torch.bfloat16]:
509
                M = module().eval()
510
                input = M.input1.to(dtype)
511
                opt_M = ipex.optimize(M, dtype=dtype, auto_kernel_selection=True)
512
                with torch.no_grad():
513
                    traced_M = torch.jit.trace(opt_M, input).eval()
514
                    traced_M.save("traced_m.pt")
515
                    loaded_M = torch.jit.load("traced_m.pt")
516
                    self.assertEqual(traced_M(input), loaded_M(input))
517
                    os.remove("traced_m.pt")
518

519
    def test_optimized_model_with_fx(self):
520
        for module in [ConvBatchNorm, OneLayerMLP, ConvTranspose2d]:
521
            for dtype in [torch.float, torch.bfloat16]:
522
                M = module().eval()
523
                input = M.input1.to(dtype)
524
                opt_M = ipex.optimize(M, dtype=dtype, auto_kernel_selection=True)
525
                ref_out = opt_M(input)
526
                fx_M = optimization.fuse(opt_M)
527
                fx_out = fx_M(input)
528
                self.assertEqual(ref_out, fx_out)
529
                with torch.no_grad():
530
                    traced_M = torch.jit.trace(fx_M, input).eval()
531
                    traced_M = torch.jit.freeze(traced_M)
532
                    # do graph opt
533
                    traced_M(input)
534
                    # get optimized results
535
                    out = traced_M(input)
536
                    self.assertEqual(ref_out, out)
537

538
    def test_optimized_model_with_sample_input(self):
539
        for module in [ConvBatchNorm, OneLayerMLP, ConvTranspose2d]:
540
            model = module().train()
541
            input = model.input1
542
            optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
543
            origin_model_state = copy.deepcopy(model.state_dict())
544
            ipex_model, _ = ipex.optimize(
545
                model,
546
                dtype=torch.float32,
547
                inplace=False,
548
                optimizer=optimizer,
549
                sample_input=input,
550
            )
551
            ipex_model_state = ipex_model.state_dict()
552
            for var_name in origin_model_state:
553
                self.assertEqual(
554
                    origin_model_state[var_name], ipex_model_state[var_name]
555
                )
556

557
    def test_partial_model_update(self):
558
        class M(torch.nn.Module):
559
            def __init__(self):
560
                super(M, self).__init__()
561
                self.L1 = torch.nn.Linear(10, 10)
562
                self.L2 = torch.nn.Linear(10, 10)
563

564
            def forward(self, x):
565
                return (self.L1(x), self.L2(x))
566

567
        model = M()
568
        optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, eps=1e-8)
569
        model.train()
570
        model, optimizer = ipex.optimize(
571
            model, optimizer=optimizer, dtype=torch.bfloat16
572
        )
573

574
        with torch.cpu.amp.autocast():
575
            loss = model(torch.rand(10, 10))[0].sum()
576

577
        loss.backward()
578
        optimizer.step()
579

580
    def _test_load_after_ipex_optimize_inference(
581
        self, model_class, dtype, optimizer_class, level, inplace
582
    ):
583
        model = model_class().train()
584
        input = model.input
585
        if optimizer_class == SGD:
586
            optimizer = optimizer_class(model.parameters(), lr=10.01, momentum=0.1)
587
        else:
588
            optimizer = optimizer_class(model.parameters(), lr=10.01)
589
        ipex_model, ipex_optimizer = ipex.optimize(
590
            model,
591
            dtype=dtype,
592
            optimizer=optimizer,
593
            sample_input=input,
594
            level=level,
595
            inplace=inplace,
596
        )
597
        # train 2 iters to save something in optimizer's state
598
        for _ in range(2):
599
            with torch.cpu.amp.autocast(enabled=True, dtype=dtype):
600
                y = ipex_model(*input).sum()
601
            ipex_optimizer.zero_grad()
602
            y.backward()
603
            ipex_optimizer.step()
604

605
        inf_model = model_class().eval()
606
        inf_model_state = inf_model.state_dict()
607
        ipex_inf_model = ipex.optimize(
608
            inf_model, dtype=dtype, sample_input=input, level=level, inplace=inplace
609
        )
610
        # check parameters are not same before load
611
        ipex_model_state = ipex_model.state_dict()
612
        for var_name in ipex_model_state:
613
            self.assertNotEqual(ipex_model_state[var_name], inf_model_state[var_name])
614
        for p1 in ipex_model.named_parameters():
615
            prefix, attr = p1[0].split(".")
616
            sub_m = getattr(ipex_inf_model, prefix)
617
            param = getattr(sub_m, attr)
618
            # the empty weight and bias tensor will always be Tensor()
619
            assert_fn = (
620
                self.assertEqual
621
                if p1[0]
622
                in _empty_weight_bias_parameter_names(
623
                    prefixes=["conv", "linear", "conv_transpose2d"]
624
                )
625
                else self.assertNotEqual
626
            )
627
            assert_fn(p1[1], param)
628

629
        # check parameters are same after load
630
        ipex_inf_model.load_state_dict(ipex_model_state)
631
        inf_model_state = ipex_inf_model.state_dict()
632
        for var_name in ipex_model_state:
633
            self.assertEqual(
634
                ipex_model_state[var_name].to(dtype).float(), inf_model_state[var_name]
635
            )
636
        for p1 in ipex_model.named_parameters():
637
            if p1[0] == "linear.weight":
638
                # Do not compare linear.weight with block format since
639
                # linear.weight in ipex_model(training model) is plain
640
                continue
641
            prefix, attr = p1[0].split(".")
642
            sub_m = getattr(ipex_inf_model, prefix)
643
            param = getattr(sub_m, attr)
644
            self.assertEqual(p1[1], param)
645

646
    def _test_load_after_ipex_optimize_training(
647
        self, model_class, dtype, optimizer_class, level, inplace
648
    ):
649
        model = model_class().train()
650
        input = model.input
651
        if optimizer_class == SGD:
652
            optimizer = optimizer_class(model.parameters(), lr=10.01, momentum=0.1)
653
        else:
654
            optimizer = optimizer_class(model.parameters(), lr=10.01)
655
        ipex_model, ipex_optimizer = ipex.optimize(
656
            model,
657
            dtype=dtype,
658
            optimizer=optimizer,
659
            sample_input=input,
660
            level=level,
661
            inplace=inplace,
662
        )
663
        # train 2 iters to save something in optimizer's state
664
        for _ in range(2):
665
            with torch.cpu.amp.autocast(enabled=True, dtype=dtype):
666
                y = ipex_model(*input).sum()
667
            ipex_optimizer.zero_grad()
668
            y.backward()
669
            ipex_optimizer.step()
670
        ref_ipex_model = copy.deepcopy(ipex_model)
671
        ref_ipex_optimizer = copy.deepcopy(ipex_optimizer)
672
        ref_ipex_model_state = copy.deepcopy(ipex_model.state_dict())
673
        ref_ipex_optimizer_state = copy.deepcopy(ipex_optimizer.state_dict())
674

675
        # train 2 iters to change model/optimizer state
676
        for _ in range(2):
677
            with torch.cpu.amp.autocast(enabled=True, dtype=dtype):
678
                y = ipex_model(*input).sum()
679
            ipex_optimizer.zero_grad()
680
            y.backward()
681
            ipex_optimizer.step()
682
        # check state changed (with public formt)
683
        ipex_model_state = ipex_model.state_dict()
684
        ipex_optimizer_state = ipex_optimizer.state_dict()
685
        for var_name in ipex_model_state:
686
            self.assertNotEqual(
687
                ipex_model_state[var_name], ref_ipex_model_state[var_name]
688
            )
689
        for var_name in ipex_optimizer_state:
690
            if var_name == "state":
691
                self.assertNotEqual(
692
                    ipex_optimizer_state[var_name], ref_ipex_optimizer_state[var_name]
693
                )
694
        # check values before load (with block format)
695
        for p1, p2 in zip(
696
            ipex_model.named_parameters(), ref_ipex_model.named_parameters()
697
        ):
698
            # the empty weight and bias tensor will always be Tensor()
699
            assert_fn = (
700
                self.assertEqual
701
                if p1[0]
702
                in _empty_weight_bias_parameter_names(
703
                    prefixes=["conv", "linear", "conv_transpose2d"]
704
                )
705
                else self.assertNotEqual
706
            )
707
            assert_fn(p1[1], p2[1])
708
        for (_, v1), (_, v2) in zip(
709
            ipex_optimizer.state.items(), ref_ipex_optimizer.state.items()
710
        ):
711
            self.assertNotEqual(v1, v2)
712
        ipex_model.load_state_dict(ref_ipex_model_state)
713
        ipex_optimizer.load_state_dict(ref_ipex_optimizer_state)
714
        # check values same after load (with block format)
715
        for p1, p2 in zip(
716
            ipex_model.named_parameters(), ref_ipex_model.named_parameters()
717
        ):
718
            self.assertEqual(p1[1], p2[1])
719
        for (_, v1), (_, v2) in zip(
720
            ipex_optimizer.state.items(), ref_ipex_optimizer.state.items()
721
        ):
722
            if "step_size" in v1:
723
                # For Rprop, there is a "clamp" operation on step_size which will change the "zero"
724
                # attribute for packed position.
725
                # The zero pos will be changed after "clamp", and will be zero again after pack and
726
                # repack it. So in ipex_optimizer, the packed pos of "step_size" will be zero but in
727
                # ref_ipex_optimizer, the packed pos of "step_size" will not be zero. Thus the
728
                # assertEqual will be failed.
729
                #    step_sizes=(1e-6, 50)
730
                #    step_size_min, step_size_max = group['step_sizes']
731
                #    step_size.mul_(sign).clamp_(step_size_min, step_size_max)
732
                #    param.addcmul_(grad.sign(), step_size, value=-1)
733
                #    (param = param - grad.sign() * step_size)
734
                # but this step_size will not have impact since grad are zero
735
                v1 = copy.deepcopy(v1)
736
                v1.pop("step_size")
737
                v2 = copy.deepcopy(v2)
738
                v2.pop("step_size")
739
                self.assertEqual(v1, v2)
740

741
        # check state same after load (with plain format)
742
        ipex_model_state = ipex_model.state_dict()
743
        ipex_optimizer_state = ipex_optimizer.state_dict()
744
        for var_name in ipex_model_state:
745
            self.assertEqual(ipex_model_state[var_name], ref_ipex_model_state[var_name])
746
        for var_name in ipex_optimizer_state:
747
            self.assertEqual(
748
                ipex_optimizer_state[var_name], ref_ipex_optimizer_state[var_name]
749
            )
750

751
    # This test case is to simulate the use case of Stable Diffusion fine-tuning
752
    def test_eval_backward(self):
753
        class Model(torch.nn.Module):
754
            def __init__(self):
755
                super(Model, self).__init__()
756
                self.conv = torch.nn.Conv2d(3, 2, kernel_size=(2, 2))
757

758
            def forward(self, x):
759
                return self.conv(x)
760

761
        x = torch.randn(1, 3, 8, 8)
762
        x_optimized = copy.deepcopy(x)
763
        x.requires_grad_()
764
        x_optimized.requires_grad_()
765

766
        m = Model().eval()
767
        optimized_m = ipex.optimize(m)
768

769
        y = m(x)
770
        y.sum().backward()
771

772
        y_optimized = optimized_m(x_optimized)
773
        y_optimized.sum().backward()
774

775
        grad = x.grad
776
        grad_optimized = x_optimized.grad
777

778
        self.assertEqual(grad, grad_optimized)
779

780
    def test_load_after_optimize(self):
781
        class Model(torch.nn.Module):
782
            def __init__(self):
783
                super(Model, self).__init__()
784
                self.input = (
785
                    torch.randn(1, 3, 224, 224),
786
                    torch.randn(100, 100),
787
                    torch.randn(5, 5, 3, 3),
788
                )
789
                self.conv = torch.nn.Conv2d(
790
                    3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3)
791
                )
792
                self.linear = torch.nn.Linear(100, 100)
793
                self.conv_transpose2d = torch.nn.ConvTranspose2d(5, 5, (3, 3))
794

795
            def forward(self, x1, x2, x3):
796
                return (
797
                    self.conv(x1).sum()
798
                    + self.linear(x2).sum()
799
                    + self.conv_transpose2d(x3)
800
                )
801

802
        params_dict = {
803
            "dtype": [torch.float, torch.bfloat16],
804
            "optimizer": [
805
                Lamb,
806
                Adadelta,
807
                Adagrad,
808
                Adam,
809
                AdamW,
810
                Adamax,
811
                ASGD,
812
                RMSprop,
813
                Rprop,
814
                SGD,
815
            ],
816
            "level": ["O0", "O1"],
817
            "inplace": [True, False],
818
        }
819
        for dtype, optimizer, level, inplace in list(
820
            itertools.product(*params_dict.values())
821
        ):
822
            self._test_load_after_ipex_optimize_training(
823
                Model, dtype, optimizer, level, inplace
824
            )
825
            self._test_load_after_ipex_optimize_inference(
826
                Model, dtype, optimizer, level, inplace
827
            )
828

829
    def test_reentrancy_of_ipex_optimize(self):
830
        CALL_NUM = 3
831

832
        class Model(torch.nn.Module):
833
            def __init__(self):
834
                super(Model, self).__init__()
835
                self.input = (
836
                    torch.randn(1, 3, 224, 224),
837
                    torch.randn(100, 100),
838
                    torch.randn(5, 5, 3, 3),
839
                )
840
                self.conv = torch.nn.Conv2d(
841
                    3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3)
842
                )
843
                self.linear = torch.nn.Linear(100, 100)
844
                self.conv_transpose2d = torch.nn.ConvTranspose2d(5, 5, (3, 3))
845

846
            def forward(self, x1, x2, x3):
847
                return (
848
                    self.conv(x1).sum()
849
                    + self.linear(x2).sum()
850
                    + self.conv_transpose2d(x3)
851
                )
852

853
        def run_and_recursively_call_ipex_optimize(
854
            model_class,
855
            dtype,
856
            level,
857
            inplace,
858
            weights_prepack,
859
            split_master_weight_for_bf16,
860
            fuse_update_step,
861
            graph_mode,
862
        ):
863
            model = model_class().train()
864
            input = model.input
865
            optimizer = torch.optim.SGD(model.parameters(), lr=10.01)
866
            for _ in range(CALL_NUM):
867
                # recursively calling ipex.optimize CALL_NUM times
868
                model, optimizer = ipex.optimize(
869
                    model,
870
                    dtype=dtype,
871
                    optimizer=optimizer,
872
                    level=level,
873
                    inplace=inplace,
874
                    weights_prepack=weights_prepack,
875
                    split_master_weight_for_bf16=split_master_weight_for_bf16,
876
                    fuse_update_step=fuse_update_step,
877
                    graph_mode=graph_mode,
878
                )
879
                with torch.cpu.amp.autocast(enabled=True, dtype=dtype):
880
                    y = model(*input).sum()
881
                optimizer.zero_grad()
882
                y.backward()
883
                optimizer.step()
884

885
        params_dict = {
886
            "dtype": [torch.float32, torch.bfloat16],
887
            "level": ["O1"],
888
            "inplace": [True, False],
889
            "weights_prepack": [True, False],
890
            "split_master_weight_for_bf16": [True, False],
891
            "fuse_update_step": [True, False],
892
            "graph_mode": [True, False],
893
        }
894

895
        for (
896
            dtype,
897
            level,
898
            inplace,
899
            weights_prepack,
900
            split_master_weight_for_bf16,
901
            fuse_update_step,
902
            graph_mode,
903
        ) in list(itertools.product(*params_dict.values())):
904
            run_and_recursively_call_ipex_optimize(
905
                Model,
906
                dtype,
907
                level,
908
                inplace,
909
                weights_prepack,
910
                split_master_weight_for_bf16,
911
                fuse_update_step,
912
                graph_mode,
913
            )
914

915

916
if __name__ == "__main__":
917
    test = unittest.main()
918

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.