intel-extension-for-pytorch
230 строк · 7.9 Кб
1import torch2import unittest3
4from intel_extension_for_pytorch.quantization.fp8 import (5fp8_autocast,6DelayedScaling,7Format,8prepare_fp8,9)
10import intel_extension_for_pytorch._C as core11
12from torch.testing._internal.common_utils import TestCase13from torch.optim import SGD14
15
16class TestFP8Cases(TestCase):17@unittest.skipIf(18not core.onednn_has_fp8_support(),19"IPEX FP8 is not supported on this CPU device",20)21def test_fp8_linear_base(self):22class MyModel(torch.nn.Module):23def __init__(self):24super().__init__()25self.ln = torch.nn.LayerNorm(5, eps=1e-05)26self.lin1 = torch.nn.Linear(5, 4, bias=False)27self.lin2 = torch.nn.Linear(4, 3, bias=True)28self.dropout = torch.nn.Dropout()29
30def forward(self, x):31x = self.ln(x)32x = self.lin1(x)33x = torch.nn.functional.gelu(x, approximate="tanh")34x = self.lin2(x)35z = self.dropout(x)36return z37
38torch.manual_seed(2024)39
40my_linear = MyModel()41my_linear.train()42inp = torch.randn((10, 7, 3, 5), dtype=torch.float32)43inp1 = inp.clone().requires_grad_(True)44inp2 = inp.clone().requires_grad_(True)45
46origin_optimizer = SGD(my_linear.parameters(), lr=0.01, momentum=0.9)47fp8_linear, ipex_optimizer = prepare_fp8(my_linear, origin_optimizer)48
49with fp8_autocast(50enabled=True,51fp8_recipe=DelayedScaling(fp8_format=Format.E4M3),52device="cpu",53):54for i in range(10):55torch.manual_seed(2024)56out = fp8_linear(inp2[i])57ipex_optimizer.zero_grad()58out.mean().backward()59ipex_optimizer.step()60
61for i in range(10):62torch.manual_seed(2024)63out_nn = my_linear(inp1[i])64origin_optimizer.zero_grad()65out_nn.mean().backward()66origin_optimizer.step()67
68self.assertEqual(out, out_nn, atol=0.05, rtol=0.1)69self.assertEqual(inp1[-1].grad, inp2[-1].grad, atol=0.01, rtol=0.1)70
71origin_model_state = my_linear.state_dict()72ipex_model_state = fp8_linear.state_dict()73for var_name in origin_model_state:74self.assertEqual(75origin_model_state[var_name],76ipex_model_state[var_name],77atol=0.01,78rtol=0.1,79)80for name, _ in fp8_linear.named_children():81if hasattr(getattr(my_linear, name), "weight"):82if getattr(my_linear, name).weight is not None:83self.assertEqual(84getattr(my_linear, name).weight.grad,85getattr(fp8_linear, name).weight.grad,86atol=0.01,87rtol=0.1,88)89if hasattr(getattr(my_linear, name), "bias"):90if getattr(my_linear, name).bias is not None:91self.assertEqual(92getattr(my_linear, name).bias.grad,93getattr(fp8_linear, name).bias.grad,94atol=0.01,95rtol=0.1,96)97
98origin_optimizer_state = origin_optimizer.state_dict()99ipex_optimizer_state = ipex_optimizer.state_dict()100for var_name in origin_optimizer_state:101if var_name == "state":102print(origin_optimizer_state[var_name])103print(ipex_optimizer_state[var_name])104self.assertEqual(105origin_optimizer_state[var_name],106ipex_optimizer_state[var_name],107atol=0.01,108rtol=0.1,109)110
111@unittest.skipIf(112not core.onednn_has_fp8_support(),113"IPEX FP8 is not supported on this CPU device",114)115def test_fp8_linear_calibration(self):116class ClassA(torch.nn.Module):117def __init__(self):118super().__init__()119self.ln = torch.nn.LayerNorm(5, eps=1e-05)120
121def forward(self, x):122z = self.ln(x)123return z124
125class ClassC(torch.nn.Module):126def __init__(self):127super().__init__()128self.lin2 = torch.nn.Linear(4, 3, bias=True)129self.dropout = torch.nn.Dropout()130
131def forward(self, x):132x = self.lin2(x)133z = self.dropout(x)134return z135
136class ClassB(torch.nn.Module):137def __init__(self):138super().__init__()139self.lin1 = torch.nn.Linear(5, 4, bias=False)140self.lin2_dropout = ClassC()141
142def forward(self, x):143x = self.lin1(x)144x = torch.nn.functional.gelu(x, approximate="tanh")145z = self.lin2_dropout(x)146return z147
148class MyModel(torch.nn.Module):149def __init__(self):150super().__init__()151self.ln = ClassA()152self.lin1_gelu = ClassB()153
154def forward(self, x):155x = self.ln(x)156z = self.lin1_gelu(x)157return z158
159# FP32 reference model160my_linear = MyModel()161my_linear.train()162inp = torch.randn((5, 7, 3, 5), dtype=torch.float32)163inp1 = inp.clone().requires_grad_(True)164inp2 = inp.clone().requires_grad_(False)165
166origin_optimizer = SGD(my_linear.parameters(), lr=0.01, momentum=0.9)167
168for i in range(4):169out_nn = my_linear(inp1[i])170origin_optimizer.zero_grad()171out_nn.mean().backward()172origin_optimizer.step()173
174torch.save(my_linear.state_dict(), "my_linear_inference.pt")175my_linear_inference = MyModel()176my_linear_inference.load_state_dict(torch.load("my_linear_inference.pt"))177my_linear_inference.eval()178out_nn_iter5 = my_linear_inference(inp1[4])179
180fp8_linear_inference = prepare_fp8(my_linear_inference)181# Do calibration to store amax of input and weight182for i in range(4):183with fp8_autocast(184enabled=False,185calibrating=True,186fp8_recipe=DelayedScaling(fp8_format=Format.E4M3),187device="cpu",188):189_ = fp8_linear_inference(inp2[i])190torch.save(fp8_linear_inference.state_dict(), "fp8_linear_inference.pt")191
192# FP8 model with calibration193fp8_linear_with_calibration = MyModel()194fp8_linear_with_calibration = prepare_fp8(fp8_linear_with_calibration)195fp8_linear_with_calibration.load_state_dict(196torch.load("fp8_linear_inference.pt")197)198fp8_linear_with_calibration.eval()199
200# Run model inference using calibration data201with fp8_autocast(202enabled=True,203calibrating=False,204fp8_recipe=DelayedScaling(fp8_format=Format.E4M3),205device="cpu",206):207out_fp8_iter5 = fp8_linear_with_calibration(inp2[4])208self.assertEqual(out_fp8_iter5, out_nn_iter5, atol=0.01, rtol=0.1)209
210@unittest.skipIf(211not core.onednn_has_fp8_support(),212"IPEX FP8 is not supported on this CPU device",213)214def test_fp8_non_contiguous_weight(self):215nn_linear = torch.nn.Linear(2, 2)216nn_linear.weight = torch.nn.Parameter(nn_linear.weight.transpose(0, 1))217inp = torch.ones(3, 2)218fp8_linear = prepare_fp8(nn_linear)219with fp8_autocast(220enabled=True,221fp8_recipe=DelayedScaling(fp8_format=Format.E4M3),222device="cpu",223):224fp8_out = fp8_linear(inp)225nn_out = nn_linear(inp)226self.assertEqual(nn_out, fp8_out, atol=0.01, rtol=0.1)227
228
229if __name__ == "__main__":230test = unittest.main()231