intel-extension-for-pytorch
586 строк · 22.4 Кб
1import unittest2import torch3import torch.nn as nn4from common_utils import TestCase5import time6import torch.nn.functional as F7import os8
9
10def nms(dets, scores, threshold, sorted=False):11return torch.ops.torch_ipex.nms(dets, scores, threshold, sorted)12
13
14batch_score_nms = torch.ops.torch_ipex.batch_score_nms15parallel_scale_back_batch = torch.ops.torch_ipex.parallel_scale_back_batch16rpn_nms = torch.ops.torch_ipex.rpn_nms17box_head_nms = torch.ops.torch_ipex.box_head_nms18
19
20def get_rand_seed():21return int(time.time() * 1000000000)22
23
24# This function is from https://github.com/kuangliu/pytorch-ssd.
25def calc_iou_tensor(box1, box2):26"""Calculation of IoU based on two boxes tensor,27Reference to https://github.com/kuangliu/pytorch-ssd
28input:
29box1 (N, 4)
30box2 (M, 4)
31output:
32IoU (N, M)
33"""
34N = box1.size(0)35M = box2.size(0)36be1 = box1.unsqueeze(1).expand(-1, M, -1)37be2 = box2.unsqueeze(0).expand(N, -1, -1)38# Left Top & Right Bottom39lt = torch.max(be1[:, :, :2], be2[:, :, :2])40# mask1 = (be1[:,:, 0] < be2[:,:, 0]) ^ (be1[:,:, 1] < be2[:,:, 1])41# mask1 = ~mask142rb = torch.min(be1[:, :, 2:], be2[:, :, 2:])43# mask2 = (be1[:,:, 2] < be2[:,:, 2]) ^ (be1[:,:, 3] < be2[:,:, 3])44# mask2 = ~mask245delta = rb - lt46delta[delta < 0] = 047intersect = delta[:, :, 0] * delta[:, :, 1]48# *mask1.float()*mask2.float()49delta1 = be1[:, :, 2:] - be1[:, :, :2]50area1 = delta1[:, :, 0] * delta1[:, :, 1]51delta2 = be2[:, :, 2:] - be2[:, :, :2]52area2 = delta2[:, :, 0] * delta2[:, :, 1]53iou = intersect / (area1 + area2 - intersect)54return iou55
56
57class TestScaleBackBatch(TestCase):58def scale_back_batch(self, bboxes_in, scores_in, dboxes_xywh, scale_xy, scale_wh):59"""60Python implementation of Encoder::scale_back_batch, refer to \
61https://github.com/mlcommons/inference/blob/v0.7/others/cloud/single_stage_detector/pytorch/utils.py
62"""
63bboxes_in[:, :, :2] = scale_xy * bboxes_in[:, :, :2]64bboxes_in[:, :, 2:] = scale_wh * bboxes_in[:, :, 2:]65bboxes_in[:, :, :2] = (66bboxes_in[:, :, :2] * dboxes_xywh[:, :, 2:] + dboxes_xywh[:, :, :2]67)68bboxes_in[:, :, 2:] = bboxes_in[:, :, 2:].exp() * dboxes_xywh[:, :, 2:]69# Transform format to ltrb70l, t, r, b = (71bboxes_in[:, :, 0] - 0.5 * bboxes_in[:, :, 2],72bboxes_in[:, :, 1] - 0.5 * bboxes_in[:, :, 3],73bboxes_in[:, :, 0] + 0.5 * bboxes_in[:, :, 2],74bboxes_in[:, :, 1] + 0.5 * bboxes_in[:, :, 3],75)76bboxes_in[:, :, 0] = l77bboxes_in[:, :, 1] = t78bboxes_in[:, :, 2] = r79bboxes_in[:, :, 3] = b80return bboxes_in, F.softmax(scores_in, dim=-1)81
82def test_scale_back_batch_result(self):83batch_size = 1684number_boxes = 102485scale_xy = 0.186scale_wh = 0.287predicted_loc = (88torch.randn((batch_size, number_boxes, 4)).contiguous().to(torch.float32)89)90predicted_score = (91torch.randn((batch_size, number_boxes, 81)).contiguous().to(torch.float32)92)93dboxes_xywh = torch.randn((1, number_boxes, 4)).contiguous().to(torch.float64)94bbox_res1, score_res1 = self.scale_back_batch(95predicted_loc.clone(),96predicted_score.clone(),97dboxes_xywh.clone(),98scale_xy,99scale_wh,100)101bbox_res2, score_res2 = parallel_scale_back_batch(102predicted_loc, predicted_score, dboxes_xywh, scale_xy, scale_wh103)104# test autocast105with torch.cpu.amp.autocast():106bbox_res3, score_res3 = parallel_scale_back_batch(107predicted_loc, predicted_score, dboxes_xywh, scale_xy, scale_wh108)109self.assertTrue(torch.allclose(bbox_res1, bbox_res2, rtol=1e-4, atol=1e-4))110self.assertTrue(torch.allclose(bbox_res1, bbox_res3, rtol=1e-4, atol=1e-4))111self.assertTrue(torch.allclose(score_res1, score_res2, rtol=1e-4, atol=1e-4))112self.assertTrue(torch.allclose(score_res1, score_res3, rtol=1e-4, atol=1e-4))113
114# test double115bbox_res4, score_res4 = parallel_scale_back_batch(116predicted_loc.clone().double(),117predicted_score,118dboxes_xywh,119scale_xy,120scale_wh,121)122self.assertEqual(bbox_res4, bbox_res2)123self.assertEqual(score_res4, score_res2)124self.assertTrue(bbox_res4.dtype == torch.float64)125
126
127class TestNMS(TestCase):128def decode_single(self, bboxes_in, scores_in, criteria, max_output, max_num=200):129"""130Python implementation of Encoder::decode_single, refer to \
131https://github.com/mlcommons/inference/blob/v0.7/others/cloud/single_stage_detector/pytorch/utils.py
132"""
133# perform non-maximum suppression134# Reference to https://github.com/amdegroot/ssd.pytorch135
136bboxes_out = []137scores_out = []138labels_out = []139for i, score in enumerate(scores_in.split(1, 1)):140# skip background141# print(score[score>0.90])142if i == 0:143continue144score = score.squeeze(1)145mask = score > 0.05146bboxes, score = bboxes_in[mask, :], score[mask]147if score.size(0) == 0:148continue149score_sorted, score_idx_sorted = score.sort(dim=0)150# select max_output indices151score_idx_sorted = score_idx_sorted[-max_num:]152candidates = []153while score_idx_sorted.numel() > 0:154idx = score_idx_sorted[-1].item()155bboxes_sorted = bboxes[score_idx_sorted, :]156bboxes_idx = bboxes[idx, :].unsqueeze(dim=0)157iou_sorted = calc_iou_tensor(bboxes_sorted, bboxes_idx).squeeze()158# we only need iou < criteria159score_idx_sorted = score_idx_sorted[iou_sorted < criteria]160candidates.append(idx)161
162bboxes_out.append(bboxes[candidates, :])163scores_out.append(score[candidates])164labels_out.extend([i] * len(candidates))165bboxes_out, labels_out, scores_out = (166torch.cat(bboxes_out, dim=0),167torch.tensor(labels_out, dtype=torch.long),168torch.cat(scores_out, dim=0),169)170_, max_ids = scores_out.sort(dim=0)171max_ids = max_ids[-max_output:]172return bboxes_out[max_ids, :], labels_out[max_ids], scores_out[max_ids]173
174def test_batch_nms_result(self):175batch_size = 1176number_boxes = 15130177scale_xy = 0.1178scale_wh = 0.2179criteria = 0.50180max_output = 200181predicted_loc = torch.load(182os.path.join(os.path.dirname(__file__), "data/nms_ploc.pt")183) # sizes: [1, 15130, 4]184predicted_score = torch.load(185os.path.join(os.path.dirname(__file__), "data/nms_plabel.pt")186) # sizes: [1, 15130, 81]187dboxes_xywh = torch.load(188os.path.join(os.path.dirname(__file__), "data/nms_dboxes_xywh.pt")189)190bboxes, probs = parallel_scale_back_batch(191predicted_loc, predicted_score, dboxes_xywh, scale_xy, scale_wh192)193bboxes_clone = bboxes.clone()194probs_clone = probs.clone()195
196output = []197for bbox, prob in zip(bboxes.split(1, 0), probs.split(1, 0)):198bbox = bbox.squeeze(0)199prob = prob.squeeze(0)200output.append(self.decode_single(bbox, prob, criteria, max_output))201output2_raw = batch_score_nms(bboxes_clone, probs_clone, criteria, max_output)202
203# test autocast204with torch.cpu.amp.autocast():205for datatype in (torch.bfloat16, torch.float32):206bboxes_autocast = bboxes.clone().to(datatype)207probs_autocast = probs.clone().to(datatype)208output2_raw_autocast = batch_score_nms(209bboxes_autocast, probs_autocast, criteria, max_output210)211for i in range(3):212self.assertTrue(output2_raw_autocast[i].dtype == torch.float32)213
214# Re-assembly the result215output2 = []216idx = 0217for i in range(output2_raw[3].size(0)):218output2.append(219(220output2_raw[0][idx : idx + output2_raw[3][i]],221output2_raw[1][idx : idx + output2_raw[3][i]],222output2_raw[2][idx : idx + output2_raw[3][i]],223)224)225idx += output2_raw[3][i]226
227for i in range(batch_size):228loc, label, prob = list(r for r in output[i])229loc2, label2, prob2 = list(r for r in output2[i])230self.assertTrue(torch.allclose(loc, loc2, rtol=1e-4, atol=1e-4))231self.assertEqual(label, label2)232self.assertTrue(torch.allclose(prob, prob2, rtol=1e-4, atol=1e-4))233
234# test double235output2_raw_double = batch_score_nms(236bboxes.clone().double(), probs.clone().double(), criteria, max_output237)238self.assertEqual(output2_raw_double, output2_raw)239self.assertTrue(output2_raw_double[0].dtype == torch.float64)240
241def test_jit_trace_batch_nms(self):242class Batch_NMS(nn.Module):243def __init__(self, criteria, max_output):244super(Batch_NMS, self).__init__()245self.criteria = criteria246self.max_output = max_output247
248def forward(self, bboxes_clone, probs_clone):249return batch_score_nms(250bboxes_clone, probs_clone, self.criteria, self.max_output251)252
253batch_size = 1254number_boxes = 15130255scale_xy = 0.1256scale_wh = 0.2257criteria = 0.50258max_output = 200259predicted_loc = torch.load(260os.path.join(os.path.dirname(__file__), "data/nms_ploc.pt")261) # sizes: [1, 15130, 4]262predicted_score = torch.load(263os.path.join(os.path.dirname(__file__), "data/nms_plabel.pt")264) # sizes: [1, 15130, 81]265dboxes_xywh = torch.load(266os.path.join(os.path.dirname(__file__), "data/nms_dboxes_xywh.pt")267)268bboxes, probs = parallel_scale_back_batch(269predicted_loc, predicted_score, dboxes_xywh, scale_xy, scale_wh270)271bboxes_clone = bboxes.clone()272probs_clone = probs.clone()273
274output = []275for bbox, prob in zip(bboxes.split(1, 0), probs.split(1, 0)):276bbox = bbox.squeeze(0)277prob = prob.squeeze(0)278output.append(self.decode_single(bbox, prob, criteria, max_output))279
280batch_score_nms_module = Batch_NMS(criteria, max_output)281model_decode = torch.jit.trace(282batch_score_nms_module, (bboxes_clone, probs_clone)283)284output2_raw = model_decode(bboxes_clone, probs_clone)285
286# Re-assembly the result287output2 = []288idx = 0289for i in range(output2_raw[3].size(0)):290output2.append(291(292output2_raw[0][idx : idx + output2_raw[3][i]],293output2_raw[1][idx : idx + output2_raw[3][i]],294output2_raw[2][idx : idx + output2_raw[3][i]],295)296)297idx += output2_raw[3][i]298
299for i in range(batch_size):300loc, label, prob = list(r for r in output[i])301loc2, label2, prob2 = list(r for r in output2[i])302self.assertTrue(torch.allclose(loc, loc2, rtol=1e-4, atol=1e-4))303self.assertEqual(label, label2)304self.assertTrue(torch.allclose(prob, prob2, rtol=1e-4, atol=1e-4))305
306def test_nms_kernel_result(self):307batch_size = 1308class_number = 81309scale_xy = 0.1310scale_wh = 0.2311criteria = 0.50312max_output = 200313predicted_loc = torch.load(314os.path.join(os.path.dirname(__file__), "data/nms_ploc.pt")315) # sizes: [1, 15130, 4]316predicted_score = torch.load(317os.path.join(os.path.dirname(__file__), "data/nms_plabel.pt")318) # sizes: [1, 15130, 81]319dboxes_xywh = torch.load(320os.path.join(os.path.dirname(__file__), "data/nms_dboxes_xywh.pt")321)322bboxes, probs = parallel_scale_back_batch(323predicted_loc, predicted_score, dboxes_xywh, scale_xy, scale_wh324)325
326for bs in range(batch_size):327loc = bboxes[bs].squeeze(0)328for class_id in range(class_number):329if class_id == 0:330# Skip the background331continue332score = probs[bs, :, class_id]333
334score_sorted, indices = torch.sort(score, descending=True)335loc_sorted = torch.index_select(loc, 0, indices)336
337result = nms(loc_sorted.clone(), score_sorted.clone(), criteria, True)338result_ref = nms(loc.clone(), score.clone(), criteria, False)339result_ref2 = nms(340loc_sorted.clone().to(dtype=torch.float64),341score_sorted.clone().to(dtype=torch.float64),342criteria,343True,344)345
346bbox_keep, _ = torch.sort(347torch.index_select(loc_sorted, 0, result).squeeze(0), 0348)349bbox_keep_ref, _ = torch.sort(350torch.index_select(loc, 0, result_ref).squeeze(0), 0351)352bbox_keep_ref2, _ = torch.sort(353torch.index_select(loc_sorted, 0, result_ref2).squeeze(0), 0354)355
356score_keep, _ = torch.sort(357torch.index_select(score_sorted, 0, result).squeeze(0), 0358)359score_keep_ref, _ = torch.sort(360torch.index_select(score, 0, result_ref).squeeze(0), 0361)362score_keep_ref2, _ = torch.sort(363torch.index_select(score_sorted, 0, result_ref2).squeeze(0), 0364)365
366self.assertEqual(result.size(0), result_ref.size(0))367self.assertTrue(368torch.allclose(bbox_keep, bbox_keep_ref, rtol=1e-4, atol=1e-4)369)370self.assertTrue(371torch.allclose(score_keep, score_keep_ref, rtol=1e-4, atol=1e-4)372)373self.assertTrue(374torch.allclose(bbox_keep, bbox_keep_ref2, rtol=1e-4, atol=1e-4)375)376self.assertTrue(377torch.allclose(score_keep, score_keep_ref2, rtol=1e-4, atol=1e-4)378)379
380# test autocast381with torch.cpu.amp.autocast():382result_autocast = nms(loc.clone(), score.clone(), criteria, False)383self.assertEqual(result_autocast, result_ref)384
385# test double386result_double = nms(387loc.clone().double(), score.clone().double(), criteria, False388)389self.assertEqual(result_double, result_ref)390
391def test_rpn_nms_result(self):392image_shapes = [(800, 824), (800, 1199)]393min_size = 0394nms_thresh = 0.7395post_nms_top_n = 1000396proposals = torch.load(397os.path.join(os.path.dirname(__file__), "data/rpn_nms_proposals.pt")398)399objectness = torch.load(400os.path.join(os.path.dirname(__file__), "data/rpn_nms_objectness.pt")401)402
403new_proposal = []404new_score = []405for proposal, score, im_shape in zip(406proposals.clone(), objectness.clone(), image_shapes407):408proposal[:, 0].clamp_(min=0, max=im_shape[0] - 1)409proposal[:, 1].clamp_(min=0, max=im_shape[1] - 1)410proposal[:, 2].clamp_(min=0, max=im_shape[0] - 1)411proposal[:, 3].clamp_(min=0, max=im_shape[1] - 1)412keep = (413(414(proposal[:, 2] - proposal[:, 0] >= min_size)415& (proposal[:, 3] - proposal[:, 1] >= min_size)416)417.nonzero()418.squeeze(1)419)420proposal = proposal[keep]421score = score[keep]422if nms_thresh > 0:423keep = nms(proposal, score, nms_thresh)424if post_nms_top_n > 0:425keep = keep[:post_nms_top_n]426new_proposal.append(proposal[keep])427new_score.append(score[keep])428
429new_proposal_, new_score_ = rpn_nms(430proposals, objectness, image_shapes, min_size, nms_thresh, post_nms_top_n431)432
433self.assertEqual(new_proposal, new_proposal_)434self.assertEqual(new_score, new_score_)435
436# test autocast437with torch.cpu.amp.autocast():438for datatype in (torch.bfloat16, torch.float32):439proposals_autocast = proposals.clone().to(datatype)440objectness_autocast = objectness.clone().to(datatype)441new_proposal_autocast, new_score_autocast = rpn_nms(442proposals_autocast,443objectness_autocast,444image_shapes,445min_size,446nms_thresh,447post_nms_top_n,448)449self.assertTrue(new_proposal_autocast[0].dtype == torch.float32)450self.assertTrue(new_score_autocast[0].dtype == torch.float32)451
452# test double453new_proposal_double, new_score_double = rpn_nms(454proposals.clone().double(),455objectness.clone().double(),456image_shapes,457min_size,458nms_thresh,459post_nms_top_n,460)461self.assertEqual(new_proposal_double, new_proposal)462self.assertEqual(new_score_double, new_score)463self.assertTrue(new_proposal_double[0].dtype == torch.float64)464self.assertTrue(new_score_double[0].dtype == torch.float64)465
466def test_box_head_nms_result(self):467image_shapes = [(800, 824), (800, 1199)]468score_thresh = 0.05469nms_ = 0.5470detections_per_img = 100471num_classes = 81472proposals = torch.load(473os.path.join(os.path.dirname(__file__), "data/box_head_nms_proposals.pt")474)475class_prob = torch.load(476os.path.join(os.path.dirname(__file__), "data/box_head_nms_class_prob.pt")477)478
479boxes_out = []480scores_out = []481labels_out = []482for scores, boxes, image_shape in zip(class_prob, proposals, image_shapes):483boxes = boxes.reshape(-1, 4)484boxes[:, 0].clamp_(min=0, max=image_shape[0] - 1)485boxes[:, 1].clamp_(min=0, max=image_shape[1] - 1)486boxes[:, 2].clamp_(min=0, max=image_shape[0] - 1)487boxes[:, 3].clamp_(min=0, max=image_shape[1] - 1)488boxes = boxes.reshape(-1, num_classes * 4)489scores = scores.reshape(-1, num_classes)490
491inds_all = scores > score_thresh492new_boxes = []493new_scores = []494new_labels = []495for j in range(1, num_classes):496inds = inds_all[:, j].nonzero().squeeze(1)497scores_j = scores[inds, j]498boxes_j = boxes[inds, j * 4 : (j + 1) * 4]499if nms_ > 0:500keep = nms(boxes_j, scores_j, nms_)501new_boxes.append(boxes_j[keep])502new_scores.append(scores_j[keep])503new_labels.append(torch.full((len(keep),), j, dtype=torch.int64))504
505new_boxes, new_scores, new_labels = (506torch.cat(new_boxes, dim=0),507torch.cat(new_scores, dim=0),508torch.cat(new_labels, dim=0),509)510number_of_detections = new_boxes.size(0)511if number_of_detections > detections_per_img > 0:512image_thresh, _ = torch.kthvalue(513new_scores, number_of_detections - detections_per_img + 1514)515keep = new_scores >= image_thresh.item()516keep = torch.nonzero(keep).squeeze(1)517boxes_out.append(new_boxes[keep])518scores_out.append(new_scores[keep])519labels_out.append(new_labels[keep])520else:521boxes_out.append(new_boxes)522scores_out.append(new_scores)523labels_out.append(new_labels)524
525boxes_out_, scores_out_, labels_out_ = box_head_nms(526proposals,527class_prob,528image_shapes,529score_thresh,530nms_,531detections_per_img,532num_classes,533)534
535self.assertEqual(boxes_out, boxes_out_)536self.assertEqual(scores_out, scores_out_)537self.assertEqual(labels_out, labels_out_)538
539# test autocast540with torch.cpu.amp.autocast():541for datatype in (torch.bfloat16, torch.float32):542proposals_autocast = (543proposals[0].to(datatype),544proposals[1].to(datatype),545)546class_prob_autocast = (547class_prob[0].to(datatype),548class_prob[1].to(datatype),549)550(551boxes_out_autocast,552scores_out_autocast,553labels_out_autocast,554) = box_head_nms(555proposals_autocast,556class_prob_autocast,557image_shapes,558score_thresh,559nms_,560detections_per_img,561num_classes,562)563self.assertTrue(boxes_out_autocast[0].dtype == torch.float32)564self.assertTrue(scores_out_autocast[0].dtype == torch.float32)565
566# test double567proposals_double = (proposals[0].double(), proposals[1].double())568class_prob_double = (class_prob[0].double(), class_prob[1].double())569boxes_out_double, scores_out_double, labels_out_double = box_head_nms(570proposals_double,571class_prob_double,572image_shapes,573score_thresh,574nms_,575detections_per_img,576num_classes,577)578self.assertEqual(boxes_out_double, boxes_out)579self.assertEqual(scores_out_double, scores_out)580self.assertEqual(labels_out_double, labels_out)581self.assertTrue(boxes_out_double[0].dtype == torch.float64)582self.assertTrue(scores_out_double[0].dtype == torch.float64)583
584
585if __name__ == "__main__":586test = unittest.main()587