intel-extension-for-pytorch
917 строк · 35.6 Кб
1import torch2import torch.fx.experimental.optimization as optimization3import intel_extension_for_pytorch as ipex4import intel_extension_for_pytorch._C as core5from intel_extension_for_pytorch.nn.utils._weight_prepack import (6_IPEXLinear as _IPEXLinear,7_IPEXConv2d as _IPEXConv2d,8)
9from torch.testing._internal.common_utils import TestCase10from torch.optim import (11Adadelta,12Adagrad,13Adam,14AdamW,15Adamax,16ASGD,17RMSprop,18Rprop,19SGD,20)
21import unittest22import itertools23import copy24from common_utils import TestModule, _empty_weight_bias_parameter_names25from intel_extension_for_pytorch.optim._lamb import Lamb26import os27
28try:29import transformers30
31HAS_TRANSFORMERS = True32except ImportError:33HAS_TRANSFORMERS = False34skipIfNoTransformers = unittest.skipIf(not HAS_TRANSFORMERS, "no transformers")35
36curpath = os.path.abspath(os.path.dirname(__file__))37
38
39class ConvBatchNorm(torch.nn.Module):40def __init__(41self,42):43super(ConvBatchNorm, self).__init__()44self.input1 = torch.randn(1, 3, 224, 224)45self.conv = torch.nn.Conv2d(463, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3)47)48self.bn = torch.nn.BatchNorm2d(4964, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True50)51
52def forward(self, x):53return self.bn(self.conv(x))54
55
56class TwoLayerMLP(torch.nn.Module):57def __init__(self):58super(TwoLayerMLP, self).__init__()59self.input1 = torch.randn(2, 2)60self.input2 = torch.randn(3, 3)61self.l1 = torch.nn.Linear(2, 2)62self.l2 = torch.nn.Linear(3, 3)63
64def forward(self, x1, x2):65return self.l1(x1).sum() + self.l2(x2).sum()66
67
68class OneLayerMLP(torch.nn.Module):69def __init__(self):70super(OneLayerMLP, self).__init__()71self.input1 = torch.randn(2, 2)72self.l1 = torch.nn.Linear(2, 2)73
74def forward(self, x1):75return self.l1(x1)76
77
78class ConvTranspose2d(torch.nn.Module):79def __init__(80self,81):82super(ConvTranspose2d, self).__init__()83self.conv_transpose2d = torch.nn.ConvTranspose2d(5, 5, (3, 3))84self.input1 = torch.randn(5, 5, 3, 3)85
86def forward(self, x):87x = self.conv_transpose2d(x)88return x89
90
91class LinearBatchNormNd(torch.nn.Module):92def __init__(self, dim):93super(LinearBatchNormNd, self).__init__()94self.linear = torch.nn.Linear(32, 32)95if dim == 1:96self.input1 = torch.randn(1, 32)97self.bn = torch.nn.BatchNorm1d(32)98elif dim == 2:99self.input1 = torch.randn(1, 32, 32, 32)100self.bn = torch.nn.BatchNorm2d(32)101elif dim == 3:102self.input1 = torch.randn(1, 32, 32, 32, 32)103self.bn = torch.nn.BatchNorm3d(32)104
105def forward(self, x):106return self.bn(self.linear(x))107
108
109class ConvBatchNormLinearBatchNorm(torch.nn.Module):110def __init__(111self,112):113super(ConvBatchNormLinearBatchNorm, self).__init__()114self.input1 = torch.randn(1, 32, 32, 32)115self.conv = torch.nn.Conv2d(32, 32, 1)116self.bn1 = torch.nn.BatchNorm2d(32)117self.linear = torch.nn.Linear(32, 32)118self.bn2 = torch.nn.BatchNorm2d(32)119
120def forward(self, x):121return self.bn2(self.linear(self.bn1(self.conv(x))))122
123
124class TestOptimizeCases(TestCase):125def test_optimize_conv_bn_parameters_behavior(self):126model = ConvBatchNorm().eval()127pre_te_enable_status = torch._C._jit_texpr_fuser_enabled()128torch._C._jit_set_texpr_fuser_enabled(False)129for level in ["O0", "O1"]:130for conv_bn_folding in [True, False]:131opt_M = ipex.optimize(132model,133level=level,134dtype=torch.float,135conv_bn_folding=conv_bn_folding,136)137with torch.no_grad():138x = model.input1139traced_model = torch.jit.trace(opt_M, x)140trace_graph = traced_model.graph_for(x)141self.assertEqual(142any(n.kind() == "ipex::batch_norm" for n in trace_graph.nodes()),143not (conv_bn_folding),144)145# TODO check weight_prepack.146torch._C._jit_set_texpr_fuser_enabled(pre_te_enable_status)147
148def test_optimize_linear_bn_parameters_behavior(self):149for dim in [1, 2, 3]:150model = LinearBatchNormNd(dim=dim).eval()151for level in ["O0", "O1"]:152for linear_bn_folding in [True, False]:153opt_M = ipex.optimize(154model,155level=level,156dtype=torch.float,157linear_bn_folding=linear_bn_folding,158)159with torch.no_grad():160x = model.input1161traced_model = torch.jit.trace(opt_M, x)162trace_graph = traced_model.graph_for(x)163self.assertEqual(164any(165n.kind() == "ipex::batch_norm" for n in trace_graph.nodes()166),167not (linear_bn_folding),168)169
170def test_optimize_conv_bn_linear_bn_parameters_behavior(self):171model = ConvBatchNormLinearBatchNorm().eval()172max_num_folding = 2173for level in ["O0", "O1"]:174for conv_bn_folding in [True, False]:175for linear_bn_folding in [True, False]:176opt_M = ipex.optimize(177model,178level=level,179dtype=torch.float,180conv_bn_folding=conv_bn_folding,181linear_bn_folding=linear_bn_folding,182)183with torch.no_grad():184x = model.input1185traced_model = torch.jit.trace(opt_M, x)186trace_graph = traced_model.graph_for(x)187self.assertEqual(188len(189[190n
191for n in trace_graph.nodes()192if n.kind() == "ipex::batch_norm"193]194),195max_num_folding - (conv_bn_folding + linear_bn_folding),196)197
198def test_optimize_bf16_model(self):199model = ConvBatchNorm()200optimized_model = ipex.optimize(model.eval(), dtype=torch.bfloat16)201# model should not has master weight attr for infernence model.202self.assertTrue(not hasattr(optimized_model.conv, "master_weight"))203# model should has master weight attr for infernence model.204sgd = torch.optim.SGD(model.parameters(), lr=0.1)205optimized_model, optimized_sgd = ipex.optimize(206model.train(),207optimizer=sgd,208dtype=torch.bfloat16,209split_master_weight_for_bf16=False,210)211self.assertEqual(optimized_model.conv.weight.dtype, torch.bfloat16)212
213def found_wrapper(parameter, params_attr):214for _, v in params_attr.items():215if parameter is v.parameter:216return v217return None218
219wrapper = found_wrapper(optimized_model.conv.weight, optimized_sgd.params_attr)220self.assertTrue(wrapper is not None)221self.assertEqual(wrapper.master_parameter.dtype, torch.float)222
223@skipIfNoTransformers224def test_optimize_bf16_AlbertMLMHead(self):225from transformers.models import albert226from intel_extension_for_pytorch.nn.utils import _parameter_wrapper227
228config = transformers.AutoConfig.from_pretrained(229f"{curpath}/hf_configs/albert-base-v1"230)231model = albert.modeling_albert.AlbertForMaskedLM(config)232params_attr = {}233_parameter_wrapper.get_shared_parameter_status(model, params_attr)234for name, param in model.named_parameters():235if name == "albert.embeddings.word_embeddings.weight":236self.assertTrue(237albert.modeling_albert.AlbertMLMHead238in params_attr[param].modules_cls239)240self.assertEqual(param.dtype, torch.float32)241self.assertTrue(params_attr[param].can_cast_inference(torch.bfloat16))242params_attr[param].cast_for_inference(torch.bfloat16)243self.assertEqual(param.dtype, torch.bfloat16)244break245
246def test_optimize_pretrain_model(self):247optimizer_options = [248Lamb,249Adadelta,250Adagrad,251Adam,252AdamW,253Adamax,254ASGD,255# RMSprop, # TODO: accuracy fails on SPR starting from oneDNN commit 0f354d256Rprop,257SGD,258]259
260options = itertools.product([torch.float, torch.bfloat16], optimizer_options)261for dtype, optimizer in options:262model = ConvBatchNorm().to(memory_format=torch.channels_last).train()263model.conv.weight.requires_grad_(False)264model.conv.bias.requires_grad_(False)265origin_model = copy.deepcopy(model)266lr = 1e-4 if optimizer is SGD else 1e-2267origin_optimizer = optimizer(origin_model.parameters(), lr=lr)268ipex_model, ipex_optimizer = ipex.optimize(269origin_model, optimizer=origin_optimizer, dtype=dtype270)271self.assertEqual(272origin_model.conv.weight.requires_grad,273ipex_model.conv.weight.requires_grad,274)275self.assertEqual(276origin_model.conv.bias.requires_grad, ipex_model.conv.bias.requires_grad277)278self.assertEqual(279origin_model.bn.weight.requires_grad, ipex_model.bn.weight.requires_grad280)281self.assertEqual(282origin_model.bn.bias.requires_grad, ipex_model.bn.bias.requires_grad283)284
285x = model.input1.to(memory_format=torch.channels_last)286origin_x = x.clone()287ipex_x = x.clone()288with torch.cpu.amp.autocast(enabled=True, dtype=dtype):289y1 = origin_model(origin_x)290grad_y = torch.ones_like(y1)291origin_optimizer.zero_grad()292y1.backward(grad_y)293origin_optimizer.step()294# train one step for ipex.295y2 = ipex_model(ipex_x)296ipex_optimizer.zero_grad()297y2.backward(grad_y)298ipex_optimizer.step()299self.assertEqual(y1, y2, rtol=1e-4, atol=5e-02)300origin_model_state = origin_model.state_dict()301ipex_model_state = ipex_model.state_dict()302for var_name in origin_model_state:303self.assertEqual(304origin_model_state[var_name],305ipex_model_state[var_name],306rtol=1e-4,307atol=5e-02,308)309self.assertTrue(origin_model.conv.weight.grad is None)310self.assertTrue(ipex_model.conv.weight.grad is None)311
312def test_optimize_unsupport_dtype_conversion(self):313class Conv(torch.nn.Module):314def __init__(315self,316):317super(Conv, self).__init__()318self.conv = torch.nn.Conv2d(3193, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False320)321
322def forward(self, x):323return self.conv(x)324
325model = Conv().double()326with self.assertWarnsRegex(327UserWarning, "WARNING: Can't convert model's parameters dtype"328):329optimized_model = ipex.optimize(model.eval(), dtype=torch.bfloat16)330
331def test_optimize_bf16_upsupported(self):332class Conv(torch.nn.Module):333def __init__(334self,335):336super(Conv, self).__init__()337self.conv = torch.nn.Conv2d(3383, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False339)340
341def forward(self, x):342return self.conv(x)343
344model = Conv()345if not core.onednn_has_bf16_support():346msg = r"BF16 weight prepack needs the cpu support avx512bw, avx512vl and avx512dq, \347please set dtype to torch.float or set weights_prepack to False."348with self.assertRaisesRegex(AssertionError, msg):349optimized_model = ipex.optimize(model.eval(), dtype=torch.bfloat16)350
351def test_optimize_unsupport_freeze_optimization(self):352model = ConvBatchNorm().eval()353x = model.input1354with torch.no_grad():355traced_model = torch.jit.trace(model, x)356frozen_model = torch.jit.freeze(traced_model)357optimized_model = ipex.optimize(frozen_model)358self.assertTrue(frozen_model == optimized_model)359
360def test_optimize_inplace_behavior_eval_mode(self):361M_ori = TestModule()362options = itertools.product([torch.float32, torch.bfloat16], ["O0", "O1"])363for dtype, level in options:364# non-inplace365M = copy.deepcopy(M_ori).eval()366opt_M = ipex.optimize(M, dtype=dtype, level=level, inplace=False)367self.assertTrue(368M.linear.weight.data_ptr() != opt_M.linear.weight.data_ptr()369)370self.assertTrue(M.conv.weight.data_ptr() != opt_M.conv.weight.data_ptr())371self.assertTrue(372M.embeddingbag.weight.data_ptr() != opt_M.embeddingbag.weight.data_ptr()373)374
375# inplace376M = copy.deepcopy(M_ori).eval()377opt_M = ipex.optimize(M, dtype=dtype, level=level, inplace=True)378# After ConvBN folding, opt_M will be Graph Module while the M is original nn.Module which they379# share parameters. But the changes on Graph Module cannot be reflected on original module. So380# only the un-opitimized weight will use same mem buffer with original module.381if level == "O1":382self.assertTrue(383M.conv.weight.data_ptr() != opt_M.conv.weight.data_ptr()384) # linear is optimized and used same parameter with original model385self.assertTrue(M.linear.weight is opt_M.linear.weight)386self.assertTrue(isinstance(opt_M.linear, _IPEXLinear))387# un-optimized part should be inplaced388self.assertTrue(389M.embeddingbag.weight.data_ptr() == opt_M.embeddingbag.weight.data_ptr()390)391
392def test_optimize_inplace_behavior_training_mode_with_optimizer(self):393M_ori = TestModule()394options = itertools.product([torch.float32, torch.bfloat16], ["O0", "O1"])395for dtype, level in options:396# non-inplace397M = copy.deepcopy(M_ori).train()398sgd = torch.optim.SGD(M.parameters(), lr=0.1)399opt_M, _ = ipex.optimize(400M, dtype=dtype, optimizer=sgd, level=level, inplace=False401)402self.assertTrue(403M.linear.weight.data_ptr() != opt_M.linear.weight.data_ptr()404)405self.assertTrue(M.conv.weight.data_ptr() != opt_M.conv.weight.data_ptr())406self.assertTrue(407M.embeddingbag.weight.data_ptr() != opt_M.embeddingbag.weight.data_ptr()408)409if level == "O1":410self.assertEqual(M.linear.weight.dtype, torch.float)411self.assertEqual(M.conv.weight.dtype, torch.float)412self.assertEqual(M.embeddingbag.weight.dtype, torch.float)413self.assertEqual(M.bn.weight.dtype, torch.float)414self.assertEqual(opt_M.linear.weight.dtype, dtype)415self.assertEqual(opt_M.conv.weight.dtype, dtype)416self.assertEqual(opt_M.embeddingbag.weight.dtype, dtype)417self.assertEqual(opt_M.bn.weight.dtype, torch.float)418
419# inplace420M = copy.deepcopy(M_ori).train()421sgd = torch.optim.SGD(M.parameters(), lr=0.1)422opt_M, _ = ipex.optimize(423M, dtype=dtype, optimizer=sgd, level=level, inplace=True424)425self.assertTrue(426M.linear.weight.data_ptr() == opt_M.linear.weight.data_ptr()427)428self.assertTrue(M.conv.weight.data_ptr() == opt_M.conv.weight.data_ptr())429self.assertTrue(430M.embeddingbag.weight.data_ptr() == opt_M.embeddingbag.weight.data_ptr()431)432if level == "O1":433self.assertEqual(M.linear.weight.dtype, dtype)434self.assertEqual(M.conv.weight.dtype, dtype)435self.assertEqual(M.embeddingbag.weight.dtype, dtype)436self.assertEqual(M.bn.weight.dtype, torch.float)437
438def _test_tensor_convert(self, tensor, bf16_tensor):439top_half, bot_half = torch.ops.torch_ipex.split_float_bfloat16(tensor)440# truncated top half should equal with convert fp32 to bf16 by ".bfloat()"441self.assertEqual(bf16_tensor, top_half)442# recovery float tensor with top half and bottom half443float_tensor = torch.ops.torch_ipex.cat_bfloat16_float(top_half, bot_half)444self.assertEqual(tensor, float_tensor)445self.assertEqual(tensor.stride(), top_half.stride())446self.assertEqual(tensor.stride(), float_tensor.stride())447
448def test_tensor_convert(self):449# contiguous case450tensor = torch.rand(100, 100)451self._test_tensor_convert(tensor, tensor.bfloat16())452# transposed case453self._test_tensor_convert(tensor.t(), tensor.bfloat16().t())454# sliced-out case455self._test_tensor_convert(tensor[2:5, 2:5], tensor.bfloat16()[2:5, 2:5])456# nc11 channel-last case457tensor = torch.rand(128, 256, 1, 1).to(memory_format=torch.channels_last)458self._test_tensor_convert(tensor, tensor.bfloat16())459
460def test_module_conversion(self):461M_ori = TestModule()462options = itertools.product(463[torch.bfloat16, torch.float32], ["O0", "O1"], [True, False]464)465for dtype, level, auto_kernel_selection in options:466sgd = torch.optim.SGD(M_ori.parameters(), lr=0.1)467opt_M, _ = ipex.optimize(468M_ori,469dtype=dtype,470optimizer=sgd,471level=level,472auto_kernel_selection=auto_kernel_selection,473)474if level == "O0":475self.assertTrue(isinstance(opt_M.linear, torch.nn.Linear))476self.assertTrue(isinstance(opt_M.conv, torch.nn.Conv2d))477else:478if not auto_kernel_selection and dtype == torch.float32:479self.assertTrue(isinstance(opt_M.linear, torch.nn.Linear))480else:481self.assertTrue(isinstance(opt_M.linear, _IPEXLinear))482self.assertTrue(isinstance(opt_M.conv, _IPEXConv2d))483
484def test_record_shape(self):485options = itertools.product([OneLayerMLP, TwoLayerMLP], [True, False])486for module, inference_only in options:487M = module()488input = M.input1489if isinstance(M, TwoLayerMLP):490input = (M.input1, M.input2)491if inference_only:492M.eval()493opt_M = ipex.optimize(M, sample_input=input, auto_kernel_selection=True)494else:495optimizer = torch.optim.SGD(M.parameters(), lr=0.01)496opt_M, _ = ipex.optimize(497M,498optimizer=optimizer,499sample_input=input,500auto_kernel_selection=True,501)502self.assertEqual(opt_M.l1.batch_size_collapsed, 2)503if isinstance(M, TwoLayerMLP):504self.assertEqual(opt_M.l2.batch_size_collapsed, 3)505
506def test_traced_model_serialization(self):507for module in [ConvBatchNorm, OneLayerMLP, ConvTranspose2d]:508for dtype in [torch.float, torch.bfloat16]:509M = module().eval()510input = M.input1.to(dtype)511opt_M = ipex.optimize(M, dtype=dtype, auto_kernel_selection=True)512with torch.no_grad():513traced_M = torch.jit.trace(opt_M, input).eval()514traced_M.save("traced_m.pt")515loaded_M = torch.jit.load("traced_m.pt")516self.assertEqual(traced_M(input), loaded_M(input))517os.remove("traced_m.pt")518
519def test_optimized_model_with_fx(self):520for module in [ConvBatchNorm, OneLayerMLP, ConvTranspose2d]:521for dtype in [torch.float, torch.bfloat16]:522M = module().eval()523input = M.input1.to(dtype)524opt_M = ipex.optimize(M, dtype=dtype, auto_kernel_selection=True)525ref_out = opt_M(input)526fx_M = optimization.fuse(opt_M)527fx_out = fx_M(input)528self.assertEqual(ref_out, fx_out)529with torch.no_grad():530traced_M = torch.jit.trace(fx_M, input).eval()531traced_M = torch.jit.freeze(traced_M)532# do graph opt533traced_M(input)534# get optimized results535out = traced_M(input)536self.assertEqual(ref_out, out)537
538def test_optimized_model_with_sample_input(self):539for module in [ConvBatchNorm, OneLayerMLP, ConvTranspose2d]:540model = module().train()541input = model.input1542optimizer = torch.optim.SGD(model.parameters(), lr=0.01)543origin_model_state = copy.deepcopy(model.state_dict())544ipex_model, _ = ipex.optimize(545model,546dtype=torch.float32,547inplace=False,548optimizer=optimizer,549sample_input=input,550)551ipex_model_state = ipex_model.state_dict()552for var_name in origin_model_state:553self.assertEqual(554origin_model_state[var_name], ipex_model_state[var_name]555)556
557def test_partial_model_update(self):558class M(torch.nn.Module):559def __init__(self):560super(M, self).__init__()561self.L1 = torch.nn.Linear(10, 10)562self.L2 = torch.nn.Linear(10, 10)563
564def forward(self, x):565return (self.L1(x), self.L2(x))566
567model = M()568optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, eps=1e-8)569model.train()570model, optimizer = ipex.optimize(571model, optimizer=optimizer, dtype=torch.bfloat16572)573
574with torch.cpu.amp.autocast():575loss = model(torch.rand(10, 10))[0].sum()576
577loss.backward()578optimizer.step()579
580def _test_load_after_ipex_optimize_inference(581self, model_class, dtype, optimizer_class, level, inplace582):583model = model_class().train()584input = model.input585if optimizer_class == SGD:586optimizer = optimizer_class(model.parameters(), lr=10.01, momentum=0.1)587else:588optimizer = optimizer_class(model.parameters(), lr=10.01)589ipex_model, ipex_optimizer = ipex.optimize(590model,591dtype=dtype,592optimizer=optimizer,593sample_input=input,594level=level,595inplace=inplace,596)597# train 2 iters to save something in optimizer's state598for _ in range(2):599with torch.cpu.amp.autocast(enabled=True, dtype=dtype):600y = ipex_model(*input).sum()601ipex_optimizer.zero_grad()602y.backward()603ipex_optimizer.step()604
605inf_model = model_class().eval()606inf_model_state = inf_model.state_dict()607ipex_inf_model = ipex.optimize(608inf_model, dtype=dtype, sample_input=input, level=level, inplace=inplace609)610# check parameters are not same before load611ipex_model_state = ipex_model.state_dict()612for var_name in ipex_model_state:613self.assertNotEqual(ipex_model_state[var_name], inf_model_state[var_name])614for p1 in ipex_model.named_parameters():615prefix, attr = p1[0].split(".")616sub_m = getattr(ipex_inf_model, prefix)617param = getattr(sub_m, attr)618# the empty weight and bias tensor will always be Tensor()619assert_fn = (620self.assertEqual621if p1[0]622in _empty_weight_bias_parameter_names(623prefixes=["conv", "linear", "conv_transpose2d"]624)625else self.assertNotEqual626)627assert_fn(p1[1], param)628
629# check parameters are same after load630ipex_inf_model.load_state_dict(ipex_model_state)631inf_model_state = ipex_inf_model.state_dict()632for var_name in ipex_model_state:633self.assertEqual(634ipex_model_state[var_name].to(dtype).float(), inf_model_state[var_name]635)636for p1 in ipex_model.named_parameters():637if p1[0] == "linear.weight":638# Do not compare linear.weight with block format since639# linear.weight in ipex_model(training model) is plain640continue641prefix, attr = p1[0].split(".")642sub_m = getattr(ipex_inf_model, prefix)643param = getattr(sub_m, attr)644self.assertEqual(p1[1], param)645
646def _test_load_after_ipex_optimize_training(647self, model_class, dtype, optimizer_class, level, inplace648):649model = model_class().train()650input = model.input651if optimizer_class == SGD:652optimizer = optimizer_class(model.parameters(), lr=10.01, momentum=0.1)653else:654optimizer = optimizer_class(model.parameters(), lr=10.01)655ipex_model, ipex_optimizer = ipex.optimize(656model,657dtype=dtype,658optimizer=optimizer,659sample_input=input,660level=level,661inplace=inplace,662)663# train 2 iters to save something in optimizer's state664for _ in range(2):665with torch.cpu.amp.autocast(enabled=True, dtype=dtype):666y = ipex_model(*input).sum()667ipex_optimizer.zero_grad()668y.backward()669ipex_optimizer.step()670ref_ipex_model = copy.deepcopy(ipex_model)671ref_ipex_optimizer = copy.deepcopy(ipex_optimizer)672ref_ipex_model_state = copy.deepcopy(ipex_model.state_dict())673ref_ipex_optimizer_state = copy.deepcopy(ipex_optimizer.state_dict())674
675# train 2 iters to change model/optimizer state676for _ in range(2):677with torch.cpu.amp.autocast(enabled=True, dtype=dtype):678y = ipex_model(*input).sum()679ipex_optimizer.zero_grad()680y.backward()681ipex_optimizer.step()682# check state changed (with public formt)683ipex_model_state = ipex_model.state_dict()684ipex_optimizer_state = ipex_optimizer.state_dict()685for var_name in ipex_model_state:686self.assertNotEqual(687ipex_model_state[var_name], ref_ipex_model_state[var_name]688)689for var_name in ipex_optimizer_state:690if var_name == "state":691self.assertNotEqual(692ipex_optimizer_state[var_name], ref_ipex_optimizer_state[var_name]693)694# check values before load (with block format)695for p1, p2 in zip(696ipex_model.named_parameters(), ref_ipex_model.named_parameters()697):698# the empty weight and bias tensor will always be Tensor()699assert_fn = (700self.assertEqual701if p1[0]702in _empty_weight_bias_parameter_names(703prefixes=["conv", "linear", "conv_transpose2d"]704)705else self.assertNotEqual706)707assert_fn(p1[1], p2[1])708for (_, v1), (_, v2) in zip(709ipex_optimizer.state.items(), ref_ipex_optimizer.state.items()710):711self.assertNotEqual(v1, v2)712ipex_model.load_state_dict(ref_ipex_model_state)713ipex_optimizer.load_state_dict(ref_ipex_optimizer_state)714# check values same after load (with block format)715for p1, p2 in zip(716ipex_model.named_parameters(), ref_ipex_model.named_parameters()717):718self.assertEqual(p1[1], p2[1])719for (_, v1), (_, v2) in zip(720ipex_optimizer.state.items(), ref_ipex_optimizer.state.items()721):722if "step_size" in v1:723# For Rprop, there is a "clamp" operation on step_size which will change the "zero"724# attribute for packed position.725# The zero pos will be changed after "clamp", and will be zero again after pack and726# repack it. So in ipex_optimizer, the packed pos of "step_size" will be zero but in727# ref_ipex_optimizer, the packed pos of "step_size" will not be zero. Thus the728# assertEqual will be failed.729# step_sizes=(1e-6, 50)730# step_size_min, step_size_max = group['step_sizes']731# step_size.mul_(sign).clamp_(step_size_min, step_size_max)732# param.addcmul_(grad.sign(), step_size, value=-1)733# (param = param - grad.sign() * step_size)734# but this step_size will not have impact since grad are zero735v1 = copy.deepcopy(v1)736v1.pop("step_size")737v2 = copy.deepcopy(v2)738v2.pop("step_size")739self.assertEqual(v1, v2)740
741# check state same after load (with plain format)742ipex_model_state = ipex_model.state_dict()743ipex_optimizer_state = ipex_optimizer.state_dict()744for var_name in ipex_model_state:745self.assertEqual(ipex_model_state[var_name], ref_ipex_model_state[var_name])746for var_name in ipex_optimizer_state:747self.assertEqual(748ipex_optimizer_state[var_name], ref_ipex_optimizer_state[var_name]749)750
751# This test case is to simulate the use case of Stable Diffusion fine-tuning752def test_eval_backward(self):753class Model(torch.nn.Module):754def __init__(self):755super(Model, self).__init__()756self.conv = torch.nn.Conv2d(3, 2, kernel_size=(2, 2))757
758def forward(self, x):759return self.conv(x)760
761x = torch.randn(1, 3, 8, 8)762x_optimized = copy.deepcopy(x)763x.requires_grad_()764x_optimized.requires_grad_()765
766m = Model().eval()767optimized_m = ipex.optimize(m)768
769y = m(x)770y.sum().backward()771
772y_optimized = optimized_m(x_optimized)773y_optimized.sum().backward()774
775grad = x.grad776grad_optimized = x_optimized.grad777
778self.assertEqual(grad, grad_optimized)779
780def test_load_after_optimize(self):781class Model(torch.nn.Module):782def __init__(self):783super(Model, self).__init__()784self.input = (785torch.randn(1, 3, 224, 224),786torch.randn(100, 100),787torch.randn(5, 5, 3, 3),788)789self.conv = torch.nn.Conv2d(7903, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3)791)792self.linear = torch.nn.Linear(100, 100)793self.conv_transpose2d = torch.nn.ConvTranspose2d(5, 5, (3, 3))794
795def forward(self, x1, x2, x3):796return (797self.conv(x1).sum()798+ self.linear(x2).sum()799+ self.conv_transpose2d(x3)800)801
802params_dict = {803"dtype": [torch.float, torch.bfloat16],804"optimizer": [805Lamb,806Adadelta,807Adagrad,808Adam,809AdamW,810Adamax,811ASGD,812RMSprop,813Rprop,814SGD,815],816"level": ["O0", "O1"],817"inplace": [True, False],818}819for dtype, optimizer, level, inplace in list(820itertools.product(*params_dict.values())821):822self._test_load_after_ipex_optimize_training(823Model, dtype, optimizer, level, inplace824)825self._test_load_after_ipex_optimize_inference(826Model, dtype, optimizer, level, inplace827)828
829def test_reentrancy_of_ipex_optimize(self):830CALL_NUM = 3831
832class Model(torch.nn.Module):833def __init__(self):834super(Model, self).__init__()835self.input = (836torch.randn(1, 3, 224, 224),837torch.randn(100, 100),838torch.randn(5, 5, 3, 3),839)840self.conv = torch.nn.Conv2d(8413, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3)842)843self.linear = torch.nn.Linear(100, 100)844self.conv_transpose2d = torch.nn.ConvTranspose2d(5, 5, (3, 3))845
846def forward(self, x1, x2, x3):847return (848self.conv(x1).sum()849+ self.linear(x2).sum()850+ self.conv_transpose2d(x3)851)852
853def run_and_recursively_call_ipex_optimize(854model_class,855dtype,856level,857inplace,858weights_prepack,859split_master_weight_for_bf16,860fuse_update_step,861graph_mode,862):863model = model_class().train()864input = model.input865optimizer = torch.optim.SGD(model.parameters(), lr=10.01)866for _ in range(CALL_NUM):867# recursively calling ipex.optimize CALL_NUM times868model, optimizer = ipex.optimize(869model,870dtype=dtype,871optimizer=optimizer,872level=level,873inplace=inplace,874weights_prepack=weights_prepack,875split_master_weight_for_bf16=split_master_weight_for_bf16,876fuse_update_step=fuse_update_step,877graph_mode=graph_mode,878)879with torch.cpu.amp.autocast(enabled=True, dtype=dtype):880y = model(*input).sum()881optimizer.zero_grad()882y.backward()883optimizer.step()884
885params_dict = {886"dtype": [torch.float32, torch.bfloat16],887"level": ["O1"],888"inplace": [True, False],889"weights_prepack": [True, False],890"split_master_weight_for_bf16": [True, False],891"fuse_update_step": [True, False],892"graph_mode": [True, False],893}894
895for (896dtype,897level,898inplace,899weights_prepack,900split_master_weight_for_bf16,901fuse_update_step,902graph_mode,903) in list(itertools.product(*params_dict.values())):904run_and_recursively_call_ipex_optimize(905Model,906dtype,907level,908inplace,909weights_prepack,910split_master_weight_for_bf16,911fuse_update_step,912graph_mode,913)914
915
916if __name__ == "__main__":917test = unittest.main()918