intel-extension-for-pytorch
790 строк · 34.5 Кб
1import torch2import torch.nn as nn3from common_utils import TestCase4import unittest5from typing import Tuple6import intel_extension_for_pytorch as ipex7
8
9class MaskedMHA(torch.nn.Module):10def __init__(self, hidden_size=4096, n_head=16, n_head_kv=16, head_dim=256):11super().__init__()12self.num_heads = n_head13self.num_kv = n_head_kv14self.head_dim = head_dim15self.query_key_value = nn.Linear(16hidden_size, (n_head_kv * 2 + n_head) * head_dim17)18
19def _split_heads(20self, fused_qkv: torch.Tensor21) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:22"""23Split the last dimension into (num_heads, head_dim), results share same memory
24storage as `fused_qkv`
25
26Args:
27fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, (num_heads + kv_num * 2) * head_dim]
28
29Returns:
30query: [batch_size, seq_length, num_heads, head_dim]
31key: [batch_size, seq_length, kv_num, head_dim]
32value: [batch_size, seq_length, kv_num, head_dim]
33"""
34bs = fused_qkv.shape[0]35query_layer = fused_qkv[:, :, : self.num_heads * self.head_dim]36query_layer = query_layer.view(bs, -1, self.num_heads, self.head_dim)37key_layer = fused_qkv[38:,39:,40self.num_heads41* self.head_dim : (self.num_heads + self.num_kv)42* self.head_dim,43]44key_layer = key_layer.view(bs, -1, self.num_kv, self.head_dim)45value_layer = fused_qkv[:, :, (self.num_heads + self.num_kv) * self.head_dim :]46value_layer = value_layer.view(bs, -1, self.num_kv, self.head_dim)47return query_layer, key_layer, value_layer48
49def _repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:50"torch.repeat_interleave(x, dim=2, repeats=n_rep)"51bs, slen, n_kv_heads, head_dim = x.shape52if n_rep == 1:53return x54return (55x[:, :, :, None, :]56.expand(bs, slen, n_kv_heads, n_rep, head_dim)57.reshape(bs, slen, n_kv_heads * n_rep, head_dim)58)59
60def forward(61self,62input_t,63key_cache,64value_cache,65max_position,66attention_mask,67beam_idx,68indirect_access_kv_cache=False,69offset=0,70enable_linear=True,71):72head_size = self.head_dim73origin_type = input_t.dtype74if enable_linear:75query, key, value = self._split_heads(self.query_key_value(input_t))76else:77query, key, value = self._split_heads(input_t)78if indirect_access_kv_cache:79query = query.contiguous()80key = key.contiguous()81value = value.contiguous()82return torch.ops.torch_ipex.masked_multihead_self_attention(83query,84key,85value,86key_cache,87value_cache,88beam_idx,89offset,90head_size**0.5,91max_position,92None,93attention_mask,94)95else:96# Get the concatenated key and value97if key_cache is not None:98key = torch.cat([key_cache, key], dim=1)99value = torch.cat([value_cache, value], dim=1)100key_cache = key101value_cache = value102n_rep = self.num_heads // self.num_kv103key = self._repeat_kv(key, n_rep)104value = self._repeat_kv(value, n_rep)105
106key = key.transpose(1, 2)107query = query.transpose(1, 2)108value = value.transpose(1, 2)109if origin_type == torch.half:110key = key.to(torch.float32)111query = query.to(torch.float32)112value = value.to(torch.float32)113# matmul new_key and new_value to get the attention score114attention_scores = torch.matmul(query, key.transpose(-1, -2))115# scale the attention score116attention_scores = attention_scores / (head_size**0.5)117# import pdb; pdb.set_trace()118if attention_mask is not None:119attention_scores = attention_scores + attention_mask120# softmax the attention score121attention_probs = attention_scores.softmax(dim=-1)122# matmul the attention score and value to get the context123attention_output = torch.matmul(attention_probs, value)124if origin_type == torch.half:125attention_output = attention_output.to(origin_type)126return attention_output, None, key_cache, value_cache, None127
128
129class MaskedMHATest(TestCase):130def _test_mha(self, torchcompile=False):131beam_size_list = [1, 4]132batch_size_list = [1, 2, 4]133head_size = 256134head_num = 16135head_num_kv_list = [1, 4, 16]136max_seq_len = 64137first_seq_len = 32138for batch_size in batch_size_list:139for beam_size in beam_size_list:140for head_num_kv in head_num_kv_list:141key_cache = None142value_cache = None143offset = 0144mha = MaskedMHA(145n_head=head_num, n_head_kv=head_num_kv, head_dim=head_size146)147
148if torchcompile:149torch._dynamo.reset()150ipex._set_compiler_backend("inductor")151mha = torch.compile(mha, backend="ipex")152
153# first token decode154input_t = torch.randn(155batch_size,156first_seq_len,157head_num * head_size,158dtype=torch.float32,159)160key_cache_iakv = torch.randn(161max_seq_len,162beam_size * batch_size,163head_num,164head_size,165dtype=torch.float32,166)167value_cache_iakv = torch.randn(168max_seq_len,169beam_size * batch_size,170head_num,171head_size,172dtype=torch.float32,173)174beam_idx = torch.zeros(175max_seq_len, beam_size * batch_size, dtype=torch.int64176)177# create attention mask and causal mask178attention_mask = torch.zeros(179batch_size, 1, first_seq_len, first_seq_len, dtype=torch.float32180)181casual_mask = torch.full(182(first_seq_len, first_seq_len), -1e6, dtype=input_t.dtype183)184casual_mask = casual_mask.triu(1)185casual_mask = casual_mask.unsqueeze(0).unsqueeze(0)186attention_mask = (187attention_mask + casual_mask188) # combine the attention mask and causal mask189# UT for first token with fp32190with torch.inference_mode(), torch.no_grad():191naive_output, _, key_cache, value_cache, _ = mha(192input_t, None, None, max_seq_len, attention_mask, None, None193)194(195indirect_access_kv_cache_output,196_,197key_cache_iakv,198value_cache_iakv,199beam_idx,200) = mha(201input_t,202key_cache_iakv,203value_cache_iakv,204max_seq_len,205attention_mask,206beam_idx,207True,208torch.tensor(offset),209)210# self.assertEqual(naive_output, indirect_access_kv_cache_output)211key_cache = key_cache.repeat_interleave(beam_size, dim=0)212value_cache = value_cache.repeat_interleave(beam_size, dim=0)213for i in range(batch_size):214self.assertEqual(215key_cache.transpose(0, 1)[:, i * beam_size, :, :],216key_cache_iakv[0:first_seq_len, i * beam_size, :, :],217)218self.assertEqual(219value_cache.transpose(0, 1)[:, i * beam_size, :, :],220value_cache_iakv[0:first_seq_len, i * beam_size, :, :],221)222if beam_size == 4:223beam_idx_t = torch.zeros(224beam_size * batch_size, dtype=torch.int64225)226for i in range(1, batch_size):227beam_idx_t[228i * beam_size : i * beam_size + beam_size229] = (230beam_idx_t[231i * beam_size : i * beam_size + beam_size232]233+ i * beam_size234)235elif beam_size == 1:236beam_idx_t = torch.arange(batch_size)237beam_idx[offset] = beam_idx_t238# reorder cache for naive impelementation239key_cache = torch.index_select(key_cache, 0, beam_idx_t)240value_cache = torch.index_select(value_cache, 0, beam_idx_t)241
242# # #UT for first token with bf16243input_t_bf16 = input_t.bfloat16()244key_cache_iakv_bf16 = key_cache_iakv.bfloat16()245value_cache_iakv_bf16 = value_cache_iakv.bfloat16()246attention_mask_bf16 = attention_mask.bfloat16()247with torch.inference_mode(), torch.no_grad(), torch.autocast(248device_type="cpu",249enabled=True,250dtype=torch.bfloat16,251):252naive_output_bf16, _, key_cache_bf16, value_cache_bf16, _ = mha(253input_t_bf16,254None,255None,256max_seq_len,257attention_mask_bf16,258None,259None,260)261(262indirect_access_kv_cache_output_bf16,263_,264key_cache_iakv_bf16,265value_cache_iakv_bf16,266beam_idx,267) = mha(268input_t_bf16,269key_cache_iakv_bf16,270value_cache_iakv_bf16,271max_seq_len,272attention_mask_bf16,273beam_idx,274True,275torch.tensor(offset),276)277self.assertEqual(278naive_output_bf16,279indirect_access_kv_cache_output_bf16,280prec=2e-2,281)282key_cache_bf16 = key_cache_bf16.repeat_interleave(283beam_size, dim=0284)285value_cache_bf16 = value_cache_bf16.repeat_interleave(286beam_size, dim=0287)288for i in range(batch_size):289self.assertEqual(290key_cache_bf16.transpose(0, 1)[:, i * beam_size, :, :],291key_cache_iakv_bf16[2920:first_seq_len, i * beam_size, :, :293],294)295self.assertEqual(296value_cache_bf16.transpose(0, 1)[297:, i * beam_size, :, :298],299value_cache_iakv_bf16[3000:first_seq_len, i * beam_size, :, :301],302)303key_cache_bf16 = torch.index_select(304key_cache_bf16, 0, beam_idx_t305)306value_cache_bf16 = torch.index_select(307value_cache_bf16, 0, beam_idx_t308)309
310offset = offset + first_seq_len311# UT for next token with fp32312input_t = torch.randn(313beam_size * batch_size,3141,315head_num * head_size,316dtype=torch.float32,317)318attention_mask = torch.zeros(319beam_size * batch_size, 1, 1, offset + 1, dtype=torch.float32320)321with torch.inference_mode(), torch.no_grad():322naive_output, _, key_cache, value_cache, _ = mha(323input_t,324key_cache,325value_cache,326max_seq_len,327attention_mask,328None,329None,330)331(332indirect_access_kv_cache_output,333_,334key_cache_iakv,335value_cache_iakv,336beam_idx,337) = mha(338input_t,339key_cache_iakv,340value_cache_iakv,341max_seq_len,342attention_mask,343beam_idx,344True,345torch.tensor(offset),346)347self.assertEqual(naive_output, indirect_access_kv_cache_output)348self.assertEqual(349key_cache.transpose(0, 1)[offset],350key_cache_iakv[offset, :, :, :],351)352self.assertEqual(353value_cache.transpose(0, 1)[offset],354value_cache_iakv[offset, :, :, :],355)356# #UT for next token with bf16357input_t_bf16 = input_t.bfloat16()358attention_mask_bf16 = attention_mask.bfloat16()359with torch.inference_mode(), torch.no_grad(), torch.autocast(360device_type="cpu",361enabled=True,362dtype=torch.bfloat16,363):364naive_output_bf16, _, key_cache_bf16, value_cache_bf16, _ = mha(365input_t_bf16,366key_cache_bf16,367value_cache_bf16,368max_seq_len,369attention_mask_bf16,370None,371None,372)373(374indirect_access_kv_cache_output_bf16,375_,376key_cache_iakv_bf16,377value_cache_iakv_bf16,378beam_idx,379) = mha(380input_t_bf16,381key_cache_iakv_bf16,382value_cache_iakv_bf16,383max_seq_len,384attention_mask_bf16,385beam_idx,386True,387torch.tensor(offset),388)389self.assertEqual(390naive_output_bf16,391indirect_access_kv_cache_output_bf16,392prec=0.05,393)394self.assertEqual(395key_cache_bf16.transpose(0, 1)[offset],396key_cache_iakv_bf16[offset, :, :, :],397)398self.assertEqual(399value_cache_bf16.transpose(0, 1)[offset],400value_cache_iakv_bf16[offset, :, :, :],401)402if beam_size == 4:403beam_idx_t = torch.tensor([1, 3, 0, 0]).repeat(batch_size)404for i in range(1, batch_size):405beam_idx_t[406i * beam_size : i * beam_size + beam_size407] = (408beam_idx_t[409i * beam_size : i * beam_size + beam_size410]411+ i * beam_size412)413elif beam_size == 1:414beam_idx_t = torch.arange(batch_size)415beam_idx[offset] = beam_idx_t416offset = offset + 1417# reorder cache for naive impelementation418key_cache = torch.index_select(key_cache, 0, beam_idx_t)419value_cache = torch.index_select(value_cache, 0, beam_idx_t)420key_cache_bf16 = torch.index_select(421key_cache_bf16, 0, beam_idx_t422)423value_cache_bf16 = torch.index_select(424value_cache_bf16, 0, beam_idx_t425)426# UT for next token with fp32427input_t = torch.randn(428beam_size * batch_size,4291,430head_num * head_size,431dtype=torch.float32,432)433attention_mask = torch.zeros(434beam_size * batch_size, 1, 1, offset + 1, dtype=torch.float32435)436with torch.inference_mode(), torch.no_grad():437naive_output, _, key_cache, value_cache, _ = mha(438input_t,439key_cache,440value_cache,441max_seq_len,442attention_mask,443None,444None,445)446(447indirect_access_kv_cache_output,448_,449key_cache_iakv,450value_cache_iakv,451beam_idx,452) = mha(453input_t,454key_cache_iakv,455value_cache_iakv,456max_seq_len,457attention_mask,458beam_idx,459True,460torch.tensor(offset),461)462self.assertEqual(naive_output, indirect_access_kv_cache_output)463self.assertEqual(464key_cache.transpose(0, 1)[offset],465key_cache_iakv[offset, :, :, :],466)467self.assertEqual(468value_cache.transpose(0, 1)[offset],469value_cache_iakv[offset, :, :, :],470)471# #UT for next token with bf16472input_t_bf16 = input_t.bfloat16()473attention_mask_bf16 = attention_mask.bfloat16()474with torch.inference_mode(), torch.no_grad(), torch.autocast(475device_type="cpu",476enabled=True,477dtype=torch.bfloat16,478):479naive_output_bf16, _, key_cache_bf16, value_cache_bf16, _ = mha(480input_t_bf16,481key_cache_bf16,482value_cache_bf16,483max_seq_len,484attention_mask_bf16,485None,486None,487)488(489indirect_access_kv_cache_output_bf16,490_,491key_cache_iakv_bf16,492value_cache_iakv_bf16,493beam_idx,494) = mha(495input_t_bf16,496key_cache_iakv_bf16,497value_cache_iakv_bf16,498max_seq_len,499attention_mask_bf16,500beam_idx,501True,502torch.tensor(offset),503)504self.assertEqual(505naive_output_bf16,506indirect_access_kv_cache_output_bf16,507prec=0.05,508)509self.assertEqual(510key_cache_bf16.transpose(0, 1)[offset],511key_cache_iakv_bf16[offset, :, :, :],512)513self.assertEqual(514value_cache_bf16.transpose(0, 1)[offset],515value_cache_iakv_bf16[offset, :, :, :],516)517
518def _test_mha_fp16(self, torchcompile=False):519beam_size_list = [1, 4]520batch_size_list = [1, 2, 4]521head_size = 256522head_num = 16523head_num_kv_list = [1, 4, 16]524max_seq_len = 64525first_seq_len = 32526for batch_size in batch_size_list:527for beam_size in beam_size_list:528for head_num_kv in head_num_kv_list:529offset = 0530mha = MaskedMHA(531n_head=head_num, n_head_kv=head_num_kv, head_dim=head_size532)533
534if torchcompile:535torch._dynamo.reset()536ipex._set_compiler_backend("inductor")537mha = torch.compile(mha, backend="ipex")538
539# first token decode540input_t = torch.randn(541batch_size,542first_seq_len,543(head_num + 2 * head_num_kv) * head_size,544dtype=torch.float32,545)546key_cache_iakv = torch.randn(547max_seq_len,548beam_size * batch_size,549head_num,550head_size,551dtype=torch.float32,552)553value_cache_iakv = torch.randn(554max_seq_len,555beam_size * batch_size,556head_num,557head_size,558dtype=torch.float32,559)560beam_idx = torch.zeros(561max_seq_len, beam_size * batch_size, dtype=torch.int64562)563# create attention mask and causal mask564attention_mask = torch.zeros(565batch_size, 1, first_seq_len, first_seq_len, dtype=torch.float32566)567casual_mask = torch.full(568(first_seq_len, first_seq_len), -1e6, dtype=input_t.dtype569)570casual_mask = casual_mask.triu(1)571casual_mask = casual_mask.unsqueeze(0).unsqueeze(0)572attention_mask = (573attention_mask + casual_mask574) # combine the attention mask and causal mask575if beam_size == 4:576beam_idx_t = torch.zeros(577beam_size * batch_size, dtype=torch.int64578)579for i in range(1, batch_size):580beam_idx_t[i * beam_size : i * beam_size + beam_size] = (581beam_idx_t[i * beam_size : i * beam_size + beam_size]582+ i * beam_size583)584elif beam_size == 1:585beam_idx_t = torch.arange(batch_size)586beam_idx[offset] = beam_idx_t587# # #UT for first token with fp16588input_t_half = input_t.half()589key_cache_iakv_half = key_cache_iakv.half()590value_cache_iakv_half = value_cache_iakv.half()591attention_mask_half = attention_mask.half()592with torch.inference_mode(), torch.no_grad():593naive_output_half, _, key_cache_half, value_cache_half, _ = mha(594input_t_half,595None,596None,597max_seq_len,598attention_mask_half,599None,600None,601enable_linear=False,602)603(604indirect_access_kv_cache_output_half,605_,606key_cache_iakv_half,607value_cache_iakv_half,608beam_idx,609) = mha(610input_t_half,611key_cache_iakv_half,612value_cache_iakv_half,613max_seq_len,614attention_mask_half,615beam_idx,616True,617torch.tensor(offset),618enable_linear=False,619)620self.assertEqual(621naive_output_half,622indirect_access_kv_cache_output_half,623prec=2e-2,624)625key_cache_half = key_cache_half.repeat_interleave(626beam_size, dim=0627)628value_cache_half = value_cache_half.repeat_interleave(629beam_size, dim=0630)631for i in range(batch_size):632self.assertEqual(633key_cache_half.transpose(0, 1)[:, i * beam_size, :, :],634key_cache_iakv_half[6350:first_seq_len, i * beam_size, :, :636],637)638self.assertEqual(639value_cache_half.transpose(0, 1)[640:, i * beam_size, :, :641],642value_cache_iakv_half[6430:first_seq_len, i * beam_size, :, :644],645)646key_cache_half = torch.index_select(647key_cache_half, 0, beam_idx_t648)649value_cache_half = torch.index_select(650value_cache_half, 0, beam_idx_t651)652
653offset = offset + first_seq_len654# #UT for next token with fp32655input_t = torch.randn(656beam_size * batch_size,6571,658(head_num + 2 * head_num_kv) * head_size,659dtype=torch.float32,660)661attention_mask = torch.zeros(662beam_size * batch_size, 1, 1, offset + 1, dtype=torch.float32663)664# UT for next token with fp16665input_t_half = input_t.half()666attention_mask_half = attention_mask.half()667with torch.inference_mode(), torch.no_grad():668naive_output_half, _, key_cache_half, value_cache_half, _ = mha(669input_t_half,670key_cache_half,671value_cache_half,672max_seq_len,673attention_mask_half,674None,675None,676enable_linear=False,677)678(679indirect_access_kv_cache_output_half,680_,681key_cache_iakv_half,682value_cache_iakv_half,683beam_idx,684) = mha(685input_t_half,686key_cache_iakv_half,687value_cache_iakv_half,688max_seq_len,689attention_mask_half,690beam_idx,691True,692torch.tensor(offset),693enable_linear=False,694)695self.assertEqual(696naive_output_half,697indirect_access_kv_cache_output_half,698prec=0.05,699)700self.assertEqual(701key_cache_half.transpose(0, 1)[offset],702key_cache_iakv_half[offset, :, :, :],703)704self.assertEqual(705value_cache_half.transpose(0, 1)[offset],706value_cache_iakv_half[offset, :, :, :],707)708if beam_size == 4:709beam_idx_t = torch.tensor([1, 3, 0, 0]).repeat(batch_size)710for i in range(1, batch_size):711beam_idx_t[712i * beam_size : i * beam_size + beam_size713] = (714beam_idx_t[715i * beam_size : i * beam_size + beam_size716]717+ i * beam_size718)719elif beam_size == 1:720beam_idx_t = torch.arange(batch_size)721beam_idx[offset] = beam_idx_t722offset = offset + 1723key_cache_half = torch.index_select(724key_cache_half, 0, beam_idx_t725)726value_cache_half = torch.index_select(727value_cache_half, 0, beam_idx_t728)729# #UT for next token with fp32730input_t = torch.randn(731beam_size * batch_size,7321,733(head_num + 2 * head_num_kv) * head_size,734dtype=torch.float32,735)736attention_mask = torch.zeros(737beam_size * batch_size, 1, 1, offset + 1, dtype=torch.float32738)739# #UT for next token with fp16740input_t_half = input_t.half()741attention_mask_half = attention_mask.half()742with torch.inference_mode(), torch.no_grad():743naive_output_half, _, key_cache_half, value_cache_half, _ = mha(744input_t_half,745key_cache_half,746value_cache_half,747max_seq_len,748attention_mask_half,749None,750None,751enable_linear=False,752)753(754indirect_access_kv_cache_output_half,755_,756key_cache_iakv_half,757value_cache_iakv_half,758beam_idx,759) = mha(760input_t_half,761key_cache_iakv_half,762value_cache_iakv_half,763max_seq_len,764attention_mask_half,765beam_idx,766True,767torch.tensor(offset),768enable_linear=False,769)770self.assertEqual(771naive_output_half,772indirect_access_kv_cache_output_half,773prec=0.05,774)775self.assertEqual(776key_cache_half.transpose(0, 1)[offset],777key_cache_iakv_half[offset, :, :, :],778)779self.assertEqual(780value_cache_half.transpose(0, 1)[offset],781value_cache_iakv_half[offset, :, :, :],782)783
784def test_mha(self):785self._test_mha(torchcompile=False)786self._test_mha_fp16(torchcompile=False)787
788
789if __name__ == "__main__":790test = unittest.main()791