intel-extension-for-pytorch
74 строки · 3.1 Кб
1import torch2import intel_extension_for_pytorch # noqa3from torch.testing._internal.common_utils import run_tests, TestCase4
5
6class TestFloat8(TestCase):7def test_creation_with_zeros(self):8x = torch.zeros(8, dtype=torch.float8_e4m3fn)9
10def test_e4m3fn_casts(self):11for dtype in (torch.float32, torch.float16):12x = torch.randn(16, dtype=torch.float)13x_fp8 = x.to(torch.float8_e4m3fn)14x_orig_dtype = x_fp8.to(torch.float)15
16def test_e4m3fn_numerics(self):17# ensure that our format matches https://arxiv.org/pdf/2209.05433.pdf, Table 118
19def _compare(bits_str, expected_fp32, comp_name):20bits_int = int(bits_str, 2)21tensor_int = torch.tensor([bits_int], dtype=torch.uint8)22tensor_fp8 = tensor_int.view(torch.float8_e4m3fn)23tensor_fp32 = tensor_fp8.float()24ref_tensor_fp32 = torch.tensor([expected_fp32], dtype=torch.float)25self.assertTrue(26torch.allclose(tensor_fp32, ref_tensor_fp32),27f"{comp_name} failed: expected {expected_fp32}, got {tensor_fp32.item()}",28)29
30_compare("00000000", 0.0, "zero")31_compare("10000000", -0.0, "neg_zero")32_compare("01111110", 448.0, "max_normal")33_compare("11111110", -448.0, "neg_max_normal")34_compare("00001000", 2**-6, "min_normal")35_compare("10001000", -1 * (2**-6), "neg_min_normal")36_compare("00000111", 0.875 * (2**-6), "max_subnorm")37_compare("10000111", -0.875 * (2**-6), "neg_max_subnorm")38_compare("00000001", 2**-9, "min_subnorm")39_compare("10000001", -1 * (2**-9), "neg_min_subnorm")40
41def test_e5m2fn_casts(self):42for dtype in (torch.float32, torch.float16):43x = torch.randn(16, dtype=torch.float)44x_fp8 = x.to(torch.float8_e5m2)45x_orig_dtype = x_fp8.to(torch.float)46
47def test_e5m2fn_numerics(self):48# ensure that our format matches https://arxiv.org/pdf/2209.05433.pdf, Table 149
50def _compare(bits_str, expected_fp32, comp_name):51bits_int = int(bits_str, 2)52tensor_int = torch.tensor([bits_int], dtype=torch.uint8)53tensor_fp8 = tensor_int.view(torch.float8_e5m2)54tensor_fp32 = tensor_fp8.float()55ref_tensor_fp32 = torch.tensor([expected_fp32], dtype=torch.float)56self.assertTrue(57torch.allclose(tensor_fp32, ref_tensor_fp32),58f"{comp_name} failed: expected {expected_fp32}, got {tensor_fp32.item()}",59)60
61_compare("00000000", 0.0, "zero")62_compare("10000000", -0.0, "neg_zero")63_compare("01111011", 57344.0, "max_normal")64_compare("11111011", -57344.0, "neg_max_normal")65_compare("00000100", 2**-14, "min_normal")66_compare("10000100", -1 * (2**-14), "neg_min_normal")67_compare("00000011", 0.75 * (2**-14), "max_subnorm")68_compare("10000011", -0.75 * (2**-14), "neg_max_subnorm")69_compare("00000001", 2**-16, "min_subnorm")70_compare("10000001", -1 * (2**-16), "neg_min_subnorm")71
72
73if __name__ == "__main__":74run_tests()75