intel-extension-for-pytorch
159 строк · 5.6 Кб
1import torch
2import torch.nn as nn
3import unittest
4import itertools
5import copy
6from torch.testing._internal.common_utils import TestCase
7import intel_extension_for_pytorch as ipex
8
9ipex_emb_fn = ipex.nn.functional._embeddingbag._embeddingbag
10aten_emb_fn = ipex.nn.functional._embeddingbag.torch_embedding_bag
11
12
13class Embeddingbag(torch.nn.Module):
14def __init__(self):
15super(Embeddingbag, self).__init__()
16self.embeddingbag = nn.EmbeddingBag(10, 3, mode="sum", sparse=True)
17
18def forward(self, input, offsets):
19return self.embeddingbag(input, offsets)
20
21
22class TestEMB(TestCase):
23def _test_emb(
24self,
25mode,
26per_sample_weights=None,
27padding_idx=None,
28include_last_offset=False,
29sparse=True,
30test_int32=False,
31):
32aten_emb = nn.EmbeddingBag(
3310,
3433,
35mode=mode,
36sparse=sparse,
37padding_idx=padding_idx,
38include_last_offset=include_last_offset,
39)
40aten_emb = aten_emb.bfloat16().float()
41ipex_emb = copy.deepcopy(aten_emb)
42bf16_emb = copy.deepcopy(aten_emb).bfloat16()
43# a batch of 2 samples of 4 indices each
44
45tensor_create_fn = torch.IntTensor if test_int32 else torch.LongTensor
46input = tensor_create_fn([1, 2, 4, 5, 4, 3, 2, 9])
47if per_sample_weights is not None:
48per_sample_weights = torch.rand_like(input.float())
49if include_last_offset:
50offsets = tensor_create_fn([0, 4, 8])
51else:
52offsets = tensor_create_fn([0, 4])
53# aten path
54torch.embedding_bag = aten_emb_fn
55aten_out = aten_emb(input, offsets, per_sample_weights)
56aten_out.sum().backward()
57
58# ipex fast path (both fp32/bf16)
59torch.embedding_bag = ipex_emb_fn
60ipex_out = ipex_emb(input, offsets, per_sample_weights)
61ipex_out.sum().backward()
62
63self.assertEqual(aten_out, ipex_out)
64if sparse:
65self.assertEqual(
66aten_emb.weight.grad.data._nnz(), ipex_emb.weight.grad.data._nnz()
67)
68self.assertEqual(
69aten_emb.weight.grad.data.sparse_dim(),
70ipex_emb.weight.grad.data.sparse_dim(),
71)
72self.assertEqual(
73aten_emb.weight.grad.data.dense_dim(),
74ipex_emb.weight.grad.data.dense_dim(),
75)
76self.assertEqual(
77aten_emb.weight.grad.data.is_coalesced(),
78ipex_emb.weight.grad.data.is_coalesced(),
79)
80self.assertEqual(
81aten_emb.weight.grad.data._indices(),
82ipex_emb.weight.grad.data._indices(),
83)
84self.assertEqual(
85aten_emb.weight.grad.data._values(), ipex_emb.weight.grad.data._values()
86)
87
88if mode == "sum" and padding_idx is None and per_sample_weights is None:
89bf16_out = bf16_emb(input, offsets)
90bf16_out.sum().backward()
91self.assertEqual(aten_out.bfloat16(), bf16_out)
92if sparse:
93self.assertEqual(
94bf16_emb.weight.grad.data._values().dtype, torch.bfloat16
95)
96self.assertEqual(
97aten_emb.weight.grad.data._nnz(), ipex_emb.weight.grad.data._nnz()
98)
99self.assertEqual(
100aten_emb.weight.grad.data.sparse_dim(),
101ipex_emb.weight.grad.data.sparse_dim(),
102)
103self.assertEqual(
104aten_emb.weight.grad.data.dense_dim(),
105ipex_emb.weight.grad.data.dense_dim(),
106)
107self.assertEqual(
108aten_emb.weight.grad.data.is_coalesced(),
109ipex_emb.weight.grad.data.is_coalesced(),
110)
111self.assertEqual(
112aten_emb.weight.grad.data._indices(),
113ipex_emb.weight.grad.data._indices(),
114)
115self.assertEqual(
116aten_emb.weight.grad.data._values().bfloat16().float(),
117ipex_emb.weight.grad.data._values().float(),
118)
119
120def test_emb_fallback_path(self):
121self._test_emb(mode="mean")
122for options in itertools.product(
123[2, None], [True, None], [True, False], [True, False], [True, False]
124):
125(
126padding_idx,
127per_sample_weights,
128include_last_offset,
129sparse,
130test_int32,
131) = options
132self._test_emb(
133mode="sum",
134per_sample_weights=per_sample_weights,
135padding_idx=padding_idx,
136include_last_offset=include_last_offset,
137sparse=sparse,
138test_int32=test_int32,
139)
140
141def test_emb_fast_path(self):
142for options in itertools.product([True, False], [True, False]):
143include_last_offset, sparse = options
144self._test_emb(
145mode="sum", sparse=sparse, include_last_offset=include_last_offset
146)
147
148def test_emb_jit_scriptable(self):
149emb = nn.EmbeddingBag(10, 3, mode="sum", sparse=True)
150input = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9])
151offsets = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7])
152ref_out = emb(input, offsets)
153script_emb = torch.jit.script(emb)
154out = script_emb(input, offsets)
155self.assertEqual(out, ref_out)
156
157
158if __name__ == "__main__":
159test = unittest.main()
160