intel-extension-for-pytorch

Форк
0
159 строк · 5.6 Кб
1
import torch
2
import torch.nn as nn
3
import unittest
4
import itertools
5
import copy
6
from torch.testing._internal.common_utils import TestCase
7
import intel_extension_for_pytorch as ipex
8

9
ipex_emb_fn = ipex.nn.functional._embeddingbag._embeddingbag
10
aten_emb_fn = ipex.nn.functional._embeddingbag.torch_embedding_bag
11

12

13
class Embeddingbag(torch.nn.Module):
14
    def __init__(self):
15
        super(Embeddingbag, self).__init__()
16
        self.embeddingbag = nn.EmbeddingBag(10, 3, mode="sum", sparse=True)
17

18
    def forward(self, input, offsets):
19
        return self.embeddingbag(input, offsets)
20

21

22
class TestEMB(TestCase):
23
    def _test_emb(
24
        self,
25
        mode,
26
        per_sample_weights=None,
27
        padding_idx=None,
28
        include_last_offset=False,
29
        sparse=True,
30
        test_int32=False,
31
    ):
32
        aten_emb = nn.EmbeddingBag(
33
            10,
34
            33,
35
            mode=mode,
36
            sparse=sparse,
37
            padding_idx=padding_idx,
38
            include_last_offset=include_last_offset,
39
        )
40
        aten_emb = aten_emb.bfloat16().float()
41
        ipex_emb = copy.deepcopy(aten_emb)
42
        bf16_emb = copy.deepcopy(aten_emb).bfloat16()
43
        # a batch of 2 samples of 4 indices each
44

45
        tensor_create_fn = torch.IntTensor if test_int32 else torch.LongTensor
46
        input = tensor_create_fn([1, 2, 4, 5, 4, 3, 2, 9])
47
        if per_sample_weights is not None:
48
            per_sample_weights = torch.rand_like(input.float())
49
        if include_last_offset:
50
            offsets = tensor_create_fn([0, 4, 8])
51
        else:
52
            offsets = tensor_create_fn([0, 4])
53
        # aten path
54
        torch.embedding_bag = aten_emb_fn
55
        aten_out = aten_emb(input, offsets, per_sample_weights)
56
        aten_out.sum().backward()
57

58
        # ipex fast path (both fp32/bf16)
59
        torch.embedding_bag = ipex_emb_fn
60
        ipex_out = ipex_emb(input, offsets, per_sample_weights)
61
        ipex_out.sum().backward()
62

63
        self.assertEqual(aten_out, ipex_out)
64
        if sparse:
65
            self.assertEqual(
66
                aten_emb.weight.grad.data._nnz(), ipex_emb.weight.grad.data._nnz()
67
            )
68
            self.assertEqual(
69
                aten_emb.weight.grad.data.sparse_dim(),
70
                ipex_emb.weight.grad.data.sparse_dim(),
71
            )
72
            self.assertEqual(
73
                aten_emb.weight.grad.data.dense_dim(),
74
                ipex_emb.weight.grad.data.dense_dim(),
75
            )
76
            self.assertEqual(
77
                aten_emb.weight.grad.data.is_coalesced(),
78
                ipex_emb.weight.grad.data.is_coalesced(),
79
            )
80
            self.assertEqual(
81
                aten_emb.weight.grad.data._indices(),
82
                ipex_emb.weight.grad.data._indices(),
83
            )
84
            self.assertEqual(
85
                aten_emb.weight.grad.data._values(), ipex_emb.weight.grad.data._values()
86
            )
87

88
        if mode == "sum" and padding_idx is None and per_sample_weights is None:
89
            bf16_out = bf16_emb(input, offsets)
90
            bf16_out.sum().backward()
91
            self.assertEqual(aten_out.bfloat16(), bf16_out)
92
            if sparse:
93
                self.assertEqual(
94
                    bf16_emb.weight.grad.data._values().dtype, torch.bfloat16
95
                )
96
                self.assertEqual(
97
                    aten_emb.weight.grad.data._nnz(), ipex_emb.weight.grad.data._nnz()
98
                )
99
                self.assertEqual(
100
                    aten_emb.weight.grad.data.sparse_dim(),
101
                    ipex_emb.weight.grad.data.sparse_dim(),
102
                )
103
                self.assertEqual(
104
                    aten_emb.weight.grad.data.dense_dim(),
105
                    ipex_emb.weight.grad.data.dense_dim(),
106
                )
107
                self.assertEqual(
108
                    aten_emb.weight.grad.data.is_coalesced(),
109
                    ipex_emb.weight.grad.data.is_coalesced(),
110
                )
111
                self.assertEqual(
112
                    aten_emb.weight.grad.data._indices(),
113
                    ipex_emb.weight.grad.data._indices(),
114
                )
115
                self.assertEqual(
116
                    aten_emb.weight.grad.data._values().bfloat16().float(),
117
                    ipex_emb.weight.grad.data._values().float(),
118
                )
119

120
    def test_emb_fallback_path(self):
121
        self._test_emb(mode="mean")
122
        for options in itertools.product(
123
            [2, None], [True, None], [True, False], [True, False], [True, False]
124
        ):
125
            (
126
                padding_idx,
127
                per_sample_weights,
128
                include_last_offset,
129
                sparse,
130
                test_int32,
131
            ) = options
132
            self._test_emb(
133
                mode="sum",
134
                per_sample_weights=per_sample_weights,
135
                padding_idx=padding_idx,
136
                include_last_offset=include_last_offset,
137
                sparse=sparse,
138
                test_int32=test_int32,
139
            )
140

141
    def test_emb_fast_path(self):
142
        for options in itertools.product([True, False], [True, False]):
143
            include_last_offset, sparse = options
144
            self._test_emb(
145
                mode="sum", sparse=sparse, include_last_offset=include_last_offset
146
            )
147

148
    def test_emb_jit_scriptable(self):
149
        emb = nn.EmbeddingBag(10, 3, mode="sum", sparse=True)
150
        input = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9])
151
        offsets = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7])
152
        ref_out = emb(input, offsets)
153
        script_emb = torch.jit.script(emb)
154
        out = script_emb(input, offsets)
155
        self.assertEqual(out, ref_out)
156

157

158
if __name__ == "__main__":
159
    test = unittest.main()
160

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.