intel-extension-for-pytorch

Форк
0
74 строки · 3.1 Кб
1
import torch
2
import intel_extension_for_pytorch  # noqa
3
from torch.testing._internal.common_utils import run_tests, TestCase
4

5

6
class TestFloat8(TestCase):
7
    def test_creation_with_zeros(self):
8
        x = torch.zeros(8, dtype=torch.float8_e4m3fn)
9

10
    def test_e4m3fn_casts(self):
11
        for dtype in (torch.float32, torch.float16):
12
            x = torch.randn(16, dtype=torch.float)
13
            x_fp8 = x.to(torch.float8_e4m3fn)
14
            x_orig_dtype = x_fp8.to(torch.float)
15

16
    def test_e4m3fn_numerics(self):
17
        # ensure that our format matches https://arxiv.org/pdf/2209.05433.pdf, Table 1
18

19
        def _compare(bits_str, expected_fp32, comp_name):
20
            bits_int = int(bits_str, 2)
21
            tensor_int = torch.tensor([bits_int], dtype=torch.uint8)
22
            tensor_fp8 = tensor_int.view(torch.float8_e4m3fn)
23
            tensor_fp32 = tensor_fp8.float()
24
            ref_tensor_fp32 = torch.tensor([expected_fp32], dtype=torch.float)
25
            self.assertTrue(
26
                torch.allclose(tensor_fp32, ref_tensor_fp32),
27
                f"{comp_name} failed: expected {expected_fp32}, got {tensor_fp32.item()}",
28
            )
29

30
        _compare("00000000", 0.0, "zero")
31
        _compare("10000000", -0.0, "neg_zero")
32
        _compare("01111110", 448.0, "max_normal")
33
        _compare("11111110", -448.0, "neg_max_normal")
34
        _compare("00001000", 2**-6, "min_normal")
35
        _compare("10001000", -1 * (2**-6), "neg_min_normal")
36
        _compare("00000111", 0.875 * (2**-6), "max_subnorm")
37
        _compare("10000111", -0.875 * (2**-6), "neg_max_subnorm")
38
        _compare("00000001", 2**-9, "min_subnorm")
39
        _compare("10000001", -1 * (2**-9), "neg_min_subnorm")
40

41
    def test_e5m2fn_casts(self):
42
        for dtype in (torch.float32, torch.float16):
43
            x = torch.randn(16, dtype=torch.float)
44
            x_fp8 = x.to(torch.float8_e5m2)
45
            x_orig_dtype = x_fp8.to(torch.float)
46

47
    def test_e5m2fn_numerics(self):
48
        # ensure that our format matches https://arxiv.org/pdf/2209.05433.pdf, Table 1
49

50
        def _compare(bits_str, expected_fp32, comp_name):
51
            bits_int = int(bits_str, 2)
52
            tensor_int = torch.tensor([bits_int], dtype=torch.uint8)
53
            tensor_fp8 = tensor_int.view(torch.float8_e5m2)
54
            tensor_fp32 = tensor_fp8.float()
55
            ref_tensor_fp32 = torch.tensor([expected_fp32], dtype=torch.float)
56
            self.assertTrue(
57
                torch.allclose(tensor_fp32, ref_tensor_fp32),
58
                f"{comp_name} failed: expected {expected_fp32}, got {tensor_fp32.item()}",
59
            )
60

61
        _compare("00000000", 0.0, "zero")
62
        _compare("10000000", -0.0, "neg_zero")
63
        _compare("01111011", 57344.0, "max_normal")
64
        _compare("11111011", -57344.0, "neg_max_normal")
65
        _compare("00000100", 2**-14, "min_normal")
66
        _compare("10000100", -1 * (2**-14), "neg_min_normal")
67
        _compare("00000011", 0.75 * (2**-14), "max_subnorm")
68
        _compare("10000011", -0.75 * (2**-14), "neg_max_subnorm")
69
        _compare("00000001", 2**-16, "min_subnorm")
70
        _compare("10000001", -1 * (2**-16), "neg_min_subnorm")
71

72

73
if __name__ == "__main__":
74
    run_tests()
75

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.